diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -31,19 +31,31 @@ /// stored directly. For binary operations, e0 and e1 denote the index of the /// children tensor expressions. struct TensorExp { - TensorExp(Kind k, unsigned x, unsigned y, Value v) - : kind(k), e0(x), e1(y), val(v) { - assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) || - (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) || - (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val)); + TensorExp(Kind k, unsigned x, unsigned y, Value v) : kind(k), val(v) { + assert((kind == Kind::kTensor && x != -1u && y == -1u && !val) || + (kind == Kind::kInvariant && x == -1u && y == -1u && val) || + (kind >= Kind::kMulF && x != -1u && y != -1u && !val)); + if (kind == Kind::kTensor) { + tensor = x; + } else if (kind >= Kind::kMulF) { + e0 = x; + e1 = y; + } } /// Tensor expression kind. Kind kind; - /// Indices of children expression(s). - unsigned e0; - unsigned e1; + union { + /// Expressions representing tensors simply have a tensor number. + unsigned tensor; + + /// Binary operations hold the indices of their child expressions. + struct { + unsigned e0; + unsigned e1; + }; + }; /// Direct link to IR for an invariant. During code generation, /// field is used to cache "hoisted" loop invariant tensor loads. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -214,7 +214,7 @@ static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) { switch (merger.exp(exp).kind) { case Kind::kTensor: - return merger.exp(exp).e0 == tensor; + return merger.exp(exp).tensor == tensor; case Kind::kMulF: case Kind::kMulI: return isConjunction(merger, tensor, merger.exp(exp).e0) || @@ -455,7 +455,7 @@ } // Actual load. SmallVector args; - OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; + OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; unsigned tensor = t->getOperandNumber(); auto map = op.getTiedIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); @@ -653,7 +653,7 @@ if (merger.exp(exp).kind == Kind::kTensor) { // Inspect tensor indices. bool atLevel = ldx == -1u; - OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; + OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; auto map = op.getTiedIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -72,7 +72,8 @@ if (p0 != p1) { // Is this a straightforward copy? unsigned e = latPoints[p1].exp; - if (tensorExps[e].kind == Kind::kTensor && tensorExps[e].e0 == outTensor) + if (tensorExps[e].kind == Kind::kTensor && + tensorExps[e].tensor == outTensor) continue; // Conjunction already covered? for (unsigned p2 : latSets[s]) { @@ -150,11 +151,11 @@ void Merger::dumpExp(unsigned e) const { switch (tensorExps[e].kind) { case Kind::kTensor: - if (tensorExps[e].e0 == syntheticTensor) + if (tensorExps[e].tensor == syntheticTensor) llvm::dbgs() << "synthetic_"; - else if (tensorExps[e].e0 == outTensor) + else if (tensorExps[e].tensor == outTensor) llvm::dbgs() << "output_"; - llvm::dbgs() << "tensor_" << tensorExps[e].e0; + llvm::dbgs() << "tensor_" << tensorExps[e].tensor; break; case Kind::kInvariant: llvm::dbgs() << "invariant";