diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/IR/Value.h" #include "llvm/ADT/BitVector.h" @@ -148,11 +149,6 @@ /// Returns true if any set bit corresponds to queried dim. bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const; - /// Builds the iteration lattices in a bottom-up traversal given the remaining - /// tensor (sub)expression and the next loop index in the iteration graph. - /// Returns index of the root expression. - unsigned buildLattices(unsigned exp, unsigned idx); - /// Setter void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } @@ -169,7 +165,19 @@ void dumpBits(const llvm::BitVector &bits) const; #endif + /// Builds the iteration lattices in a bottom-up traversal given the remaining + /// tensor (sub)expression and the next loop index in the iteration graph. + /// Returns index of the root expression. + unsigned buildLattices(unsigned exp, unsigned idx); + + /// Builds a tensor expression from the given Linalg operation. + /// Returns index of the root expression on success. + Optional buildTensorExpFromLinalg(linalg::GenericOp op); + private: + /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. + Optional buildTensorExp(linalg::GenericOp op, Value val); + const unsigned outTensor; const unsigned syntheticTensor; const unsigned numTensors; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1224,14 +1224,12 @@ !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true)) return failure(); - // Finds the terminating yield statement and builds the tensor - // expression for the Linalg operation in SSA form. - Operation *yield = op.region().front().getTerminator(); - Optional exp = buildTensorExp(merger, op, yield->getOperand(0)); + // Builds the tensor expression for the Linalg operation in SSA form. + Optional exp = merger.buildTensorExpFromLinalg(op); if (!exp.hasValue()) - return failure(); // build failure + return failure(); - // Reject an inadmissable tensor expression. + // Rejects an inadmissable tensor expression. if (!isAdmissableTensorExp(merger, op, exp.getValue())) return failure(); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -14,6 +14,10 @@ namespace mlir { namespace sparse_tensor { +// +// Lattice methods. +// + unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { unsigned e = tensorExps.size(); tensorExps.push_back(TensorExp(k, e0, e1, v)); @@ -68,7 +72,7 @@ if (p0 != p1) { // Is this a straightforward copy? unsigned e = latPoints[p1].exp; - if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) + if (tensorExps[e].kind == Kind::kTensor && tensorExps[e].e0 == outTensor) continue; // Conjunction already covered? for (unsigned p2 : latSets[s]) { @@ -137,33 +141,6 @@ return false; } -unsigned Merger::buildLattices(unsigned e, unsigned idx) { - Kind kind = exp(e).kind; - if (kind == Kind::kTensor || kind == Kind::kInvariant) { - // Either the index is really used in the tensor expression, or it is - // set to the undefined index in that dimension. An invariant expression - // is set to a synthetic tensor with undefined indices only. - unsigned s = addSet(); - unsigned t = kind == Kind::kTensor ? exp(e).e0 : syntheticTensor; - set(s).push_back(addLat(t, idx, e)); - return s; - } - unsigned s0 = buildLattices(exp(e).e0, idx); - unsigned s1 = buildLattices(exp(e).e1, idx); - switch (kind) { - case Kind::kTensor: - case Kind::kInvariant: - llvm_unreachable("handled above"); - case Kind::kMulF: - case Kind::kMulI: - return takeConj(kind, s0, s1); - case Kind::kAddF: - case Kind::kAddI: - return takeDisj(kind, s0, s1); - } - llvm_unreachable("unexpected expression kind"); -} - #ifndef NDEBUG // @@ -173,6 +150,10 @@ void Merger::dumpExp(unsigned e) const { switch (tensorExps[e].kind) { case Kind::kTensor: + if (tensorExps[e].e0 == syntheticTensor) + llvm::dbgs() << "synthetic_"; + else if (tensorExps[e].e0 == outTensor) + llvm::dbgs() << "output_"; llvm::dbgs() << "tensor_" << tensorExps[e].e0; break; case Kind::kInvariant: @@ -242,5 +223,82 @@ #endif // NDEBUG +// +// Builder methods. +// + +unsigned Merger::buildLattices(unsigned e, unsigned idx) { + Kind kind = exp(e).kind; + if (kind == Kind::kTensor || kind == Kind::kInvariant) { + // Either the index is really used in the tensor expression, or it is + // set to the undefined index in that dimension. An invariant expression + // is set to a synthetic tensor with undefined indices only. + unsigned s = addSet(); + unsigned t = kind == Kind::kTensor ? exp(e).e0 : syntheticTensor; + set(s).push_back(addLat(t, idx, e)); + return s; + } + unsigned s0 = buildLattices(exp(e).e0, idx); + unsigned s1 = buildLattices(exp(e).e1, idx); + switch (kind) { + case Kind::kTensor: + case Kind::kInvariant: + llvm_unreachable("handled above"); + case Kind::kMulF: + case Kind::kMulI: + return takeConj(kind, s0, s1); + case Kind::kAddF: + case Kind::kAddI: + return takeDisj(kind, s0, s1); + } + llvm_unreachable("unexpected expression kind"); +} + +Optional Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { + Operation *yield = op.region().front().getTerminator(); + return buildTensorExp(op, yield->getOperand(0)); +} + +Optional Merger::buildTensorExp(linalg::GenericOp op, Value val) { + if (auto arg = val.dyn_cast()) { + unsigned argN = arg.getArgNumber(); + // Any argument of the generic op that is not marked as a scalar + // argument is considered a tensor, indexed by the implicit loop + // bounds. This includes rank-0 tensor arguments. + if (arg.getOwner()->getParentOp() == op) { + OpOperand *t = op.getInputAndOutputOperands()[argN]; + if (!op.isScalar(t)) + return addExp(Kind::kTensor, argN); + val = t->get(); // get scalar value + } + // Any other argument (marked as scalar argument for the generic op + // or belonging to an enveloping op) is considered invariant. + return addExp(Kind::kInvariant, val); + } + Operation *def = val.getDefiningOp(); + if (def->getBlock() != &op.region().front()) { + // Something defined outside is invariant. + return addExp(Kind::kInvariant, val); + } else if (def->getNumOperands() == 2) { + // Construct binary operations if subexpressions could be built. + auto x = buildTensorExp(op, def->getOperand(0)); + auto y = buildTensorExp(op, def->getOperand(1)); + if (x.hasValue() && y.hasValue()) { + unsigned e0 = x.getValue(); + unsigned e1 = y.getValue(); + if (isa(def)) + return addExp(Kind::kMulF, e0, e1); + if (isa(def)) + return addExp(Kind::kMulI, e0, e1); + if (isa(def)) + return addExp(Kind::kAddF, e0, e1); + if (isa(def)) + return addExp(Kind::kAddI, e0, e1); + } + } + // Cannot build. + return None; +} + } // namespace sparse_tensor } // namespace mlir