diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -30,19 +30,32 @@ /// stored directly. For binary operations, e0 and e1 denote the index of the /// children tensor expressions. struct TensorExp { - TensorExp(Kind k, unsigned x, unsigned y, Value v) - : kind(k), e0(x), e1(y), val(v) { - assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) || - (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) || - (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val)); + TensorExp(Kind k, unsigned x, unsigned y, Value v) : kind(k), val(v) { + assert((kind == Kind::kTensor && x != -1u && y == -1u && !val) || + (kind == Kind::kInvariant && x == -1u && y == -1u && val) || + (kind >= Kind::kMulF && x != -1u && y != -1u && !val)); + if (kind == Kind::kTensor) { + tensor = x; + } else if (kind >= Kind::kMulF) { + e0 = x; + e1 = y; + } } /// Tensor expression kind. Kind kind; - /// Indices of children expression(s). - unsigned e0; - unsigned e1; + /// Parameters of a TensorExp. + union { + /// Expressions representing tensors simply have a tensor number. + unsigned tensor; + + /// Binary operations hold the indices of their child expressions. + struct { + unsigned e0; + unsigned e1; + }; + }; /// Direct link to IR for an invariant. During code generation, /// field is used to cache "hoisted" loop invariant tensor loads. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -259,7 +259,7 @@ static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) { switch (merger.exp(exp).kind) { case Kind::kTensor: - return merger.exp(exp).e0 == tensor; + return merger.exp(exp).tensor == tensor; case Kind::kMulF: case Kind::kMulI: return isConjunction(merger, tensor, merger.exp(exp).e0) || @@ -500,7 +500,7 @@ } // Actual load. SmallVector args; - OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; + OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; unsigned tensor = t->getOperandNumber(); auto map = op.getTiedIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); @@ -698,7 +698,7 @@ if (merger.exp(exp).kind == Kind::kTensor) { // Inspect tensor indices. bool atLevel = ldx == -1u; - OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; + OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; auto map = op.getTiedIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -68,7 +68,7 @@ if (p0 != p1) { // Is this a straightforward copy? unsigned e = latPoints[p1].exp; - if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) + if (exp(e).kind == Kind::kTensor && exp(e).tensor == outTensor) continue; // Conjunction already covered? for (unsigned p2 : latSets[s]) { @@ -144,7 +144,7 @@ // set to the undefined index in that dimension. An invariant expression // is set to a synthetic tensor with undefined indices only. unsigned s = addSet(); - unsigned t = kind == Kind::kTensor ? exp(e).e0 : syntheticTensor; + unsigned t = kind == Kind::kTensor ? exp(e).tensor : syntheticTensor; set(s).push_back(addLat(t, idx, e)); return s; } @@ -173,7 +173,7 @@ void Merger::dumpExp(unsigned e) const { switch (tensorExps[e].kind) { case Kind::kTensor: - llvm::dbgs() << "tensor_" << tensorExps[e].e0; + llvm::dbgs() << "tensor_" << tensorExps[e].tensor; break; case Kind::kInvariant: llvm::dbgs() << "invariant";