diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h @@ -53,6 +53,12 @@ "any-storage-any-loop", "Enable sparse parallelization for any storage and loop."))}; + PassOptions::Option enableSliceBasedAffine{ + *this, "enable-slice-affine", + desc("Enable (experimental) slice-based algorithm to generate affine " + "indices on sparse inputs"), + init(false)}; + PassOptions::Option enableRuntimeLibrary{ *this, "enable-runtime-library", desc("Enable runtime library for manipulating sparse tensors"), @@ -108,7 +114,7 @@ /// Projects out the options for `createSparsificationPass`. SparsificationOptions sparsificationOptions() const { - return SparsificationOptions(parallelization); + return SparsificationOptions(parallelization, enableSliceBasedAffine); } /// Projects out the options for `createSparseTensorConversionPass`. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -49,11 +49,12 @@ /// Options for the Sparsification pass. struct SparsificationOptions { - SparsificationOptions(SparseParallelizationStrategy p) - : parallelizationStrategy(p) {} + SparsificationOptions(SparseParallelizationStrategy p, bool sliceBasedAffine) + : parallelizationStrategy(p), enableSliceBasedAffine(sliceBasedAffine) {} SparsificationOptions() - : SparsificationOptions(SparseParallelizationStrategy::kNone) {} + : SparsificationOptions(SparseParallelizationStrategy::kNone, false) {} SparseParallelizationStrategy parallelizationStrategy; + bool enableSliceBasedAffine; }; /// Sets up sparsification rewriting rules with the given options. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -81,6 +81,8 @@ ]; // TODO(57514): These enum options are duplicated in Passes.h. let options = [ + Option<"enableSliceBasedAffine", "enable-slice-affine", "bool", + "false", "Enable (experimental) slice-based algorithm to generate affine indices on sparse inputs">, Option<"parallelization", "parallelization-strategy", "mlir::SparseParallelizationStrategy", "mlir::SparseParallelizationStrategy::kNone", "Set the parallelization strategy", [{llvm::cl::values( diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -63,12 +63,13 @@ SparsificationPass(const SparsificationPass &pass) = default; SparsificationPass(const SparsificationOptions &options) { parallelization = options.parallelizationStrategy; + enableSliceBasedAffine = options.enableSliceBasedAffine; } void runOnOperation() override { auto *ctx = &getContext(); // Translate strategy flags to strategy options. - SparsificationOptions options(parallelization); + SparsificationOptions options(parallelization, enableSliceBasedAffine); // Apply sparsification and cleanup rewriting. RewritePatternSet patterns(ctx); populateSparsificationPatterns(patterns, options); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1449,6 +1449,9 @@ if (op.getNumDpsInits() != 1 || hasCompoundAffineOnSparseOut(op)) return failure(); + if (options.enableSliceBasedAffine) + llvm_unreachable("not yet implemented"); + // Sets up a code generation environment. unsigned numTensors = op->getNumOperands(); unsigned numLoops = op.getNumLoops();