diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -163,6 +163,10 @@ void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT); +std::unique_ptr createSparseTensorRewritePass(); +std::unique_ptr +createSparseTensorRewritePass(const SparsificationOptions &options); + std::unique_ptr createDenseBufferizationPass( const bufferization::OneShotBufferizationOptions &options); diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -11,6 +11,27 @@ include "mlir/Pass/PassBase.td" +def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> { + let summary = "Applies sparse tensor rewriting rules prior to sparsification"; + let description = [{ + A pass that applies rewriting rules to sparse tensor operations prior + to running the actual sparsification pass. + }]; + let constructor = "mlir::createSparseTensorRewritePass()"; + let dependentDialects = [ + "arith::ArithDialect", + "bufferization::BufferizationDialect", + "linalg::LinalgDialect", + "memref::MemRefDialect", + "scf::SCFDialect", + "sparse_tensor::SparseTensorDialect", + ]; + let options = [ + Option<"enableRuntimeLibrary", "enable-runtime-library", "bool", + "true", "Enable runtime library for manipulating sparse tensors"> + ]; +} + def SparsificationPass : Pass<"sparsification", "ModuleOp"> { let summary = "Automatically generate sparse tensor code from sparse tensor types"; let description = [{ @@ -57,6 +78,7 @@ "arith::ArithDialect", "bufferization::BufferizationDialect", "LLVM::LLVMDialect", + "linalg::LinalgDialect", "memref::MemRefDialect", "scf::SCFDialect", "sparse_tensor::SparseTensorDialect", @@ -193,4 +215,5 @@ "sparse_tensor::SparseTensorDialect", ]; } + #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -58,6 +58,7 @@ /*analysisOnly=*/options.testBufferizationAnalysisOnly))); if (options.testBufferizationAnalysisOnly) return; + pm.addPass(createSparseTensorRewritePass(options.sparsificationOptions())); pm.addPass(createSparsificationPass(options.sparsificationOptions())); if (options.enableRuntimeLibrary) pm.addPass(createSparseTensorConversionPass( diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -21,6 +21,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { +#define GEN_PASS_DEF_SPARSETENSORREWRITE #define GEN_PASS_DEF_SPARSIFICATIONPASS #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS #define GEN_PASS_DEF_SPARSETENSORCODEGEN @@ -37,6 +38,23 @@ // Passes implementation. //===----------------------------------------------------------------------===// +struct SparseTensorRewritePass + : public impl::SparseTensorRewriteBase { + + SparseTensorRewritePass() = default; + SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default; + SparseTensorRewritePass(const SparsificationOptions &options) { + enableRuntimeLibrary = options.enableRuntimeLibrary; + } + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateSparseTensorRewriting(patterns, enableRuntimeLibrary); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct SparsificationPass : public impl::SparsificationPassBase { @@ -53,14 +71,10 @@ void runOnOperation() override { auto *ctx = &getContext(); - RewritePatternSet prePatterns(ctx); // Translate strategy flags to strategy options. SparsificationOptions options(parallelization, vectorization, vectorLength, enableSIMDIndex32, enableVLAVectorization, enableRuntimeLibrary); - // Apply pre-rewriting. - populateSparseTensorRewriting(prePatterns, options.enableRuntimeLibrary); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns)); // Apply sparsification and vector cleanup rewriting. RewritePatternSet patterns(ctx); populateSparsificationPatterns(patterns, options); @@ -236,6 +250,15 @@ // Pass creation methods. //===----------------------------------------------------------------------===// +std::unique_ptr mlir::createSparseTensorRewritePass() { + return std::make_unique(); +} + +std::unique_ptr +mlir::createSparseTensorRewritePass(const SparsificationOptions &options) { + return std::make_unique(options); +} + std::unique_ptr mlir::createSparsificationPass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/rewriting.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -sparsification | FileCheck %s +// RUN: mlir-opt %s -sparse-tensor-rewrite | FileCheck %s #SparseVector = #sparse_tensor.encoding<{ dimLevelType = ["compressed"] diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --sparsification=enable-runtime-library=false | FileCheck %s +// RUN: mlir-opt %s --sparse-tensor-rewrite=enable-runtime-library=false --sparsification | FileCheck %s #DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --linalg-generalize-named-ops --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s +// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-tensor-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s #DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --tensor-copy-insertion --sparsification --cse | FileCheck %s +// RUN: mlir-opt %s --tensor-copy-insertion --sparse-tensor-rewrite --sparsification --cse | FileCheck %s #SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>