diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -138,6 +138,11 @@ /// Returns true if any set bit corresponds to queried dim. bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const; + /// Builds the iteration lattices in a bottom-up traversal given the remaining + /// tensor (sub)expression and the next loop index in the iteration graph. + /// Returns index of the root expression. + unsigned buildLattices(unsigned exp, unsigned idx); + /// Setter void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -302,37 +302,6 @@ return false; } -/// Builds the iteration lattices in a bottom-up traversal given the remaining -/// tensor (sub)expression and the next loop index in the iteration graph. -static unsigned buildLattices(Merger &merger, linalg::GenericOp op, - unsigned exp, unsigned idx) { - Kind kind = merger.exp(exp).kind; - if (kind == Kind::kTensor || kind == Kind::kInvariant) { - // Either the index is really used in the tensor expression, or it is - // set to the undefined index in that dimension. An invariant expression - // is set to a synthetic tensor with undefined indices only. - unsigned s = merger.addSet(); - unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0 - : op.getNumInputsAndOutputs(); - merger.set(s).push_back(merger.addLat(t, idx, exp)); - return s; - } - unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx); - unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx); - switch (kind) { - case Kind::kTensor: - case Kind::kInvariant: - llvm_unreachable("handled above"); - case Kind::kMulF: - case Kind::kMulI: - return merger.takeConj(kind, s0, s1); - case Kind::kAddF: - case Kind::kAddI: - return merger.takeDisj(kind, s0, s1); - } - llvm_unreachable("unexpected expression kind"); -} - /// Maps sparse integer option to actual integral storage type. static Type genIntType(PatternRewriter &rewriter, unsigned width) { if (width == 0) @@ -1121,7 +1090,7 @@ // in play for a non-singleton loop sequence. Location loc = op.getLoc(); unsigned idx = topSort[at]; - unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx)); + unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); unsigned lsize = merger.set(lts).size(); assert(lsize != 0); unsigned l0 = merger.set(lts)[0]; diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -134,5 +134,32 @@ return false; } +unsigned Merger::buildLattices(unsigned e, unsigned idx) { + Kind kind = exp(e).kind; + if (kind == Kind::kTensor || kind == Kind::kInvariant) { + // Either the index is really used in the tensor expression, or it is + // set to the undefined index in that dimension. An invariant expression + // is set to a synthetic tensor with undefined indices only. + unsigned s = addSet(); + unsigned t = kind == Kind::kTensor ? exp(e).e0 : numTensors - 1; + set(s).push_back(addLat(t, idx, e)); + return s; + } + unsigned s0 = buildLattices(exp(e).e0, idx); + unsigned s1 = buildLattices(exp(e).e1, idx); + switch (kind) { + case Kind::kTensor: + case Kind::kInvariant: + llvm_unreachable("handled above"); + case Kind::kMulF: + case Kind::kMulI: + return takeConj(kind, s0, s1); + case Kind::kAddF: + case Kind::kAddI: + return takeDisj(kind, s0, s1); + } + llvm_unreachable("unexpected expression kind"); +} + } // namespace sparse_tensor } // namespace mlir