diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -436,18 +436,55 @@ /// Iterates over a set of `TensorLoopId`s, invoking the callback /// for each `TensorLoopId` and passing it the corresponding tensor /// identifier, level, and level-type. - void - foreachTensorLoopId(LatPointId p, - function_ref, DimLevelType)> - callback) const { - for (const TensorLoopId b : latPoints[p].bits.set_bits()) - callback(b, tensor(b), getLvl(b), getDimLevelType(b)); + void foreachTensorLoopId( + LatPointId p, function_ref, DimLevelType, bool)> + callback) { + for (const TensorLoopId b : latPoints[p].bits.set_bits()) { + TensorId t = tensor(b); + if (isLvlWithNonTrivialIdxExp(b)) { + // This must be an undefined dim. + assert(!getLvl(b).has_value()); + // Slice the tid along the dependent dim to iterate current loop. + callback(b, t, ldxToDependencies[loop(b)][t], getDimLevelType(b), true); + } else { + callback(b, t, getLvl(b), getDimLevelType(b), false); + } + } } /// Sets whether the output tensor is sparse or not. void setHasSparseOut(bool s) { hasSparseOut = s; } + /// Establishes the two-way map that l <-> . + void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl) { + assert(lvl < numLoops); + ldxToDependencies[i][t] = lvl; + lvlToDependentIdx[t][lvl].push_back(i); + } + + // Whether the ldx has dependent slice. + bool hasDependentLvl(LoopId i, TensorId tid) { + return ldxToDependencies[i][tid].has_value(); + } + + // Returns the list of loop indices appeared in the non-trivial indexing + // expression on t_l, e.g., A[i+j] => {i, j} + std::vector &getDependedLoops(TensorId t, Level lvl) { + return lvlToDependentIdx[t][lvl]; + } + + // Return the defining [tid, dim] for the loop. + std::pair getLoopDefiningDim(Level lvl) const { + return loopBounds[lvl]; + } + + /// Whether the lattice point represents a tensor level with non-trivial index + /// expression on it. + bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const { + return ldxToDependencies[loop(b)][tensor(b)].has_value(); + } + /// Convenience getters to immediately access the stored nodes. /// Typically it is inadvisible to keep the reference around, as in /// `TensorExpr &te = merger.exp(e)`, since insertions into the merger @@ -511,6 +548,21 @@ // Map that converts pair to the corresponding LoopId. std::vector>> lvlToLoop; + // Map from a loop idx to its dependencies if any. + // The dependencies of a loop idx is a set of (tensor, level) pairs. + // It is currently only set for non-trivial index expressions. + // E.g., A[i+j] => i and j will have dependencies {A0} to indicate that + // i and j are used in the non-trivial index expression on A0. + std::vector>> ldxToDependencies; + // The inverse map of ldxToDependencies from tensor level -> dependent loop + // index. + // E.g., A[i+j], we have A0 => {i, j}, to indicate that A0 uses both {i, j} + // to compute it indices. + std::vector>> lvlToDependentIdx; + + // Map from loop index to the [tid, dim] pair that defines the loop boundary. + std::vector> loopBounds; + llvm::SmallVector tensorExps; llvm::SmallVector latPoints; llvm::SmallVector> latSets; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -99,6 +99,7 @@ topSort.reserve(capacity); } + ArrayRef getTopSort() const { return topSort; }; ArrayRef getTopSortSlice(LoopOrd n, LoopOrd m) const; ArrayRef getLoopStackUpTo(LoopOrd n) const; ArrayRef getCurrentLoopStack() const; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -109,6 +109,12 @@ SmallVector iterTypes; }; +// Flattens an affine expression into a list of AffineDimExprs. +struct AffineDimCollector : public AffineExprVisitor { + void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); } + SmallVector dims; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -254,6 +260,69 @@ } } +/// Helper method to inspect affine expressions for index variable reduction +/// based codegen. It finds the dependent index set for all tensor levels in the +/// current expression we are generating. +/// +/// For example, when handling A[i+j][j+k], we build the two way mapping in +/// merger between (tensor, level) pairs and their dependent index variable set: +/// A_0 <=> [i, j] and A_1 <=> [j, k] +/// +/// It rejects cases (returns false) +/// 1st, when the same index is used more than once, e.g., A[i+j][i] +/// 2nd, when multiplication is used in the non-trivial index expression. +/// 3rd, when a constant operand is used in the non-trivial index expression. +/// +/// TODO: constant should be easy to handle. +static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, + AffineExpr a, DimLevelType dlt, + bool isSubExp = false) { + switch (a.getKind()) { + case AffineExprKind::DimId: { + LoopId ldx = a.cast().getPosition(); + if (!isUndefDLT(merger.getDimLevelType(tensor, ldx))) + return false; // used more than once, e.g., A[i][i] + + // TODO: Generalizes the following two cases. A[i] (with trivial index + // expression) can be treated as a special affine index expression. We do + // not necessarily need to differentiate them. + if (!isSubExp) + merger.setLevelAndType(tensor, ldx, lvl, dlt); + + if (isSubExp) { + // The current loops appears in more than one affine expressions on the + // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is + // used twice. + if (merger.hasDependentLvl(ldx, tensor)) { + // TODO: This can be supported by coiterate slices if the loop idx is + // appeared on affine index for different tensor, or take slice on + // mulitple dimensions when it is on the same tensor. + // E.g., + // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0] + // d0_1 = getNextSliceOffset t0 along lvl0 + // d0_2 = getNextSliceOffset t1 along lvl0 + // if d0_1 == d0_2 then d0 = d0_1 = d0_1 + // else increase min(d0_1, d0_2). + return false; + } + merger.setLoopDependentTensorLevel(ldx, tensor, lvl); + } + return true; + } + case AffineExprKind::Constant: + case AffineExprKind::Mul: + // TODO: Support Mul and Constant AffineExp for slice-based codegen + return false; + case AffineExprKind::Add: { + auto binOp = a.cast(); + return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), dlt, true) && + findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), dlt, true); + } + default: + return false; + } +} + /// Get the total number of compound affine expressions in the /// `getMatchingIndexingMap` for the given tensor. For the following inputs: /// @@ -262,7 +331,8 @@ /// /// Returns 1 (because the first level is compressed and its corresponding /// indexing-expression is `d0 + d1`) -static unsigned getNumCompoundAffineOnSparseLvls(AffineMap map, Value tensor) { +static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, + Value tensor) { // The `tensor` is not guaranted to have `RankedTensorType`, therefore // we can't use `getRankedTensorType`/`getSparseTensorType` here. // However, we don't need to handle `StorageSpecifierType`, so we @@ -305,20 +375,20 @@ /// Get the total number of sparse levels with compound affine /// expressions, summed over all operands of the `GenericOp`. -static unsigned getNumCompoundAffineOnSparseLvls(linalg::GenericOp op) { +static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) { unsigned num = 0; for (OpOperand &t : op->getOpOperands()) - num += getNumCompoundAffineOnSparseLvls(op.getMatchingIndexingMap(&t), - t.get()); + num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t), + t.get()); return num; } -static bool hasCompoundAffineOnSparseOut(linalg::GenericOp op) { +static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) { OpOperand *out = op.getDpsInitOperand(0); if (getSparseTensorType(out->get()).isAllDense()) return false; - return getNumCompoundAffineOnSparseLvls(op.getMatchingIndexingMap(out), - out->get()); + return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out), + out->get()); } /// Helper method to inspect sparse encodings in the tensor types. @@ -326,7 +396,7 @@ /// Returns true if the sparse annotations and affine subscript /// expressions of all tensors are admissible. Returns false if /// no annotations are found or inadmissible constructs occur. -static bool findSparseAnnotations(CodegenEnv &env) { +static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) { bool annotated = false; // `filterLdx` may be mutated by `findAffine`. LoopId filterLdx = env.merger().getStartingFilterLoopId(); @@ -335,17 +405,30 @@ const auto enc = getSparseTensorEncoding(t.get().getType()); if (enc) annotated = true; + const Level lvlRank = map.getNumResults(); assert(!enc || lvlRank == enc.getLvlRank()); assert(static_cast(env.op().getRank(&t)) == lvlRank); + + // We only need to do index reduction if there is at least one non-trivial + // index expression on sparse levels. + // If all non-trivial index expression is on dense levels, we can + // efficiently rely on the random access to locate the element. + bool needIdxReduc = + enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0; + // If then current tensor being inspected requires affine index, it need + // to be sliced. for (Level l = 0; l < lvlRank; l++) { const TensorId tid = t.getOperandNumber(); - // FIXME: `toOrigDim` is deprecated. - // FIXME: above we asserted that there are `lvlRank` many results, - // but this is assuming there are in fact `dimRank` many results instead. - const AffineExpr a = map.getResult(toOrigDim(enc, l)); - if (!findAffine(env.merger(), tid, l, a, enc.getLvlType(l), filterLdx)) - return false; // inadmissible affine expression + AffineExpr a = map.getResult(toOrigDim(enc, l)); + DimLevelType dlt = enc.getLvlType(l); + if (idxReducBased && needIdxReduc) { + if (!findDepIdxSet(env.merger(), tid, l, a, dlt)) + return false; // inadmissible affine expression + } else { + if (!findAffine(env.merger(), tid, l, a, dlt, filterLdx)) + return false; // inadmissible affine expression + } } } assert(filterLdx == env.merger().getNumLoops()); @@ -469,11 +552,11 @@ } } -static void tryLoosenAffineDenseConstraints(linalg::GenericOp op, - std::optional &fldx, - AffineExpr &fa, - std::optional &tldx, - AffineExpr &ta) { +static void tryRelaxAffineConstraints(linalg::GenericOp op, + std::optional &fldx, + AffineExpr &fa, + std::optional &tldx, + AffineExpr &ta) { // We use a heuristic here to only pick one dim expression from each // compound affine expression to establish the order between two dense // dimensions. @@ -494,7 +577,7 @@ } if (!ta.isa()) { // Heuristic: we prefer reduction loop for rhs to reduce the chance - // addint reduce < parallel ordering. + // adding reduce < parallel ordering. finder.setPickedIterType(utils::IteratorType::reduction); finder.walkPostOrder(ta); ta = finder.getDimExpr(); @@ -503,14 +586,184 @@ } } +/// Makes target array's elements appear in the same order as the `order` array. +static void sortArrayBasedOnOrder(std::vector &target, + ArrayRef order) { + std::sort(target.begin(), target.end(), [&order](LoopId l, LoopId r) { + assert(l != r); + int idxL = -1, idxR = -1; + for (int i = 0, e = order.size(); i < e; i++) { + if (order[i] == l) + idxL = i; + if (order[i] == r) + idxR = i; + } + assert(idxL >= 0 && idxR >= 0); + return idxL < idxR; + }); +} + +static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t, + OpOperand *skip, SortMask mask, + std::vector> &adjM, + std::vector &inDegree) { + // Get map and encoding. + auto map = env.op().getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); + + // Each tensor expression and optional dimension ordering (row-major + // by default) puts an ordering constraint on the loop indices. For + // example, the tensor expresion A_ijk forces the ordering i < j < k + // on the loop indices if no explicit dimension ordering is given. + for (Level l = 0, rank = map.getNumResults(); l < rank; l++) { + AffineExpr ta = map.getResult(toOrigDim(enc, l)); + std::optional tldx = + env.merger().getLoopId(t.getOperandNumber(), l); + // Filter loops should be constructed after all the dependent loops, + // i.e., d0 + d1 < filter_loop(d0 + d1) + if (tldx && env.merger().isFilterLoop(*tldx)) { + enc.isAllDense(); + assert(!ta.isa() && !isDenseDLT(enc.getDimLevelType()[l])); + addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx); + // Now that the ordering of affine expression is captured by filter + // loop idx, we only need to ensure the affine ordering against filter + // loop. Thus, we reset the affine express to nil here to mark it as + // resolved. + ta = AffineExpr(); + } + + // Skip tensor during cycle resolution, though order between filter loop + // and dependent loops need to be guaranteed unconditionally. + if (&t == skip) + continue; + + if (l > 0) { + AffineExpr fa = map.getResult(toOrigDim(enc, l - 1)); + std::optional fldx = + env.merger().getLoopId(t.getOperandNumber(), l - 1); + + // Applying order constraints on every pair of dimExpr between two + // compound affine expressions can sometime too strict: + // E.g, for [dense, dense] -> (d0 + d1, d2 + d3). + // It is totally fine to have loop sequence d0->d2->d1->d3 instead of + // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3. + // We also relax the affine constraint when use slice-based algorithm + // as there is no filter loop for affine index on sparse dimension. + // TODO: do we really need the condition? + if (!includesDense(mask)) + tryRelaxAffineConstraints(env.op(), fldx, fa, tldx, ta); + + // (d0 + d1) < (d2 + d3), or + // filter_loop_d-1 < (d2 + d3), or + // (d0 + d1) < filter_loop_d, or + // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset + // above. + addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx); + } + } +} + +static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t, + OpOperand *skip, SortMask mask, + std::vector> &adjM, + std::vector &inDegree) { + // Get map and encoding. + auto map = env.op().getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); + + // No special treatment for simple indices. + if (getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) == 0) + return addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree); + + // Skip tensor during cycle resolution, though order between filter loop + // and dependent loops need to be guaranteed unconditionally. + if (&t == skip) + return; + + AffineDimFinder finder(env.op()); + finder.setPickedIterType(utils::IteratorType::reduction); + // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6], + // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6}, + // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y + for (Level lvl = 1, rank = map.getNumResults(); lvl < rank; lvl++) { + AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1)); + AffineExpr ta = map.getResult(toOrigDim(enc, lvl)); + + // This is a heurisitic, we pick an abitrary reduction loop from lhs and + // rhs and use them as d_x and d_y. + finder.walkPostOrder(fa); + AffineDimExpr fexp = finder.getDimExpr(); + LoopId fldx = fexp.getPosition(); + + finder.walkPostOrder(ta); + AffineDimExpr texp = finder.getDimExpr(); + LoopId tldx = texp.getPosition(); + + // d_x > d_y + if (!adjM[fldx][tldx]) { + adjM[fldx][tldx] = true; + inDegree[tldx]++; + } + + AffineDimCollector fCollector; + fCollector.walkPostOrder(fa); + AffineDimCollector tCollector; + tCollector.walkPostOrder(ta); + + // make sure dx and dy is the last; + for (auto fd : fCollector.dims) { + LoopId f = fd.getPosition(); + if (f == fldx) + continue; + if (!adjM[f][fldx]) { + adjM[f][fldx] = true; + inDegree[fldx]++; + } + } + for (auto td : tCollector.dims) { + LoopId t = td.getPosition(); + if (t == tldx) + continue; + if (!adjM[t][tldx]) { + adjM[t][tldx] = true; + inDegree[tldx]++; + } + } + // Since we only support affine addition, the order between two dim + // expression does not really matters. + // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y + // This is to ensure that the affine expressions are reduced in sparse + // tensor level ordering. + // TODO: this ordering could probably be loosen if we support out-of-order + // reduction. + // TODO: the evaluation order need to be ensure to + // support affine multiplication. + for (auto fd : fCollector.dims) { + LoopId f = fd.getPosition(); + if (f == fldx) // skip d_x + continue; + + for (auto td : tCollector.dims) { + LoopId t = td.getPosition(); + if (t == tldx) // skip d_y + continue; + if (!adjM[f][t]) { + adjM[f][t] = true; + inDegree[t]++; + } + } + } + } +} + /// Computes a topologically sorted iteration graph for the linalg operation. -/// Ensures all tensors are visited in natural coordinate order. This is +/// Ensures all tensors are visited in natural index order. This is /// essential for sparse storage formats since these only support access -/// along fixed levels. Even for dense storage formats, however, the natural -/// coordinate order yields innermost unit-stride access with better spatial +/// along fixed dimensions. Even for dense storage formats, however, the natural +/// index order yields innermost unit-stride access with better spatial /// locality. static bool computeIterationGraph(CodegenEnv &env, SortMask mask, - OpOperand *skip = nullptr) { + OpOperand *skip, bool useSlice = false) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. const LoopId n = env.merger().getNumLoops(); @@ -522,7 +775,8 @@ // Get map and encoding. const auto map = env.op().getMatchingIndexingMap(&t); const auto enc = getSparseTensorEncoding(t.get().getType()); - assert(map.getNumDims() + getNumCompoundAffineOnSparseLvls(env.op()) == n); + assert(map.getNumDims() + getNumNonTrivialIdxExpOnSparseLvls(env.op()) == + n); // Skips dense inputs/outputs when not requested. const bool isDenseInput = !enc && env.op().isDpsInput(&t); @@ -549,63 +803,12 @@ } } } - - // Each tensor expression and optional dimension ordering (row-major - // by default) puts an ordering constraint on the loop indices. For - // example, the tensor expresion A_ijk forces the ordering i < j < k - // on the loop indices if no explicit dimension ordering is given. - const Level lvlRank = map.getNumResults(); - assert(!enc || lvlRank == enc.getLvlRank()); - for (Level l = 0; l < lvlRank; l++) { - // FIXME: `toOrigDim` is deprecated. - // FIXME: above we asserted that there are `lvlRank` many results, - // but this is assuming there are in fact `dimRank` many results instead. - AffineExpr ta = map.getResult(toOrigDim(enc, l)); - std::optional tldx = - env.merger().getLoopId(t.getOperandNumber(), l); - - // Filter loops should be constructed after all the dependent loops, - // i.e., d0 + d1 < filter_loop(d0 + d1) - if (tldx && env.merger().isFilterLoop(*tldx)) { - assert(!ta.isa() && !isDenseDLT(enc.getLvlType(l))); - addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, - tldx); - // Now that the ordering of affine expression is captured by filter - // loop idx, we only need to ensure the affine ordering against filter - // loop. Thus, we reset the affine express to nil here to mark it as - // resolved. - ta = AffineExpr(); - } - - // Skip tensor during cycle resolution, though order between filter loop - // and dependent loops need to be guaranteed unconditionally. - if (&t == skip) - continue; - - if (l > 0) { - // FIXME: `toOrigDim` is deprecated. - // FIXME: above we asserted that there are `lvlRank` many results, - // but this is assuming there are in fact `dimRank` many results. - AffineExpr fa = map.getResult(toOrigDim(enc, l - 1)); - std::optional fldx = - env.merger().getLoopId(t.getOperandNumber(), l - 1); - - // Applying order constraints on every pair of dimExpr between two - // compound affine expressions can sometime too strict: - // E.g, for [dense, dense] -> (d0 + d1, d2 + d3). - // It is totally fine to have loop sequence d0->d2->d1->d3 instead of - // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3. - if (!includesDense(mask)) - tryLoosenAffineDenseConstraints(env.op(), fldx, fa, tldx, ta); - - // (d0 + d1) < (d2 + d3), or - // filter_loop_d-1 < (d2 + d3), or - // (d0 + d1) < filter_loop_d, or - // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset - // above. - addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx); - } - } + // Push unrelated loops into sparse iteration space, so these + // will be skipped more often. + if (useSlice) + addSliceBasedConstraints(env, t, skip, mask, adjM, inDegree); + else + addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree); } // Topologically sort the iteration graph to determine loop order. // Report failure for a cyclic iteration graph. @@ -1275,7 +1478,7 @@ SmallVector lvls; env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid, std::optional lvl, - DimLevelType dlt) { + DimLevelType dlt, bool) { assert(env.merger().loop(b) == idx); if (isDenseDLT(dlt) || isUndefDLT(dlt)) { needsUniv = true; @@ -1350,7 +1553,7 @@ bool hasNonUnique = false; env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, - DimLevelType dlt) { + DimLevelType dlt, bool) { if (simple.test(b)) { if (isUndefDLT(dlt)) { // An undefined dlt in the lattices, we probably mean to @@ -1596,21 +1799,25 @@ PatternRewriter &rewriter) const override { // Only accept single output operations without affine index on sparse // output. - if (op.getNumDpsInits() != 1 || hasCompoundAffineOnSparseOut(op)) + if (op.getNumDpsInits() != 1 || hasNonTrivialAffineOnSparseOut(op)) return failure(); - if (options.enableIndexReduction) - llvm_unreachable("not yet implemented"); - // Sets up a code generation environment. const unsigned numTensors = op->getNumOperands(); const unsigned numLoops = op.getNumLoops(); - const unsigned numFilterLoops = getNumCompoundAffineOnSparseLvls(op); - CodegenEnv env(op, options, numTensors, numLoops, numFilterLoops); + const unsigned numFilterLoops = getNumNonTrivialIdxExpOnSparseLvls(op); + // TODO: we should probably always use slice-based codegen whenever + // possible, we can even intermix slice-based and filter-loop based codegen. + bool idxReducBased = options.enableIndexReduction && numFilterLoops != 0; + + // If we uses slice based algorithm for affine index, we do not need filter + // loop. + CodegenEnv env(op, options, numTensors, numLoops, + /*numFilterLoops=*/idxReducBased ? 0 : numFilterLoops); // Detects sparse annotations and translates the per-level sparsity // information for all tensors to loop indices in the kernel. - if (!findSparseAnnotations(env)) + if (!findSparseAnnotations(env, idxReducBased)) return failure(); // Constructs the tensor expressions tree from `op`, returns failure if the @@ -1635,7 +1842,7 @@ SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput, SortMask::kIncludeUndef, SortMask::kSparseOnly}; for (const SortMask mask : allMasks) { - if (computeIterationGraph(env, mask)) { + if (computeIterationGraph(env, mask, nullptr, idxReducBased)) { hasCycle = false; if (env.isAdmissibleTopoOrder()) { isAdmissible = true; @@ -1644,11 +1851,24 @@ // else try a set of less strict constraints. } } - if (hasCycle) - return resolveCycle(env, rewriter); // one last shot + if (hasCycle) { + return idxReducBased + ? failure() // TODO: should cycle be resolved differently? + : resolveCycle(env, rewriter); // one last shot + } + if (!isAdmissible) return failure(); // inadmissible expression, reject + for (OpOperand &t : env.op()->getOpOperands()) { + Level rank = env.op().getMatchingIndexingMap(&t).getNumResults(); + for (Level lvl = 0; lvl < rank; lvl++) { + sortArrayBasedOnOrder( + env.merger().getDependedLoops(t.getOperandNumber(), lvl), + env.getTopSort()); + } + } + // Recursively generates code if admissible. env.startEmit(); genBuffers(env, rewriter); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -220,7 +220,12 @@ loopToLvl(numTensors, std::vector>(numLoops, std::nullopt)), lvlToLoop(numTensors, - std::vector>(numLoops, std::nullopt)) {} + std::vector>(numLoops, std::nullopt)), + ldxToDependencies(numLoops, std::vector>( + numTensors, std::nullopt)), + lvlToDependentIdx(numTensors, std::vector>( + numLoops, std::vector())), + loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} //===----------------------------------------------------------------------===// // Lattice methods. @@ -762,7 +767,10 @@ const TensorId t = tensor(b); const LoopId i = loop(b); const auto dlt = lvlTypes[t][i]; - llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); + if (isLvlWithNonTrivialIdxExp(b)) + llvm::dbgs() << " DEP_" << t << "_" << i; + else + llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); } } }