diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -323,6 +323,26 @@ dimToLoopIdx[t][dim] = i; } + void setLoopDependentSliceDim(unsigned l, unsigned t, unsigned dim) { + assert(dim < numLoops); + ldxToDependentSlice[t][l] = dim; + sliceToRelatedldx[t][dim].push_back(l); + } + + // Whether the ldx has dependent slice. + bool hasDependentSliceDim(unsigned l) { + return llvm::any_of(ldxToDependentSlice, + [l](const std::vector> &m) { + return m[l].has_value(); + }); + } + + // Returns the list of loop index appeared in the affine indexing expression + // on t[d]. + std::vector &getRelatedLoops(unsigned t, unsigned d) { + return sliceToRelatedldx[t][d]; + } + // Iterates the bits of a lattice, for each set bit, converts it into the // corresponding tensor dimension and invokes the callback. void foreachTidDimPairInBits( @@ -394,6 +414,11 @@ // Map that converts pair to the corresponding loop id. std::vector>> dimToLoopIdx; + // Map from loop idx to the dependent slice [tid, dim] pair (if any). + std::vector>> ldxToDependentSlice; + // Map from the dependent slice [tid, dim] pair to a list of loop idx. + std::vector>> sliceToRelatedldx; + llvm::SmallVector tensorExps; llvm::SmallVector latPoints; llvm::SmallVector> latSets; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -99,6 +99,7 @@ topSort.reserve(capacity); } + ArrayRef getTopSort() const { return ArrayRef(topSort); } ArrayRef getTopSortSlice(size_t n, size_t m) const; ArrayRef getLoopCurStack() const; Value getLoopIdxValue(size_t loopIdx) const; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -134,31 +134,31 @@ } bool CodegenEnv::isAdmissibleTopoOrder() { - if (!hasSparseOutput()) - return true; - - OpOperand *lhs = linalgOp.getDpsInitOperand(0); - // Accept "truly dynamic" if the output tensor materializes uninitialized - // into the computation and insertions occur in lexicographic index order. - unsigned nest = 0; - auto iteratorTypes = linalgOp.getIteratorTypesArray(); - for (unsigned i = 0, e = latticeMerger.getNumLoops(); i < e; i++) { - if (!latticeMerger.isFilterLoop(topSortAt(i))) { - // We only count non-filter loops as filter loops should be considered - // as a special type of parallel loops. - if (linalg::isReductionIterator(iteratorTypes[topSortAt(i)])) - break; // terminate at first reduction - nest++; + if (hasSparseOutput()) { + OpOperand *lhs = op().getDpsInitOperand(0); + // Accept "truly dynamic" if the output tensor materializes uninitialized + // into the computation and insertions occur in lexicographic index order. + unsigned nest = 0; + auto iteratorTypes = op().getIteratorTypesArray(); + for (unsigned i = 0, e = merger().getNumLoops(); i < e; i++) { + if (!merger().isFilterLoop(topSortAt(i))) { + // We only count non-filter loops as filter loops should be considered + // as a special type of parallel loops. + if (linalg::isReductionIterator(iteratorTypes[topSortAt(i)])) + break; // terminate at first reduction + nest++; + } } + // Determine admissible dynamic insertion situations: + // (1) fully injective, since there are no reductions, + // (2) admissible 1-d expansion in innermost dimension. + if (nest >= op().getRank(lhs) - 1) + outerParNest = nest; + else + return false; } - // Determine admissible dynamic insertion situations: - // (1) fully injective, since there are no reductions, - // (2) admissible 1-d expansion in innermost dimension. - if (nest >= linalgOp.getRank(lhs) - 1) { - outerParNest = nest; - return true; - } - return false; + + return true; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -83,6 +83,14 @@ SmallVector iterTypes; }; +/// A helper class that visits an affine expression and tries to find an +/// AffineDimExpr to which the corresponding iterator from a GenericOp matches +/// the desired iterator type. +struct AffineDimCollector : public AffineExprVisitor { + void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); } + SmallVector dims; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -169,6 +177,52 @@ return AffineMap::getPermutationMap(perm, env.op().getContext()); } +/// Get the total number of compound affine expressions in affineMap that are +/// attached to the given tensor. For the following inputs: +/// +/// affineMap = (d0, d1, d2) => (d0 + d1, d2) +/// tensor = ["compressed", "compressed"] +/// +/// Returns 1 (because the first level is compressed and its corresponding +/// affineMap is d0 + d1) +static unsigned getNumCompoundAffineOnSparseDims(AffineMap affineMap, + Value tensor) { + unsigned num = 0; + const auto enc = getSparseTensorEncoding(tensor.getType()); + if (enc) { + const ArrayRef exps = affineMap.getResults(); + const Level lvlRank = enc.getLvlRank(); + assert(static_cast(exps.size()) == lvlRank); + for (Level l = 0; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + const Dimension d = toOrigDim(enc, l); + // FIXME: there's some dim/lvl confusion here; since `d` isn't + // guaranteed to be in bounds (for non-permutations). + if (!exps[d].isa() && !enc.isDenseLvl(l)) + num++; + } + } + return num; +} + +/// Get the total number of compound affine expressions attached on a sparse +/// level in the given GenericOp. +static unsigned getNumCompoundAffineOnSparseDims(linalg::GenericOp op) { + unsigned num = 0; + for (OpOperand &t : op->getOpOperands()) + num += getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(&t), + t.get()); + return num; +} + +static bool hasCompoundAffineOnSparseOut(linalg::GenericOp op) { + OpOperand *out = op.getDpsInitOperand(0); + if (getSparseTensorType(out->get()).isAllDense()) + return false; + return getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(out), + out->get()); +} + /// Helper method to inspect affine expressions. Rejects cases where the /// same index is used more than once. Also rejects compound affine /// expressions in sparse dimensions. @@ -215,50 +269,58 @@ } } -/// Get the total number of compound affine expressions in affineMap that are -/// attached to the given tensor. For the following inputs: -/// -/// affineMap = (d0, d1, d2) => (d0 + d1, d2) -/// tensor = ["compressed", "compressed"] -/// -/// Returns 1 (because the first level is compressed and its corresponding -/// affineMap is d0 + d1) -static unsigned getNumCompoundAffineOnSparseDims(AffineMap affineMap, - Value tensor) { - unsigned num = 0; - const auto enc = getSparseTensorEncoding(tensor.getType()); - if (enc) { - const ArrayRef exps = affineMap.getResults(); - const Level lvlRank = enc.getLvlRank(); - assert(static_cast(exps.size()) == lvlRank); - for (Level l = 0; l < lvlRank; l++) { - // FIXME: `toOrigDim` is deprecated. - const Dimension d = toOrigDim(enc, l); - // FIXME: there's some dim/lvl confusion here; since `d` isn't - // guaranteed to be in bounds (for non-permutations). - if (!exps[d].isa() && !enc.isDenseLvl(l)) - num++; +/// Helper method to inspect affine expressions for slice based codegen. Rejects +/// cases where the same index is used more than once. +static bool findSliceBasedAffine(Merger &merger, unsigned tensor, unsigned dim, + AffineExpr a, DimLevelType dlt, bool needSlice, + bool isSubExp = false) { + switch (a.getKind()) { + case AffineExprKind::DimId: { + unsigned idx = a.cast().getPosition(); + if (!isUndefDLT(merger.getDimLevelType(tensor, idx))) + return false; // used more than once + if (!isSubExp) + merger.setDimAndDimLevelType(tensor, idx, dim, dlt); + + if (isSubExp && needSlice) { + // The current loops appears in more than one affine expressions. + // Now, for simplicity, we do not handle this case. + if (merger.hasDependentSliceDim(idx)) { + // TODO: This can be supported by coiterate slices if the loop idx is + // appeared on affine index for different tensor, or take slice on + // mulitple dimensions when it is on the same tensor. + // E.g., + // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0] + // d0_1 = getNextSliceOffset t0 along lvl0 + // d0_2 = getNextSliceOffset t1 along lvl0 + // if d0_1 == d0_2 then d0 = d0_1 = d0_1 + // else increase min(d0_1, d0_2). + return false; + } + merger.setLoopDependentSliceDim(idx, tensor, dim); } + return true; } - return num; -} - -/// Get the total number of compound affine expressions attached on a sparse -/// level in the given GenericOp. -static unsigned getNumCompoundAffineOnSparseDims(linalg::GenericOp op) { - unsigned num = 0; - for (OpOperand &t : op->getOpOperands()) - num += getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(&t), - t.get()); - return num; -} - -static bool hasCompoundAffineOnSparseOut(linalg::GenericOp op) { - OpOperand *out = op.getDpsInitOperand(0); - if (getSparseTensorType(out->get()).isAllDense()) + case AffineExprKind::Constant: + case AffineExprKind::Mul: + // TODO: Support Mul and Constant AffineExp for slice-based codegen + if (needSlice) + return false; + [[fallthrough]]; + case AffineExprKind::Add: { + auto binOp = a.cast(); + // We do not set dim level format for affine expresssion like d0 + d1 on + // either loop index at d0 or d1. + // We continue the recursion merely to check whether current affine is + // admissible or not. + return findSliceBasedAffine(merger, tensor, dim, binOp.getLHS(), dlt, + needSlice, true) && + findSliceBasedAffine(merger, tensor, dim, binOp.getRHS(), dlt, + needSlice, true); + } + default: return false; - return getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(out), - out->get()); + } } /// Helper method to inspect sparse encodings in the tensor types. @@ -266,7 +328,7 @@ /// Returns true if the sparse annotations and affine subscript /// expressions of all tensors are admissible. Returns false if /// no annotations are found or inadmissible constructs occur. -static bool findSparseAnnotations(CodegenEnv &env) { +static bool findSparseAnnotations(CodegenEnv &env, bool sliceBased) { bool annotated = false; unsigned filterLdx = env.merger().getFilterLoopStartingIdx(); for (OpOperand &t : env.op()->getOpOperands()) { @@ -274,15 +336,24 @@ const auto enc = getSparseTensorEncoding(t.get().getType()); if (enc) annotated = true; + const Level lvlRank = map.getNumResults(); assert(!enc || lvlRank == enc.getLvlRank()); - assert(static_cast(env.op().getRank(&t)) == lvlRank); - for (Level l = 0; l < lvlRank; l++) { - const unsigned tensor = t.getOperandNumber(); - // FIXME: `toOrigDim` is deprecated. - const AffineExpr a = map.getResult(toOrigDim(enc, l)); - if (!findAffine(env.merger(), tensor, l, a, enc.getLvlType(l), filterLdx)) - return false; // inadmissible affine expression + assert(map.getNumResults() == env.op().getRank(&t)); + // If then current tensor being inspected requires affine index, it need + // to be sliced. + bool needSlice = enc && getNumCompoundAffineOnSparseDims(map, t.get()) != 0; + for (unsigned l = 0; l < lvlRank; l++) { + unsigned tensor = t.getOperandNumber(); + AffineExpr a = map.getResult(toOrigDim(enc, l)); + DimLevelType dlt = enc.getLvlType(l); + if (sliceBased) { + if (!findSliceBasedAffine(env.merger(), tensor, l, a, dlt, needSlice)) + return false; // inadmissible affine expression + } else { + if (!findAffine(env.merger(), tensor, l, a, dlt, filterLdx)) + return false; // inadmissible affine expression + } } } assert(filterLdx == env.merger().getNumLoops()); @@ -295,8 +366,8 @@ /// latest possible index. static bool topSortOptimal(CodegenEnv &env, unsigned n, ArrayRef iteratorTypes, - std::vector &inDegree, - std::vector> &adjM) { + std::vector inDegree, + const std::vector> &adjM) { std::vector redIt; // reduce iterator with 0 degree std::vector parIt; // parallel iterator with 0 degree std::vector filterIt; // filter loop with 0 degree @@ -334,6 +405,7 @@ auto &it = !filterIt.empty() ? filterIt : (!parIt.empty() ? parIt : redIt); auto src = it.back(); env.topSortPushBack(src); + // llvm::outs() << src; it.pop_back(); // Update in-degree, and push 0-degree node into worklist. for (unsigned dst = 0; dst < n; dst++) { @@ -399,11 +471,11 @@ } } -static void tryLoosenAffineDenseConstraints(linalg::GenericOp op, - std::optional &fldx, - AffineExpr &fa, - std::optional &tldx, - AffineExpr &ta) { +static void tryLoosenAffineConstraints(linalg::GenericOp op, + std::optional &fldx, + AffineExpr &fa, + std::optional &tldx, + AffineExpr &ta) { // We use a heuristic here to only pick one dim expression from each // compound affine expression to establish the order between two dense // dimensions. @@ -424,7 +496,7 @@ } if (!ta.isa()) { // Heuristic: we prefer reduction loop for rhs to reduce the chance - // addint reduce < parallel ordering. + // adding reduce < parallel ordering. finder.setPickedIterType(utils::IteratorType::reduction); finder.walkPostOrder(ta); ta = finder.getDimExpr(); @@ -433,6 +505,181 @@ } } +static void sortArrayBasedOnOrder(std::vector &target, + ArrayRef order) { + std::sort(target.begin(), target.end(), [&order](unsigned l, unsigned r) { + assert(l != r); + int idxL = -1, idxR = -1; + for (int i = 0, e = order.size(); i < e; i++) { + if (order[i] == l) + idxL = i; + if (order[i] == r) + idxR = i; + } + assert(idxL >= 0 && idxR >= 0); + return idxL < idxR; + }); +} + +static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t, + OpOperand *skip, unsigned mask, + std::vector> &adjM, + std::vector &inDegree) { + // Get map and encoding. + auto map = env.op().getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); + + // Skip dense tensor constraints when not requested. + if (!(mask & SortMask::kIncludeDense) && !enc) + return; + // Each tensor expression and optional dimension ordering (row-major + // by default) puts an ordering constraint on the loop indices. For + // example, the tensor expresion A_ijk forces the ordering i < j < k + // on the loop indices if no explicit dimension ordering is given. + for (unsigned l = 0, rank = map.getNumResults(); l < rank; l++) { + AffineExpr ta = map.getResult(toOrigDim(enc, l)); + std::optional tldx = + env.merger().getLoopIdx(t.getOperandNumber(), l); + // Filter loops should be constructed after all the dependent loops, + // i.e., d0 + d1 < filter_loop(d0 + d1) + if (tldx && env.merger().isFilterLoop(*tldx)) { + enc.isAllDense(); + assert(!ta.isa() && !isDenseDLT(enc.getDimLevelType()[l])); + addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx); + // Now that the ordering of affine expression is captured by filter + // loop idx, we only need to ensure the affine ordering against filter + // loop. Thus, we reset the affine express to nil here to mark it as + // resolved. + ta = AffineExpr(); + } + + // Skip tensor during cycle resolution, though order between filter loop + // and dependent loops need to be guaranteed unconditionally. + if (&t == skip) + continue; + + if (l > 0) { + AffineExpr fa = map.getResult(toOrigDim(enc, l - 1)); + std::optional fldx = + env.merger().getLoopIdx(t.getOperandNumber(), l - 1); + + // Applying order constraints on every pair of dimExpr between two + // compound affine expressions can sometime too strict: + // E.g, for [dense, dense] -> (d0 + d1, d2 + d3). + // It is totally fine to have loop sequence d0->d2->d1->d3 instead of + // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3. + // We also loosen the affine constraint when use slice-based algorithm + // as there is no filter loop for affine index on sparse dimension. + if (!(mask & SortMask::kIncludeDense)) + tryLoosenAffineConstraints(env.op(), fldx, fa, tldx, ta); + + // (d0 + d1) < (d2 + d3), or + // filter_loop_d-1 < (d2 + d3), or + // (d0 + d1) < filter_loop_d, or + // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset + // above. + addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx); + } + } +} + +static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t, + OpOperand *skip, unsigned mask, + std::vector> &adjM, + std::vector &inDegree) { + // Get map and encoding. + auto map = env.op().getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); + + // Skip dense tensor constraints when not requested. + if (!(mask & SortMask::kIncludeDense) && !enc) + return; + + // No special treatment for simple indices. + if (getNumCompoundAffineOnSparseDims(map, t.get()) == 0) + return addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree); + + // Skip tensor during cycle resolution, though order between filter loop + // and dependent loops need to be guaranteed unconditionally. + if (&t == skip) + return; + + AffineDimFinder finder(env.op()); + finder.setPickedIterType(utils::IteratorType::reduction); + // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6], + // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6}, + // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y + for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { + AffineExpr fa = map.getResult(toOrigDim(enc, d - 1)); + AffineExpr ta = map.getResult(toOrigDim(enc, d)); + + // This is a heurisitic, we pick an abitrary reduction loop from lhs and + // rhs and use them as d_x and d_y. + finder.walkPostOrder(fa); + AffineDimExpr fexp = finder.getDimExpr(); + unsigned fldx = fexp.getPosition(); + + finder.walkPostOrder(ta); + AffineDimExpr texp = finder.getDimExpr(); + unsigned tldx = texp.getPosition(); + + // d_x > d_y + if (!adjM[fldx][tldx]) { + adjM[fldx][tldx] = true; + inDegree[tldx]++; + } + + AffineDimCollector fCollector; + fCollector.walkPostOrder(fa); + AffineDimCollector tCollector; + tCollector.walkPostOrder(ta); + + // make sure dx and dy is the last; + for (auto fd : fCollector.dims) { + unsigned f = fd.getPosition(); + if (f == fldx) + continue; + if (!adjM[f][fldx]) { + adjM[f][fldx] = true; + inDegree[fldx]++; + } + } + for (auto td : tCollector.dims) { + unsigned t = td.getPosition(); + if (t == tldx) + continue; + if (!adjM[t][tldx]) { + adjM[t][tldx] = true; + inDegree[tldx]++; + } + } + // Since we only support affine addition, the order between two dim + // expression does not really matters. + // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y + // This is to ensure that the affine expressions are reduced in sparse + // tensor level ordering. + // TODO: this ordering could probably be loosen if we support out-of-order + // reduction. + // TODO: the evaluation order need to be ensure to + // support affine multiplication. + for (auto fd : fCollector.dims) { + unsigned f = fd.getPosition(); + if (f == fldx) // skip d_x + continue; + + for (auto td : tCollector.dims) { + unsigned t = td.getPosition(); + if (t == tldx) // skip d_y + continue; + if (!adjM[f][t]) { + adjM[f][t] = true; + inDegree[t]++; + } + } + } + } +} + /// Computes a topologically sorted iteration graph for the linalg /// operation. Ensures all tensors are visited in natural index order. This /// is essential for sparse storage formats since these only support access @@ -440,7 +687,7 @@ /// natural index order yields innermost unit-stride access with better /// spatial locality. static bool computeIterationGraph(CodegenEnv &env, unsigned mask, - OpOperand *skip = nullptr) { + OpOperand *skip, bool useSlice = false) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. const unsigned n = env.merger().getNumLoops(); @@ -449,67 +696,13 @@ const auto iteratorTypes = env.op().getIteratorTypesArray(); // Iterate over the indexing maps of every tensor in the tensor expression. for (OpOperand &t : env.op()->getOpOperands()) { - // Get map and encoding. - const auto map = env.op().getMatchingIndexingMap(&t); - const auto enc = getSparseTensorEncoding(t.get().getType()); - assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.op()) == n); - // Skip dense tensor constraints when not requested. - if (!(mask & SortMask::kIncludeDense) && !enc) - continue; - // Each tensor expression and optional dimension ordering (row-major - // by default) puts an ordering constraint on the loop indices. For - // example, the tensor expresion A_ijk forces the ordering i < j < k - // on the loop indices if no explicit dimension ordering is given. - const Level lvlRank = map.getNumResults(); - assert(!enc || lvlRank == enc.getLvlRank()); - for (Level l = 0; l < lvlRank; l++) { - // FIXME: `toOrigDim` is deprecated. - AffineExpr ta = map.getResult(toOrigDim(enc, l)); - std::optional tldx = - env.merger().getLoopIdx(t.getOperandNumber(), l); - - // Filter loops should be constructed after all the dependent loops, - // i.e., d0 + d1 < filter_loop(d0 + d1) - if (tldx && env.merger().isFilterLoop(*tldx)) { - assert(!ta.isa() && !isDenseDLT(enc.getLvlType(l))); - addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, - tldx); - // Now that the ordering of affine expression is captured by filter - // loop idx, we only need to ensure the affine ordering against filter - // loop. Thus, we reset the affine express to nil here to mark it as - // resolved. - ta = AffineExpr(); - } - - // Skip tensor during cycle resolution, though order between filter loop - // and dependent loops need to be guaranteed unconditionally. - if (&t == skip) - continue; - - if (l > 0) { - // FIXME: `toOrigDim` is deprecated. - AffineExpr fa = map.getResult(toOrigDim(enc, l - 1)); - std::optional fldx = - env.merger().getLoopIdx(t.getOperandNumber(), l - 1); - - // Applying order constraints on every pair of dimExpr between two - // compound affine expressions can sometime too strict: - // E.g, for [dense, dense] -> (d0 + d1, d2 + d3). - // It is totally fine to have loop sequence d0->d2->d1->d3 instead of - // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3. - if (!(mask & SortMask::kIncludeDense)) - tryLoosenAffineDenseConstraints(env.op(), fldx, fa, tldx, ta); - - // (d0 + d1) < (d2 + d3), or - // filter_loop_d-1 < (d2 + d3), or - // (d0 + d1) < filter_loop_d, or - // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset - // above. - addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx); - } - } // Push unrelated loops into sparse iteration space, so these // will be skipped more often. + if (useSlice) + addSliceBasedConstraints(env, t, skip, mask, adjM, inDegree); + else + addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree); + if (mask & SortMask::kIncludeUndef) { unsigned tensor = t.getOperandNumber(); for (unsigned i = 0; i < n; i++) { @@ -1506,18 +1699,22 @@ if (op.getNumDpsInits() != 1 || hasCompoundAffineOnSparseOut(op)) return failure(); - if (options.enableSliceBasedAffine) - llvm_unreachable("not yet implemented"); - // Sets up a code generation environment. unsigned numTensors = op->getNumOperands(); unsigned numLoops = op.getNumLoops(); - unsigned numFilterLoops = getNumCompoundAffineOnSparseDims(op); - CodegenEnv env(op, options, numTensors, numLoops, numFilterLoops); + unsigned numAffineOnSparse = getNumCompoundAffineOnSparseDims(op); + // TODO: we should probably always use slice-based codegen whenever + // possible, we can even intermix slice-based and filter-loop based codegen. + bool sliceBased = options.enableSliceBasedAffine && numAffineOnSparse != 0; + + // If we uses slice based algorithm for affine index, we do not need filter + // loop. + CodegenEnv env(op, options, numTensors, numLoops, + /*numFilterLoops=*/sliceBased ? 0 : numAffineOnSparse); // Detects sparse annotations and translates the per-dimension sparsity // information for all tensors to loop indices in the kernel. - if (!findSparseAnnotations(env)) + if (!findSparseAnnotations(env, sliceBased)) return failure(); // Constructs the tensor expressions tree from `op`, returns failure if the @@ -1538,7 +1735,7 @@ const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef, SortMask::kIncludeDense, SortMask::kSparseOnly}; for (auto mask : allMask) { - if (computeIterationGraph(env, mask)) { + if (computeIterationGraph(env, mask, nullptr, sliceBased)) { hasCycle = false; if (env.isAdmissibleTopoOrder()) { isAdmissible = true; @@ -1547,11 +1744,24 @@ // else try a set of less strict constraints. } } - if (hasCycle) - return resolveCycle(env, rewriter); // one last shot + if (hasCycle) { + return sliceBased + ? failure() // TODO: should cycle be resolved differently? + : resolveCycle(env, rewriter); // one last shot + } + if (!isAdmissible) return failure(); // inadmissible expression, reject + for (OpOperand &t : env.op()->getOpOperands()) { + unsigned rank = env.op().getMatchingIndexingMap(&t).getNumResults(); + for (unsigned i = 0; i < rank; i++) { + sortArrayBasedOnOrder( + env.merger().getRelatedLoops(t.getOperandNumber(), i), + env.getTopSort()); + } + } + // Recursively generates code if admissible. env.startEmit(); genBuffers(env, rewriter); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -214,7 +214,11 @@ loopIdxToDim(numTensors, std::vector>( numLoops, std::nullopt)), dimToLoopIdx(numTensors, std::vector>( - numLoops, std::nullopt)) {} + numLoops, std::nullopt)), + ldxToDependentSlice(numTensors, std::vector>( + numLoops, std::nullopt)), + sliceToRelatedldx(numTensors, std::vector>( + numLoops, std::vector())) {} //===----------------------------------------------------------------------===// // Lattice methods. @@ -752,6 +756,8 @@ unsigned i = index(b); DimLevelType dlt = dimTypes[t][i]; llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); + if (ldxToDependentSlice[t][i]) + llvm::dbgs() << "_D_" << *ldxToDependentSlice[t][i]; } } }