diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -25,24 +25,40 @@ /// Dimension level type for a tensor (undef means index does not appear). enum class Dim { kSparse, kDense, kSingle, kUndef }; +/// Parameters of a TensorExp. +union Params { + /// Expressions representing tensors simply have a tensor number. + unsigned tensor_num; + + /// Binary operations hold the indices of their child expression(s). + struct { + unsigned e0; + unsigned e1; + } children; +}; + /// Tensor expression. Represents a MLIR expression in tensor index notation. /// For tensors, e0 denotes the tensor index. For invariants, the IR value is /// stored directly. For binary operations, e0 and e1 denote the index of the /// children tensor expressions. struct TensorExp { TensorExp(Kind k, unsigned x, unsigned y, Value v) - : kind(k), e0(x), e1(y), val(v) { - assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) || - (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) || - (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val)); + : kind(k), params{}, val(v) { + assert((kind == Kind::kTensor && x != -1u && y == -1u && !val) || + (kind == Kind::kInvariant && x == -1u && y == -1u && val) || + (kind >= Kind::kMulF && x != -1u && y != -1u && !val)); + if (kind == Kind::kTensor) { + params.tensor_num = x; + } else if (kind >= Kind::kMulF) { + params.children.e0 = x; + params.children.e1 = y; + } } /// Tensor expression kind. Kind kind; - /// Indices of children expression(s). - unsigned e0; - unsigned e1; + Params params; /// Direct link to IR for an invariant. During code generation, /// field is used to cache "hoisted" loop invariant tensor loads. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -259,11 +259,11 @@ static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) { switch (merger.exp(exp).kind) { case Kind::kTensor: - return merger.exp(exp).e0 == tensor; + return merger.exp(exp).params.tensor_num == tensor; case Kind::kMulF: case Kind::kMulI: - return isConjunction(merger, tensor, merger.exp(exp).e0) || - isConjunction(merger, tensor, merger.exp(exp).e1); + return isConjunction(merger, tensor, merger.exp(exp).params.children.e0) || + isConjunction(merger, tensor, merger.exp(exp).params.children.e1); default: return false; } @@ -500,7 +500,8 @@ } // Actual load. SmallVector args; - OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; + OpOperand *t = + op.getInputAndOutputOperands()[merger.exp(exp).params.tensor_num]; unsigned tensor = t->getOperandNumber(); auto map = op.getTiedIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); @@ -673,8 +674,10 @@ return genTensorLoad(merger, codegen, rewriter, op, exp); else if (merger.exp(exp).kind == Kind::kInvariant) return genInvariantValue(merger, codegen, rewriter, exp); - Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0); - Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1); + Value v0 = + genExp(merger, codegen, rewriter, op, merger.exp(exp).params.children.e0); + Value v1 = + genExp(merger, codegen, rewriter, op, merger.exp(exp).params.children.e1); switch (merger.exp(exp).kind) { case Kind::kTensor: case Kind::kInvariant: @@ -698,7 +701,8 @@ if (merger.exp(exp).kind == Kind::kTensor) { // Inspect tensor indices. bool atLevel = ldx == -1u; - OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; + OpOperand *t = + op.getInputAndOutputOperands()[merger.exp(exp).params.tensor_num]; auto map = op.getTiedIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { @@ -720,8 +724,8 @@ // Traverse into the binary operations. Note that we only hoist // tensor loads, since subsequent MLIR/LLVM passes know how to // deal with all other kinds of derived loop invariants. - unsigned e0 = merger.exp(exp).e0; - unsigned e1 = merger.exp(exp).e1; + unsigned e0 = merger.exp(exp).params.children.e0; + unsigned e1 = merger.exp(exp).params.children.e1; genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -68,7 +68,7 @@ if (p0 != p1) { // Is this a straightforward copy? unsigned e = latPoints[p1].exp; - if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) + if (exp(e).kind == Kind::kTensor && exp(e).params.tensor_num == outTensor) continue; // Conjunction already covered? for (unsigned p2 : latSets[s]) { @@ -144,12 +144,13 @@ // set to the undefined index in that dimension. An invariant expression // is set to a synthetic tensor with undefined indices only. unsigned s = addSet(); - unsigned t = kind == Kind::kTensor ? exp(e).e0 : syntheticTensor; + unsigned t = + kind == Kind::kTensor ? exp(e).params.tensor_num : syntheticTensor; set(s).push_back(addLat(t, idx, e)); return s; } - unsigned s0 = buildLattices(exp(e).e0, idx); - unsigned s1 = buildLattices(exp(e).e1, idx); + unsigned s0 = buildLattices(exp(e).params.children.e0, idx); + unsigned s1 = buildLattices(exp(e).params.children.e1, idx); switch (kind) { case Kind::kTensor: case Kind::kInvariant: @@ -173,7 +174,7 @@ void Merger::dumpExp(unsigned e) const { switch (tensorExps[e].kind) { case Kind::kTensor: - llvm::dbgs() << "tensor_" << tensorExps[e].e0; + llvm::dbgs() << "tensor_" << tensorExps[e].params.tensor_num; break; case Kind::kInvariant: llvm::dbgs() << "invariant"; @@ -181,17 +182,17 @@ default: case Kind::kMulI: llvm::dbgs() << "("; - dumpExp(tensorExps[e].e0); + dumpExp(tensorExps[e].params.children.e0); llvm::dbgs() << " * "; - dumpExp(tensorExps[e].e1); + dumpExp(tensorExps[e].params.children.e1); llvm::dbgs() << ")"; break; case Kind::kAddF: case Kind::kAddI: llvm::dbgs() << "("; - dumpExp(tensorExps[e].e0); + dumpExp(tensorExps[e].params.children.e0); llvm::dbgs() << " + "; - dumpExp(tensorExps[e].e1); + dumpExp(tensorExps[e].params.children.e1); llvm::dbgs() << ")"; break; }