diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h @@ -53,6 +53,13 @@ "any-storage-any-loop", "Enable sparse parallelization for any storage and loop."))}; + PassOptions::Option enableIndexReduction{ + *this, "enable-index-reduction", + desc("Enable dependent index reduction based algorithm to handle " + "non-trivial index expressions on sparse inputs (experimental " + "features)"), + init(false)}; + PassOptions::Option enableRuntimeLibrary{ *this, "enable-runtime-library", desc("Enable runtime library for manipulating sparse tensors"), @@ -108,7 +115,7 @@ /// Projects out the options for `createSparsificationPass`. SparsificationOptions sparsificationOptions() const { - return SparsificationOptions(parallelization); + return SparsificationOptions(parallelization, enableIndexReduction); } /// Projects out the options for `createSparseTensorConversionPass`. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -49,11 +49,12 @@ /// Options for the Sparsification pass. struct SparsificationOptions { - SparsificationOptions(SparseParallelizationStrategy p) - : parallelizationStrategy(p) {} + SparsificationOptions(SparseParallelizationStrategy p, bool idxReduc) + : parallelizationStrategy(p), enableIndexReduction(idxReduc) {} SparsificationOptions() - : SparsificationOptions(SparseParallelizationStrategy::kNone) {} + : SparsificationOptions(SparseParallelizationStrategy::kNone, false) {} SparseParallelizationStrategy parallelizationStrategy; + bool enableIndexReduction; }; /// Sets up sparsification rewriting rules with the given options. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -81,6 +81,9 @@ ]; // TODO(57514): These enum options are duplicated in Passes.h. let options = [ + Option<"enableIndexReduction", "enable-index-reduction", "bool", + "false", + "Enable dependent index reduction based algorithm to handle non-trivial index expressions on sparse inputs (experimental features)">, Option<"parallelization", "parallelization-strategy", "mlir::SparseParallelizationStrategy", "mlir::SparseParallelizationStrategy::kNone", "Set the parallelization strategy", [{llvm::cl::values( diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -63,12 +63,13 @@ SparsificationPass(const SparsificationPass &pass) = default; SparsificationPass(const SparsificationOptions &options) { parallelization = options.parallelizationStrategy; + enableIndexReduction = options.enableIndexReduction; } void runOnOperation() override { auto *ctx = &getContext(); // Translate strategy flags to strategy options. - SparsificationOptions options(parallelization); + SparsificationOptions options(parallelization, enableIndexReduction); // Apply sparsification and cleanup rewriting. RewritePatternSet patterns(ctx); populateSparsificationPatterns(patterns, options); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1348,75 +1348,75 @@ unsigned numloopCond = 0; bool hasNonUnique = false; - env.merger().foreachTensorLoopId( - li, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, - DimLevelType dlt) { - if (simple.test(b)) { - if (isUndefDLT(dlt)) { - // An undefined dlt in the lattices, we probably mean to - // iterate based on the level of output tensor. E.g., this - // could be a synthetic tensor (for invariants and sparse - // output tensor). - // out[i][j] = invariant; or a broadcast - // out[i][j] = in[i] (j is undef for input) - tid = outTid; - lvl = outLvl; - // Skips invalid lvl (e.g., when this is a zero ranked tensor). - if (!lvl) - return; - } - hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique; - tids.push_back(tid); - lvls.push_back(*lvl); - numloopCond++; - } else if (isDenseDLT(dlt)) { - tids.push_back(tid); - lvls.push_back(*lvl); - } else { - assert(isUndefDLT(dlt)); - linalg::GenericOp op = env.op(); - if (tid >= op.getNumDpsInputs()) - // We only handle affine expression on input tensors (for now). - return; - OpOperand *operand = &op->getOpOperand(tid); - const auto stt = getSparseTensorType(operand->get()); - // Non-annotated dense tensors requires no special handling. - if (!stt.hasEncoding()) - return; - - ArrayRef affines = - op.getMatchingIndexingMap(operand).getResults(); - const Level lvlRank = stt.getLvlRank(); - assert(affines.size() == static_cast(lvlRank)); - for (Level l = 0; l < lvlRank; l++) { - // FIXME: `toOrigDim` is deprecated. - AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)]; - // Skip simple affine expression and non-dense levels (which - // have their own filter loop). - if (exp.isa() || !stt.isDenseLvl(l)) - continue; - - // Constant affine expression are handled in genLoop - if (!exp.isa()) { - bool isAtLoop = false; - if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) { - // If the compound affine is invariant and we are right at the - // level. We need to generate the address according to the - // affine expression. This is also the best place we can do it - // to avoid putting it inside inner loops. - // NOTE: It assumes that the levels of the input tensor are - // initialized in order (and it is also currently guaranteed by - // computeIterationGraph), another more admissible approach - // might be accepting out-of-order access between consecutive - // dense levels. - affineTids.push_back(tid); - affineLvls.push_back(l); - exps.push_back(exp); - } - } + env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid, + std::optional lvl, + DimLevelType dlt) { + if (simple.test(b)) { + if (isUndefDLT(dlt)) { + // An undefined dlt in the lattices, we probably mean to + // iterate based on the level of output tensor. E.g., this + // could be a synthetic tensor (for invariants and sparse + // output tensor). + // out[i][j] = invariant; or a broadcast + // out[i][j] = in[i] (j is undef for input) + tid = outTid; + lvl = outLvl; + // Skips invalid lvl (e.g., when this is a zero ranked tensor). + if (!lvl) + return; + } + hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique; + tids.push_back(tid); + lvls.push_back(*lvl); + numloopCond++; + } else if (isDenseDLT(dlt)) { + tids.push_back(tid); + lvls.push_back(*lvl); + } else { + assert(isUndefDLT(dlt)); + linalg::GenericOp op = env.op(); + if (tid >= op.getNumDpsInputs()) + // We only handle affine expression on input tensors (for now). + return; + OpOperand *operand = &op->getOpOperand(tid); + const auto stt = getSparseTensorType(operand->get()); + // Non-annotated dense tensors requires no special handling. + if (!stt.hasEncoding()) + return; + + ArrayRef affines = + op.getMatchingIndexingMap(operand).getResults(); + const Level lvlRank = stt.getLvlRank(); + assert(affines.size() == static_cast(lvlRank)); + for (Level l = 0; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)]; + // Skip simple affine expression and non-dense levels (which + // have their own filter loop). + if (exp.isa() || !stt.isDenseLvl(l)) + continue; + + // Constant affine expression are handled in genLoop + if (!exp.isa()) { + bool isAtLoop = false; + if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) { + // If the compound affine is invariant and we are right at the + // level. We need to generate the address according to the + // affine expression. This is also the best place we can do it + // to avoid putting it inside inner loops. + // NOTE: It assumes that the levels of the input tensor are + // initialized in order (and it is also currently guaranteed by + // computeIterationGraph), another more admissible approach + // might be accepting out-of-order access between consecutive + // dense levels. + affineTids.push_back(tid); + affineLvls.push_back(l); + exps.push_back(exp); } } - }); + } + } + }); if (isDenseDLT(env.dlt(outTid, ldx))) { // Note that we generate dense indices of the output tensor @@ -1599,6 +1599,9 @@ if (op.getNumDpsInits() != 1 || hasCompoundAffineOnSparseOut(op)) return failure(); + if (options.enableIndexReduction) + llvm_unreachable("not yet implemented"); + // Sets up a code generation environment. const unsigned numTensors = op->getNumOperands(); const unsigned numLoops = op.getNumLoops();