diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -445,9 +445,30 @@ callback(b, tensor(b), getLvl(b), getDimLevelType(b)); } - /// Sets whether the output tensor is sparse or not. + // Has sparse output tensor setter. void setHasSparseOut(bool s) { hasSparseOut = s; } + // FIXME: Dim or Level? + void setLoopDependentSliceDim(unsigned l, unsigned t, unsigned dim) { + assert(dim < numLoops); + ldxToDependentSlice[t][l] = dim; + sliceToRelatedldx[t][dim].push_back(l); + } + + // Whether the ldx has dependent slice. + bool hasDependentSliceDim(unsigned l) { + return llvm::any_of(ldxToDependentSlice, + [l](const std::vector> &m) { + return m[l].has_value(); + }); + } + + // Returns the list of loop index appeared in the affine indexing expression + // on t[d]. + std::vector &getRelatedLoops(unsigned t, unsigned d) { + return sliceToRelatedldx[t][d]; + } + /// Convenience getters to immediately access the stored nodes. /// Typically it is inadvisible to keep the reference around, as in /// `TensorExpr &te = merger.exp(e)`, since insertions into the merger @@ -511,6 +532,11 @@ // Map that converts pair to the corresponding LoopId. std::vector>> lvlToLoop; + // Map from loop idx to the dependent slice [tid, dim] pair (if any). + std::vector>> ldxToDependentSlice; + // Map from the dependent slice [tid, dim] pair to a list of loop idx. + std::vector>> sliceToRelatedldx; + llvm::SmallVector tensorExps; llvm::SmallVector latPoints; llvm::SmallVector> latSets; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -99,6 +99,7 @@ topSort.reserve(capacity); } + ArrayRef getTopSort() const { return topSort; }; ArrayRef getTopSortSlice(LoopOrd n, LoopOrd m) const; ArrayRef getLoopStackUpTo(LoopOrd n) const; ArrayRef getCurrentLoopStack() const; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -154,6 +154,13 @@ break; // terminate at first reduction nest++; } + // Determine admissible dynamic insertion situations: + // (1) fully injective, since there are no reductions, + // (2) admissible 1-d expansion in innermost dimension. + if (nest >= op().getRank(lhs) - 1) + outerParNest = nest; + else + return false; } // Determine admissible dynamic insertion situations: // (1) fully injective, since there are no reductions, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -109,6 +109,14 @@ SmallVector iterTypes; }; +/// A helper class that visits an affine expression and tries to find an +/// AffineDimExpr to which the corresponding iterator from a GenericOp matches +/// the desired iterator type. +struct AffineDimCollector : public AffineExprVisitor { + void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); } + SmallVector dims; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -254,6 +262,66 @@ } } +/// Helper method to inspect affine expressions. Rejects cases where the +/// same index is used more than once. Also rejects compound affine +/// expressions in sparse dimensions. +/// filterIdx stores the current filter loop idx should be used for the next +/// compound affine sparse level, and it will be incremented by one when +/// used. +/// Helper method to inspect affine expressions for slice based codegen. Rejects +/// cases where the same index is used more than once. +static bool findSliceBasedAffine(Merger &merger, unsigned tensor, unsigned dim, + AffineExpr a, DimLevelType dlt, bool needSlice, + bool isSubExp = false) { + switch (a.getKind()) { + case AffineExprKind::DimId: { + unsigned idx = a.cast().getPosition(); + if (!isUndefDLT(merger.getDimLevelType(tensor, idx))) + return false; // used more than once + if (!isSubExp) + merger.setLevelAndType(tensor, idx, dim, dlt); + + if (isSubExp && needSlice) { + // The current loops appears in more than one affine expressions. + // Now, for simplicity, we do not handle this case. + if (merger.hasDependentSliceDim(idx)) { + // TODO: This can be supported by coiterate slices if the loop idx is + // appeared on affine index for different tensor, or take slice on + // mulitple dimensions when it is on the same tensor. + // E.g., + // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0] + // d0_1 = getNextSliceOffset t0 along lvl0 + // d0_2 = getNextSliceOffset t1 along lvl0 + // if d0_1 == d0_2 then d0 = d0_1 = d0_1 + // else increase min(d0_1, d0_2). + return false; + } + merger.setLoopDependentSliceDim(idx, tensor, dim); + } + return true; + } + case AffineExprKind::Constant: + case AffineExprKind::Mul: + // TODO: Support Mul and Constant AffineExp for slice-based codegen + if (needSlice) + return false; + [[fallthrough]]; + case AffineExprKind::Add: { + auto binOp = a.cast(); + // We do not set dim level format for affine expresssion like d0 + d1 on + // either loop index at d0 or d1. + // We continue the recursion merely to check whether current affine is + // admissible or not. + return findSliceBasedAffine(merger, tensor, dim, binOp.getLHS(), dlt, + needSlice, true) && + findSliceBasedAffine(merger, tensor, dim, binOp.getRHS(), dlt, + needSlice, true); + } + default: + return false; + } +} + /// Get the total number of compound affine expressions in the /// `getMatchingIndexingMap` for the given tensor. For the following inputs: /// @@ -326,7 +394,7 @@ /// Returns true if the sparse annotations and affine subscript /// expressions of all tensors are admissible. Returns false if /// no annotations are found or inadmissible constructs occur. -static bool findSparseAnnotations(CodegenEnv &env) { +static bool findSparseAnnotations(CodegenEnv &env, bool sliceBased) { bool annotated = false; // `filterLdx` may be mutated by `findAffine`. LoopId filterLdx = env.merger().getStartingFilterLoopId(); @@ -335,17 +403,23 @@ const auto enc = getSparseTensorEncoding(t.get().getType()); if (enc) annotated = true; + const Level lvlRank = map.getNumResults(); assert(!enc || lvlRank == enc.getLvlRank()); - assert(static_cast(env.op().getRank(&t)) == lvlRank); + bool needSlice = enc && getNumCompoundAffineOnSparseLvls(map, t.get()) != 0; + // If then current tensor being inspected requires affine index, it need + // to be sliced. for (Level l = 0; l < lvlRank; l++) { const TensorId tid = t.getOperandNumber(); - // FIXME: `toOrigDim` is deprecated. - // FIXME: above we asserted that there are `lvlRank` many results, - // but this is assuming there are in fact `dimRank` many results instead. - const AffineExpr a = map.getResult(toOrigDim(enc, l)); - if (!findAffine(env.merger(), tid, l, a, enc.getLvlType(l), filterLdx)) - return false; // inadmissible affine expression + AffineExpr a = map.getResult(toOrigDim(enc, l)); + DimLevelType dlt = enc.getLvlType(l); + if (sliceBased) { + if (!findSliceBasedAffine(env.merger(), tid, l, a, dlt, needSlice)) + return false; // inadmissible affine expression + } else { + if (!findAffine(env.merger(), tid, l, a, dlt, filterLdx)) + return false; // inadmissible affine expression + } } } assert(filterLdx == env.merger().getNumLoops()); @@ -401,6 +475,7 @@ auto &it = !filterIt.empty() ? filterIt : (!parIt.empty() ? parIt : redIt); auto src = it.back(); env.topSortPushBack(src); + // llvm::outs() << src; it.pop_back(); // Update in-degree, and push 0-degree node into worklist. for (LoopId dst = 0; dst < n; dst++) { @@ -469,11 +544,11 @@ } } -static void tryLoosenAffineDenseConstraints(linalg::GenericOp op, - std::optional &fldx, - AffineExpr &fa, - std::optional &tldx, - AffineExpr &ta) { +static void tryLoosenAffineConstraints(linalg::GenericOp op, + std::optional &fldx, + AffineExpr &fa, + std::optional &tldx, + AffineExpr &ta) { // We use a heuristic here to only pick one dim expression from each // compound affine expression to establish the order between two dense // dimensions. @@ -494,7 +569,7 @@ } if (!ta.isa()) { // Heuristic: we prefer reduction loop for rhs to reduce the chance - // addint reduce < parallel ordering. + // adding reduce < parallel ordering. finder.setPickedIterType(utils::IteratorType::reduction); finder.walkPostOrder(ta); ta = finder.getDimExpr(); @@ -503,14 +578,183 @@ } } -/// Computes a topologically sorted iteration graph for the linalg operation. -/// Ensures all tensors are visited in natural coordinate order. This is -/// essential for sparse storage formats since these only support access -/// along fixed levels. Even for dense storage formats, however, the natural -/// coordinate order yields innermost unit-stride access with better spatial -/// locality. +static void sortArrayBasedOnOrder(std::vector &target, + ArrayRef order) { + std::sort(target.begin(), target.end(), [&order](unsigned l, unsigned r) { + assert(l != r); + int idxL = -1, idxR = -1; + for (int i = 0, e = order.size(); i < e; i++) { + if (order[i] == l) + idxL = i; + if (order[i] == r) + idxR = i; + } + assert(idxL >= 0 && idxR >= 0); + return idxL < idxR; + }); +} + +static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t, + OpOperand *skip, SortMask mask, + std::vector> &adjM, + std::vector &inDegree) { + // Get map and encoding. + auto map = env.op().getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); + + // Each tensor expression and optional dimension ordering (row-major + // by default) puts an ordering constraint on the loop indices. For + // example, the tensor expresion A_ijk forces the ordering i < j < k + // on the loop indices if no explicit dimension ordering is given. + for (unsigned l = 0, rank = map.getNumResults(); l < rank; l++) { + AffineExpr ta = map.getResult(toOrigDim(enc, l)); + std::optional tldx = + env.merger().getLoopId(t.getOperandNumber(), l); + // Filter loops should be constructed after all the dependent loops, + // i.e., d0 + d1 < filter_loop(d0 + d1) + if (tldx && env.merger().isFilterLoop(*tldx)) { + enc.isAllDense(); + assert(!ta.isa() && !isDenseDLT(enc.getDimLevelType()[l])); + addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx); + // Now that the ordering of affine expression is captured by filter + // loop idx, we only need to ensure the affine ordering against filter + // loop. Thus, we reset the affine express to nil here to mark it as + // resolved. + ta = AffineExpr(); + } + + // Skip tensor during cycle resolution, though order between filter loop + // and dependent loops need to be guaranteed unconditionally. + if (&t == skip) + continue; + + if (l > 0) { + AffineExpr fa = map.getResult(toOrigDim(enc, l - 1)); + std::optional fldx = + env.merger().getLoopId(t.getOperandNumber(), l - 1); + + // Applying order constraints on every pair of dimExpr between two + // compound affine expressions can sometime too strict: + // E.g, for [dense, dense] -> (d0 + d1, d2 + d3). + // It is totally fine to have loop sequence d0->d2->d1->d3 instead of + // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3. + // We also loosen the affine constraint when use slice-based algorithm + // as there is no filter loop for affine index on sparse dimension. + // TODO: do we really need the condition? + if (!includesDense(mask)) + tryLoosenAffineConstraints(env.op(), fldx, fa, tldx, ta); + + // (d0 + d1) < (d2 + d3), or + // filter_loop_d-1 < (d2 + d3), or + // (d0 + d1) < filter_loop_d, or + // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset + // above. + addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx); + } + } +} + +static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t, + OpOperand *skip, SortMask mask, + std::vector> &adjM, + std::vector &inDegree) { + // Get map and encoding. + auto map = env.op().getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); + + // No special treatment for simple indices. + if (getNumCompoundAffineOnSparseLvls(map, t.get()) == 0) + return addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree); + + // Skip tensor during cycle resolution, though order between filter loop + // and dependent loops need to be guaranteed unconditionally. + if (&t == skip) + return; + + AffineDimFinder finder(env.op()); + finder.setPickedIterType(utils::IteratorType::reduction); + // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6], + // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6}, + // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y + for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { + AffineExpr fa = map.getResult(toOrigDim(enc, d - 1)); + AffineExpr ta = map.getResult(toOrigDim(enc, d)); + + // This is a heurisitic, we pick an abitrary reduction loop from lhs and + // rhs and use them as d_x and d_y. + finder.walkPostOrder(fa); + AffineDimExpr fexp = finder.getDimExpr(); + unsigned fldx = fexp.getPosition(); + + finder.walkPostOrder(ta); + AffineDimExpr texp = finder.getDimExpr(); + unsigned tldx = texp.getPosition(); + + // d_x > d_y + if (!adjM[fldx][tldx]) { + adjM[fldx][tldx] = true; + inDegree[tldx]++; + } + + AffineDimCollector fCollector; + fCollector.walkPostOrder(fa); + AffineDimCollector tCollector; + tCollector.walkPostOrder(ta); + + // make sure dx and dy is the last; + for (auto fd : fCollector.dims) { + unsigned f = fd.getPosition(); + if (f == fldx) + continue; + if (!adjM[f][fldx]) { + adjM[f][fldx] = true; + inDegree[fldx]++; + } + } + for (auto td : tCollector.dims) { + unsigned t = td.getPosition(); + if (t == tldx) + continue; + if (!adjM[t][tldx]) { + adjM[t][tldx] = true; + inDegree[tldx]++; + } + } + // Since we only support affine addition, the order between two dim + // expression does not really matters. + // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y + // This is to ensure that the affine expressions are reduced in sparse + // tensor level ordering. + // TODO: this ordering could probably be loosen if we support out-of-order + // reduction. + // TODO: the evaluation order need to be ensure to + // support affine multiplication. + for (auto fd : fCollector.dims) { + unsigned f = fd.getPosition(); + if (f == fldx) // skip d_x + continue; + + for (auto td : tCollector.dims) { + unsigned t = td.getPosition(); + if (t == tldx) // skip d_y + continue; + if (!adjM[f][t]) { + adjM[f][t] = true; + inDegree[t]++; + } + } + } + } +} + +/// Computes a topologically sorted iteration graph for the linalg +/// operation. Ensures all tensors are visited in natural index order. This +/// is essential for sparse storage formats since these only support access +/// along fixed dimensions. Even for dense storage formats, however, the +/// natural index order yields innermost unit-stride access with better +/// spatial locality. static bool computeIterationGraph(CodegenEnv &env, SortMask mask, - OpOperand *skip = nullptr) { + OpOperand *skip, bool useSlice = false) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. const LoopId n = env.merger().getNumLoops(); @@ -549,63 +793,12 @@ } } } - - // Each tensor expression and optional dimension ordering (row-major - // by default) puts an ordering constraint on the loop indices. For - // example, the tensor expresion A_ijk forces the ordering i < j < k - // on the loop indices if no explicit dimension ordering is given. - const Level lvlRank = map.getNumResults(); - assert(!enc || lvlRank == enc.getLvlRank()); - for (Level l = 0; l < lvlRank; l++) { - // FIXME: `toOrigDim` is deprecated. - // FIXME: above we asserted that there are `lvlRank` many results, - // but this is assuming there are in fact `dimRank` many results instead. - AffineExpr ta = map.getResult(toOrigDim(enc, l)); - std::optional tldx = - env.merger().getLoopId(t.getOperandNumber(), l); - - // Filter loops should be constructed after all the dependent loops, - // i.e., d0 + d1 < filter_loop(d0 + d1) - if (tldx && env.merger().isFilterLoop(*tldx)) { - assert(!ta.isa() && !isDenseDLT(enc.getLvlType(l))); - addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, - tldx); - // Now that the ordering of affine expression is captured by filter - // loop idx, we only need to ensure the affine ordering against filter - // loop. Thus, we reset the affine express to nil here to mark it as - // resolved. - ta = AffineExpr(); - } - - // Skip tensor during cycle resolution, though order between filter loop - // and dependent loops need to be guaranteed unconditionally. - if (&t == skip) - continue; - - if (l > 0) { - // FIXME: `toOrigDim` is deprecated. - // FIXME: above we asserted that there are `lvlRank` many results, - // but this is assuming there are in fact `dimRank` many results. - AffineExpr fa = map.getResult(toOrigDim(enc, l - 1)); - std::optional fldx = - env.merger().getLoopId(t.getOperandNumber(), l - 1); - - // Applying order constraints on every pair of dimExpr between two - // compound affine expressions can sometime too strict: - // E.g, for [dense, dense] -> (d0 + d1, d2 + d3). - // It is totally fine to have loop sequence d0->d2->d1->d3 instead of - // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3. - if (!includesDense(mask)) - tryLoosenAffineDenseConstraints(env.op(), fldx, fa, tldx, ta); - - // (d0 + d1) < (d2 + d3), or - // filter_loop_d-1 < (d2 + d3), or - // (d0 + d1) < filter_loop_d, or - // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset - // above. - addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx); - } - } + // Push unrelated loops into sparse iteration space, so these + // will be skipped more often. + if (useSlice) + addSliceBasedConstraints(env, t, skip, mask, adjM, inDegree); + else + addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree); } // Topologically sort the iteration graph to determine loop order. // Report failure for a cyclic iteration graph. @@ -1349,75 +1542,75 @@ unsigned numloopCond = 0; bool hasNonUnique = false; - env.merger().foreachTensorLoopId( - all, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, - DimLevelType dlt) { - if (simple.test(b)) { - if (isUndefDLT(dlt)) { - // An undefined dlt in the lattices, we probably mean to - // iterate based on the level of output tensor. E.g., this - // could be a synthetic tensor (for invariants and sparse - // output tensor). - // out[i][j] = invariant; or a broadcast - // out[i][j] = in[i] (j is undef for input) - tid = outTid; - lvl = outLvl; - // Skips invalid lvl (e.g., when this is a zero ranked tensor). - if (!lvl) - return; - } - hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique; - tids.push_back(tid); - lvls.push_back(*lvl); - numloopCond++; - } else if (isDenseDLT(dlt)) { - tids.push_back(tid); - lvls.push_back(*lvl); - } else { - assert(isUndefDLT(dlt)); - linalg::GenericOp op = env.op(); - if (tid >= op.getNumDpsInputs()) - // We only handle affine expression on input tensors (for now). - return; - OpOperand *operand = &op->getOpOperand(tid); - const auto stt = getSparseTensorType(operand->get()); - // Non-annotated dense tensors requires no special handling. - if (!stt.hasEncoding()) - return; - - ArrayRef affines = - op.getMatchingIndexingMap(operand).getResults(); - const Level lvlRank = stt.getLvlRank(); - assert(affines.size() == static_cast(lvlRank)); - for (Level l = 0; l < lvlRank; l++) { - // FIXME: `toOrigDim` is deprecated. - AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)]; - // Skip simple affine expression and non-dense levels (which - // have their own filter loop). - if (exp.isa() || !stt.isDenseLvl(l)) - continue; - - // Constant affine expression are handled in genLoop - if (!exp.isa()) { - bool isAtLoop = false; - if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) { - // If the compound affine is invariant and we are right at the - // level. We need to generate the address according to the - // affine expression. This is also the best place we can do it - // to avoid putting it inside inner loops. - // NOTE: It assumes that the levels of the input tensor are - // initialized in order (and it is also currently guaranteed by - // computeIterationGraph), another more admissible approach - // might be accepting out-of-order access between consecutive - // dense levels. - affineTids.push_back(tid); - affineLvls.push_back(l); - exps.push_back(exp); - } - } + env.merger().foreachTensorLoopId(all, [&, ldx](TensorLoopId b, TensorId tid, + std::optional lvl, + DimLevelType dlt) { + if (simple.test(b)) { + if (isUndefDLT(dlt)) { + // An undefined dlt in the lattices, we probably mean to + // iterate based on the level of output tensor. E.g., this + // could be a synthetic tensor (for invariants and sparse + // output tensor). + // out[i][j] = invariant; or a broadcast + // out[i][j] = in[i] (j is undef for input) + tid = outTid; + lvl = outLvl; + // Skips invalid lvl (e.g., when this is a zero ranked tensor). + if (!lvl) + return; + } + hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique; + tids.push_back(tid); + lvls.push_back(*lvl); + numloopCond++; + } else if (isDenseDLT(dlt)) { + tids.push_back(tid); + lvls.push_back(*lvl); + } else { + assert(isUndefDLT(dlt)); + linalg::GenericOp op = env.op(); + if (tid >= op.getNumDpsInputs()) + // We only handle affine expression on input tensors (for now). + return; + OpOperand *operand = &op->getOpOperand(tid); + const auto stt = getSparseTensorType(operand->get()); + // Non-annotated dense tensors requires no special handling. + if (!stt.hasEncoding()) + return; + + ArrayRef affines = + op.getMatchingIndexingMap(operand).getResults(); + const Level lvlRank = stt.getLvlRank(); + assert(affines.size() == static_cast(lvlRank)); + for (Level l = 0; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)]; + // Skip simple affine expression and non-dense levels (which + // have their own filter loop). + if (exp.isa() || !stt.isDenseLvl(l)) + continue; + + // Constant affine expression are handled in genLoop + if (!exp.isa()) { + bool isAtLoop = false; + if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) { + // If the compound affine is invariant and we are right at the + // level. We need to generate the address according to the + // affine expression. This is also the best place we can do it + // to avoid putting it inside inner loops. + // NOTE: It assumes that the levels of the input tensor are + // initialized in order (and it is also currently guaranteed by + // computeIterationGraph), another more admissible approach + // might be accepting out-of-order access between consecutive + // dense levels. + affineTids.push_back(tid); + affineLvls.push_back(l); + exps.push_back(exp); } } - }); + } + } + }); if (isDenseDLT(env.dlt(outTid, ldx))) { // Note that we generate dense indices of the output tensor @@ -1600,18 +1793,22 @@ if (op.getNumDpsInits() != 1 || hasCompoundAffineOnSparseOut(op)) return failure(); - if (options.enableSliceBasedAffine) - llvm_unreachable("not yet implemented"); - // Sets up a code generation environment. const unsigned numTensors = op->getNumOperands(); const unsigned numLoops = op.getNumLoops(); const unsigned numFilterLoops = getNumCompoundAffineOnSparseLvls(op); - CodegenEnv env(op, options, numTensors, numLoops, numFilterLoops); + // TODO: we should probably always use slice-based codegen whenever + // possible, we can even intermix slice-based and filter-loop based codegen. + bool sliceBased = options.enableSliceBasedAffine && numFilterLoops != 0; + + // If we uses slice based algorithm for affine index, we do not need filter + // loop. + CodegenEnv env(op, options, numTensors, numLoops, + /*numFilterLoops=*/sliceBased ? 0 : numFilterLoops); // Detects sparse annotations and translates the per-level sparsity // information for all tensors to loop indices in the kernel. - if (!findSparseAnnotations(env)) + if (!findSparseAnnotations(env, sliceBased)) return failure(); // Constructs the tensor expressions tree from `op`, returns failure if the @@ -1636,7 +1833,7 @@ SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput, SortMask::kIncludeUndef, SortMask::kSparseOnly}; for (const SortMask mask : allMasks) { - if (computeIterationGraph(env, mask)) { + if (computeIterationGraph(env, mask, nullptr, sliceBased)) { hasCycle = false; if (env.isAdmissibleTopoOrder()) { isAdmissible = true; @@ -1645,11 +1842,24 @@ // else try a set of less strict constraints. } } - if (hasCycle) - return resolveCycle(env, rewriter); // one last shot + if (hasCycle) { + return sliceBased + ? failure() // TODO: should cycle be resolved differently? + : resolveCycle(env, rewriter); // one last shot + } + if (!isAdmissible) return failure(); // inadmissible expression, reject + for (OpOperand &t : env.op()->getOpOperands()) { + unsigned rank = env.op().getMatchingIndexingMap(&t).getNumResults(); + for (unsigned i = 0; i < rank; i++) { + sortArrayBasedOnOrder( + env.merger().getRelatedLoops(t.getOperandNumber(), i), + env.getTopSort()); + } + } + // Recursively generates code if admissible. env.startEmit(); genBuffers(env, rewriter); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -220,7 +220,11 @@ loopToLvl(numTensors, std::vector>(numLoops, std::nullopt)), lvlToLoop(numTensors, - std::vector>(numLoops, std::nullopt)) {} + std::vector>(numLoops, std::nullopt)), + ldxToDependentSlice(numTensors, std::vector>( + numLoops, std::nullopt)), + sliceToRelatedldx(numTensors, std::vector>( + numLoops, std::vector())) {} //===----------------------------------------------------------------------===// // Lattice methods. @@ -763,6 +767,8 @@ const LoopId i = loop(b); const auto dlt = lvlTypes[t][i]; llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); + if (ldxToDependentSlice[t][i]) + llvm::dbgs() << "_D_" << *ldxToDependentSlice[t][i]; } } }