diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -448,6 +448,27 @@ /// Sets whether the output tensor is sparse or not. void setHasSparseOut(bool s) { hasSparseOut = s; } + // FIXME: Dim or Level? + void setLoopDependentSliceDim(unsigned l, unsigned t, unsigned dim) { + assert(dim < numLoops); + ldxToDependentSlice[t][l] = dim; + sliceToRelatedldx[t][dim].push_back(l); + } + + // Whether the ldx has dependent slice. + bool hasDependentSliceDim(unsigned l) { + return llvm::any_of(ldxToDependentSlice, + [l](const std::vector> &m) { + return m[l].has_value(); + }); + } + + // Returns the list of loop index appeared in the affine indexing expression + // on t[d]. + std::vector &getRelatedLoops(unsigned t, unsigned d) { + return sliceToRelatedldx[t][d]; + } + /// Convenience getters to immediately access the stored nodes. /// Typically it is inadvisible to keep the reference around, as in /// `TensorExpr &te = merger.exp(e)`, since insertions into the merger @@ -511,6 +532,11 @@ // Map that converts pair to the corresponding LoopId. std::vector>> lvlToLoop; + // Map from loop idx to the dependent slice [tid, dim] pair (if any). + std::vector>> ldxToDependentSlice; + // Map from the dependent slice [tid, dim] pair to a list of loop idx. + std::vector>> sliceToRelatedldx; + llvm::SmallVector tensorExps; llvm::SmallVector latPoints; llvm::SmallVector> latSets; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -99,6 +99,7 @@ topSort.reserve(capacity); } + ArrayRef getTopSort() const { return topSort; }; ArrayRef getTopSortSlice(LoopOrd n, LoopOrd m) const; ArrayRef getLoopStackUpTo(LoopOrd n) const; ArrayRef getCurrentLoopStack() const; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -109,6 +109,14 @@ SmallVector iterTypes; }; +/// A helper class that visits an affine expression and tries to find an +/// AffineDimExpr to which the corresponding iterator from a GenericOp matches +/// the desired iterator type. +struct AffineDimCollector : public AffineExprVisitor { + void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); } + SmallVector dims; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -254,6 +262,66 @@ } } +/// Helper method to inspect affine expressions. Rejects cases where the +/// same index is used more than once. Also rejects compound affine +/// expressions in sparse dimensions. +/// filterIdx stores the current filter loop idx should be used for the next +/// compound affine sparse level, and it will be incremented by one when +/// used. +/// Helper method to inspect affine expressions for slice based codegen. Rejects +/// cases where the same index is used more than once. +static bool findSliceBasedAffine(Merger &merger, unsigned tensor, unsigned dim, + AffineExpr a, DimLevelType dlt, bool needSlice, + bool isSubExp = false) { + switch (a.getKind()) { + case AffineExprKind::DimId: { + unsigned idx = a.cast().getPosition(); + if (!isUndefDLT(merger.getDimLevelType(tensor, idx))) + return false; // used more than once + if (!isSubExp) + merger.setLevelAndType(tensor, idx, dim, dlt); + + if (isSubExp && needSlice) { + // The current loops appears in more than one affine expressions. + // Now, for simplicity, we do not handle this case. + if (merger.hasDependentSliceDim(idx)) { + // TODO: This can be supported by coiterate slices if the loop idx is + // appeared on affine index for different tensor, or take slice on + // mulitple dimensions when it is on the same tensor. + // E.g., + // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0] + // d0_1 = getNextSliceOffset t0 along lvl0 + // d0_2 = getNextSliceOffset t1 along lvl0 + // if d0_1 == d0_2 then d0 = d0_1 = d0_1 + // else increase min(d0_1, d0_2). + return false; + } + merger.setLoopDependentSliceDim(idx, tensor, dim); + } + return true; + } + case AffineExprKind::Constant: + case AffineExprKind::Mul: + // TODO: Support Mul and Constant AffineExp for slice-based codegen + if (needSlice) + return false; + [[fallthrough]]; + case AffineExprKind::Add: { + auto binOp = a.cast(); + // We do not set dim level format for affine expresssion like d0 + d1 on + // either loop index at d0 or d1. + // We continue the recursion merely to check whether current affine is + // admissible or not. + return findSliceBasedAffine(merger, tensor, dim, binOp.getLHS(), dlt, + needSlice, true) && + findSliceBasedAffine(merger, tensor, dim, binOp.getRHS(), dlt, + needSlice, true); + } + default: + return false; + } +} + /// Get the total number of compound affine expressions in the /// `getMatchingIndexingMap` for the given tensor. For the following inputs: /// @@ -262,7 +330,8 @@ /// /// Returns 1 (because the first level is compressed and its corresponding /// indexing-expression is `d0 + d1`) -static unsigned getNumCompoundAffineOnSparseLvls(AffineMap map, Value tensor) { +static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, + Value tensor) { // The `tensor` is not guaranted to have `RankedTensorType`, therefore // we can't use `getRankedTensorType`/`getSparseTensorType` here. // However, we don't need to handle `StorageSpecifierType`, so we @@ -305,20 +374,20 @@ /// Get the total number of sparse levels with compound affine /// expressions, summed over all operands of the `GenericOp`. -static unsigned getNumCompoundAffineOnSparseLvls(linalg::GenericOp op) { +static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) { unsigned num = 0; for (OpOperand &t : op->getOpOperands()) - num += getNumCompoundAffineOnSparseLvls(op.getMatchingIndexingMap(&t), - t.get()); + num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t), + t.get()); return num; } -static bool hasCompoundAffineOnSparseOut(linalg::GenericOp op) { +static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) { OpOperand *out = op.getDpsInitOperand(0); if (getSparseTensorType(out->get()).isAllDense()) return false; - return getNumCompoundAffineOnSparseLvls(op.getMatchingIndexingMap(out), - out->get()); + return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out), + out->get()); } /// Helper method to inspect sparse encodings in the tensor types. @@ -326,7 +395,7 @@ /// Returns true if the sparse annotations and affine subscript /// expressions of all tensors are admissible. Returns false if /// no annotations are found or inadmissible constructs occur. -static bool findSparseAnnotations(CodegenEnv &env) { +static bool findSparseAnnotations(CodegenEnv &env, bool sliceBased) { bool annotated = false; // `filterLdx` may be mutated by `findAffine`. LoopId filterLdx = env.merger().getStartingFilterLoopId(); @@ -335,17 +404,26 @@ const auto enc = getSparseTensorEncoding(t.get().getType()); if (enc) annotated = true; + const Level lvlRank = map.getNumResults(); assert(!enc || lvlRank == enc.getLvlRank()); assert(static_cast(env.op().getRank(&t)) == lvlRank); + + bool needSlice = + enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0; + // If then current tensor being inspected requires affine index, it need + // to be sliced. for (Level l = 0; l < lvlRank; l++) { const TensorId tid = t.getOperandNumber(); - // FIXME: `toOrigDim` is deprecated. - // FIXME: above we asserted that there are `lvlRank` many results, - // but this is assuming there are in fact `dimRank` many results instead. - const AffineExpr a = map.getResult(toOrigDim(enc, l)); - if (!findAffine(env.merger(), tid, l, a, enc.getLvlType(l), filterLdx)) - return false; // inadmissible affine expression + AffineExpr a = map.getResult(toOrigDim(enc, l)); + DimLevelType dlt = enc.getLvlType(l); + if (sliceBased) { + if (!findSliceBasedAffine(env.merger(), tid, l, a, dlt, needSlice)) + return false; // inadmissible affine expression + } else { + if (!findAffine(env.merger(), tid, l, a, dlt, filterLdx)) + return false; // inadmissible affine expression + } } } assert(filterLdx == env.merger().getNumLoops()); @@ -469,11 +547,11 @@ } } -static void tryLoosenAffineDenseConstraints(linalg::GenericOp op, - std::optional &fldx, - AffineExpr &fa, - std::optional &tldx, - AffineExpr &ta) { +static void tryLoosenAffineConstraints(linalg::GenericOp op, + std::optional &fldx, + AffineExpr &fa, + std::optional &tldx, + AffineExpr &ta) { // We use a heuristic here to only pick one dim expression from each // compound affine expression to establish the order between two dense // dimensions. @@ -494,7 +572,7 @@ } if (!ta.isa()) { // Heuristic: we prefer reduction loop for rhs to reduce the chance - // addint reduce < parallel ordering. + // adding reduce < parallel ordering. finder.setPickedIterType(utils::IteratorType::reduction); finder.walkPostOrder(ta); ta = finder.getDimExpr(); @@ -503,14 +581,183 @@ } } -/// Computes a topologically sorted iteration graph for the linalg operation. -/// Ensures all tensors are visited in natural coordinate order. This is -/// essential for sparse storage formats since these only support access -/// along fixed levels. Even for dense storage formats, however, the natural -/// coordinate order yields innermost unit-stride access with better spatial -/// locality. +static void sortArrayBasedOnOrder(std::vector &target, + ArrayRef order) { + std::sort(target.begin(), target.end(), [&order](unsigned l, unsigned r) { + assert(l != r); + int idxL = -1, idxR = -1; + for (int i = 0, e = order.size(); i < e; i++) { + if (order[i] == l) + idxL = i; + if (order[i] == r) + idxR = i; + } + assert(idxL >= 0 && idxR >= 0); + return idxL < idxR; + }); +} + +static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t, + OpOperand *skip, SortMask mask, + std::vector> &adjM, + std::vector &inDegree) { + // Get map and encoding. + auto map = env.op().getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); + + // Each tensor expression and optional dimension ordering (row-major + // by default) puts an ordering constraint on the loop indices. For + // example, the tensor expresion A_ijk forces the ordering i < j < k + // on the loop indices if no explicit dimension ordering is given. + for (unsigned l = 0, rank = map.getNumResults(); l < rank; l++) { + AffineExpr ta = map.getResult(toOrigDim(enc, l)); + std::optional tldx = + env.merger().getLoopId(t.getOperandNumber(), l); + // Filter loops should be constructed after all the dependent loops, + // i.e., d0 + d1 < filter_loop(d0 + d1) + if (tldx && env.merger().isFilterLoop(*tldx)) { + enc.isAllDense(); + assert(!ta.isa() && !isDenseDLT(enc.getDimLevelType()[l])); + addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx); + // Now that the ordering of affine expression is captured by filter + // loop idx, we only need to ensure the affine ordering against filter + // loop. Thus, we reset the affine express to nil here to mark it as + // resolved. + ta = AffineExpr(); + } + + // Skip tensor during cycle resolution, though order between filter loop + // and dependent loops need to be guaranteed unconditionally. + if (&t == skip) + continue; + + if (l > 0) { + AffineExpr fa = map.getResult(toOrigDim(enc, l - 1)); + std::optional fldx = + env.merger().getLoopId(t.getOperandNumber(), l - 1); + + // Applying order constraints on every pair of dimExpr between two + // compound affine expressions can sometime too strict: + // E.g, for [dense, dense] -> (d0 + d1, d2 + d3). + // It is totally fine to have loop sequence d0->d2->d1->d3 instead of + // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3. + // We also loosen the affine constraint when use slice-based algorithm + // as there is no filter loop for affine index on sparse dimension. + // TODO: do we really need the condition? + if (!includesDense(mask)) + tryLoosenAffineConstraints(env.op(), fldx, fa, tldx, ta); + + // (d0 + d1) < (d2 + d3), or + // filter_loop_d-1 < (d2 + d3), or + // (d0 + d1) < filter_loop_d, or + // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset + // above. + addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx); + } + } +} + +static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t, + OpOperand *skip, SortMask mask, + std::vector> &adjM, + std::vector &inDegree) { + // Get map and encoding. + auto map = env.op().getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); + + // No special treatment for simple indices. + if (getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) == 0) + return addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree); + + // Skip tensor during cycle resolution, though order between filter loop + // and dependent loops need to be guaranteed unconditionally. + if (&t == skip) + return; + + AffineDimFinder finder(env.op()); + finder.setPickedIterType(utils::IteratorType::reduction); + // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6], + // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6}, + // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y + for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { + AffineExpr fa = map.getResult(toOrigDim(enc, d - 1)); + AffineExpr ta = map.getResult(toOrigDim(enc, d)); + + // This is a heurisitic, we pick an abitrary reduction loop from lhs and + // rhs and use them as d_x and d_y. + finder.walkPostOrder(fa); + AffineDimExpr fexp = finder.getDimExpr(); + unsigned fldx = fexp.getPosition(); + + finder.walkPostOrder(ta); + AffineDimExpr texp = finder.getDimExpr(); + unsigned tldx = texp.getPosition(); + + // d_x > d_y + if (!adjM[fldx][tldx]) { + adjM[fldx][tldx] = true; + inDegree[tldx]++; + } + + AffineDimCollector fCollector; + fCollector.walkPostOrder(fa); + AffineDimCollector tCollector; + tCollector.walkPostOrder(ta); + + // make sure dx and dy is the last; + for (auto fd : fCollector.dims) { + unsigned f = fd.getPosition(); + if (f == fldx) + continue; + if (!adjM[f][fldx]) { + adjM[f][fldx] = true; + inDegree[fldx]++; + } + } + for (auto td : tCollector.dims) { + unsigned t = td.getPosition(); + if (t == tldx) + continue; + if (!adjM[t][tldx]) { + adjM[t][tldx] = true; + inDegree[tldx]++; + } + } + // Since we only support affine addition, the order between two dim + // expression does not really matters. + // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y + // This is to ensure that the affine expressions are reduced in sparse + // tensor level ordering. + // TODO: this ordering could probably be loosen if we support out-of-order + // reduction. + // TODO: the evaluation order need to be ensure to + // support affine multiplication. + for (auto fd : fCollector.dims) { + unsigned f = fd.getPosition(); + if (f == fldx) // skip d_x + continue; + + for (auto td : tCollector.dims) { + unsigned t = td.getPosition(); + if (t == tldx) // skip d_y + continue; + if (!adjM[f][t]) { + adjM[f][t] = true; + inDegree[t]++; + } + } + } + } +} + +/// Computes a topologically sorted iteration graph for the linalg +/// operation. Ensures all tensors are visited in natural index order. This +/// is essential for sparse storage formats since these only support access +/// along fixed dimensions. Even for dense storage formats, however, the +/// natural index order yields innermost unit-stride access with better +/// spatial locality. static bool computeIterationGraph(CodegenEnv &env, SortMask mask, - OpOperand *skip = nullptr) { + OpOperand *skip, bool useSlice = false) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. const LoopId n = env.merger().getNumLoops(); @@ -522,7 +769,8 @@ // Get map and encoding. const auto map = env.op().getMatchingIndexingMap(&t); const auto enc = getSparseTensorEncoding(t.get().getType()); - assert(map.getNumDims() + getNumCompoundAffineOnSparseLvls(env.op()) == n); + assert(map.getNumDims() + getNumNonTrivialIdxExpOnSparseLvls(env.op()) == + n); // Skips dense inputs/outputs when not requested. const bool isDenseInput = !enc && env.op().isDpsInput(&t); @@ -549,63 +797,12 @@ } } } - - // Each tensor expression and optional dimension ordering (row-major - // by default) puts an ordering constraint on the loop indices. For - // example, the tensor expresion A_ijk forces the ordering i < j < k - // on the loop indices if no explicit dimension ordering is given. - const Level lvlRank = map.getNumResults(); - assert(!enc || lvlRank == enc.getLvlRank()); - for (Level l = 0; l < lvlRank; l++) { - // FIXME: `toOrigDim` is deprecated. - // FIXME: above we asserted that there are `lvlRank` many results, - // but this is assuming there are in fact `dimRank` many results instead. - AffineExpr ta = map.getResult(toOrigDim(enc, l)); - std::optional tldx = - env.merger().getLoopId(t.getOperandNumber(), l); - - // Filter loops should be constructed after all the dependent loops, - // i.e., d0 + d1 < filter_loop(d0 + d1) - if (tldx && env.merger().isFilterLoop(*tldx)) { - assert(!ta.isa() && !isDenseDLT(enc.getLvlType(l))); - addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, - tldx); - // Now that the ordering of affine expression is captured by filter - // loop idx, we only need to ensure the affine ordering against filter - // loop. Thus, we reset the affine express to nil here to mark it as - // resolved. - ta = AffineExpr(); - } - - // Skip tensor during cycle resolution, though order between filter loop - // and dependent loops need to be guaranteed unconditionally. - if (&t == skip) - continue; - - if (l > 0) { - // FIXME: `toOrigDim` is deprecated. - // FIXME: above we asserted that there are `lvlRank` many results, - // but this is assuming there are in fact `dimRank` many results. - AffineExpr fa = map.getResult(toOrigDim(enc, l - 1)); - std::optional fldx = - env.merger().getLoopId(t.getOperandNumber(), l - 1); - - // Applying order constraints on every pair of dimExpr between two - // compound affine expressions can sometime too strict: - // E.g, for [dense, dense] -> (d0 + d1, d2 + d3). - // It is totally fine to have loop sequence d0->d2->d1->d3 instead of - // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3. - if (!includesDense(mask)) - tryLoosenAffineDenseConstraints(env.op(), fldx, fa, tldx, ta); - - // (d0 + d1) < (d2 + d3), or - // filter_loop_d-1 < (d2 + d3), or - // (d0 + d1) < filter_loop_d, or - // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset - // above. - addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx); - } - } + // Push unrelated loops into sparse iteration space, so these + // will be skipped more often. + if (useSlice) + addSliceBasedConstraints(env, t, skip, mask, adjM, inDegree); + else + addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree); } // Topologically sort the iteration graph to determine loop order. // Report failure for a cyclic iteration graph. @@ -1596,21 +1793,25 @@ PatternRewriter &rewriter) const override { // Only accept single output operations without affine index on sparse // output. - if (op.getNumDpsInits() != 1 || hasCompoundAffineOnSparseOut(op)) + if (op.getNumDpsInits() != 1 || hasNonTrivialAffineOnSparseOut(op)) return failure(); - if (options.enableIndexReduction) - llvm_unreachable("not yet implemented"); - // Sets up a code generation environment. const unsigned numTensors = op->getNumOperands(); const unsigned numLoops = op.getNumLoops(); - const unsigned numFilterLoops = getNumCompoundAffineOnSparseLvls(op); - CodegenEnv env(op, options, numTensors, numLoops, numFilterLoops); + const unsigned numFilterLoops = getNumNonTrivialIdxExpOnSparseLvls(op); + // TODO: we should probably always use slice-based codegen whenever + // possible, we can even intermix slice-based and filter-loop based codegen. + bool sliceBased = options.enableIndexReduction && numFilterLoops != 0; + + // If we uses slice based algorithm for affine index, we do not need filter + // loop. + CodegenEnv env(op, options, numTensors, numLoops, + /*numFilterLoops=*/sliceBased ? 0 : numFilterLoops); // Detects sparse annotations and translates the per-level sparsity // information for all tensors to loop indices in the kernel. - if (!findSparseAnnotations(env)) + if (!findSparseAnnotations(env, sliceBased)) return failure(); // Constructs the tensor expressions tree from `op`, returns failure if the @@ -1635,7 +1836,7 @@ SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput, SortMask::kIncludeUndef, SortMask::kSparseOnly}; for (const SortMask mask : allMasks) { - if (computeIterationGraph(env, mask)) { + if (computeIterationGraph(env, mask, nullptr, sliceBased)) { hasCycle = false; if (env.isAdmissibleTopoOrder()) { isAdmissible = true; @@ -1644,11 +1845,24 @@ // else try a set of less strict constraints. } } - if (hasCycle) - return resolveCycle(env, rewriter); // one last shot + if (hasCycle) { + return sliceBased + ? failure() // TODO: should cycle be resolved differently? + : resolveCycle(env, rewriter); // one last shot + } + if (!isAdmissible) return failure(); // inadmissible expression, reject + for (OpOperand &t : env.op()->getOpOperands()) { + unsigned rank = env.op().getMatchingIndexingMap(&t).getNumResults(); + for (unsigned i = 0; i < rank; i++) { + sortArrayBasedOnOrder( + env.merger().getRelatedLoops(t.getOperandNumber(), i), + env.getTopSort()); + } + } + // Recursively generates code if admissible. env.startEmit(); genBuffers(env, rewriter); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -220,7 +220,11 @@ loopToLvl(numTensors, std::vector>(numLoops, std::nullopt)), lvlToLoop(numTensors, - std::vector>(numLoops, std::nullopt)) {} + std::vector>(numLoops, std::nullopt)), + ldxToDependentSlice(numTensors, std::vector>( + numLoops, std::nullopt)), + sliceToRelatedldx(numTensors, std::vector>( + numLoops, std::vector())) {} //===----------------------------------------------------------------------===// // Lattice methods. @@ -763,6 +767,8 @@ const LoopId i = loop(b); const auto dlt = lvlTypes[t][i]; llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); + if (ldxToDependentSlice[t][i]) + llvm::dbgs() << "_D_" << *ldxToDependentSlice[t][i]; } } }