diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -138,16 +138,25 @@ createSparseTensorCodegenPass(bool enableBufferInitialization); //===----------------------------------------------------------------------===// -// The SparseTensorRewriting pass. +// The SparseTensorPreRewriting pass. //===----------------------------------------------------------------------===// -void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT, - bool enableForeach, bool enableConvert); +void populateSparseTensorPreRewriting(RewritePatternSet &patterns); -std::unique_ptr<Pass> createSparseTensorRewritePass(); -std::unique_ptr<Pass> createSparseTensorRewritePass(bool enableRT, - bool enableForeach = true, - bool enableConvert = true); +std::unique_ptr<Pass> createSparseTensorPreRewritePass(); + +//===----------------------------------------------------------------------===// +// The SparseTensorPostRewriting pass. +//===----------------------------------------------------------------------===// + +void populateSparseTensorPostRewriting(RewritePatternSet &patterns, + bool enableRT, bool enableForeach, + bool enableConvert); + +std::unique_ptr<Pass> createSparseTensorPostRewritePass(); +std::unique_ptr<Pass> +createSparseTensorPostRewritePass(bool enableRT, bool enableForeach = true, + bool enableConvert = true); //===----------------------------------------------------------------------===// // Other rewriting rules and passes. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -11,13 +11,13 @@ include "mlir/Pass/PassBase.td" -def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> { +def SparseTensorPreRewrite : Pass<"sparse-tensor-pre-rewrite", "ModuleOp"> { let summary = "Applies sparse tensor rewriting rules prior to sparsification"; let description = [{ A pass that applies rewriting rules to sparse tensor operations prior to running the actual sparsification pass. }]; - let constructor = "mlir::createSparseTensorRewritePass()"; + let constructor = "mlir::createSparseTensorPreRewritePass()"; let dependentDialects = [ "arith::ArithDialect", "bufferization::BufferizationDialect", @@ -26,14 +26,6 @@ "scf::SCFDialect", "sparse_tensor::SparseTensorDialect", ]; - let options = [ - Option<"enableRuntimeLibrary", "enable-runtime-library", "bool", - "true", "Enable runtime library for manipulating sparse tensors">, - Option<"enableForeach", "enable-foreach", "bool", - "true", "Enable rewriting rules for the foreach operator">, - Option<"enableConvert", "enable-convert", "bool", - "true", "Enable rewriting rules for the convert operator">, - ]; } def SparsificationPass : Pass<"sparsification", "ModuleOp"> { @@ -109,6 +101,31 @@ ]; } +def SparseTensorPostRewrite : Pass<"sparse-tensor-post-rewrite", "ModuleOp"> { + let summary = "Applies sparse tensor rewriting rules after sparsification"; + let description = [{ + A pass that applies rewriting rules to sparse tensor operations after + running the actual sparsification pass. + }]; + let constructor = "mlir::createSparseTensorPostRewritePass()"; + let dependentDialects = [ + "arith::ArithDialect", + "bufferization::BufferizationDialect", + "linalg::LinalgDialect", + "memref::MemRefDialect", + "scf::SCFDialect", + "sparse_tensor::SparseTensorDialect", + ]; + let options = [ + Option<"enableRuntimeLibrary", "enable-runtime-library", "bool", + "true", "Enable runtime library for manipulating sparse tensors">, + Option<"enableForeach", "enable-foreach", "bool", + "true", "Enable rewriting rules for the foreach operator">, + Option<"enableConvert", "enable-convert", "bool", + "true", "Enable rewriting rules for the convert operator">, + ]; +} + def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> { let summary = "Convert sparse tensors and primitives to library calls"; let description = [{ diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -57,8 +57,9 @@ /*analysisOnly=*/options.testBufferizationAnalysisOnly))); if (options.testBufferizationAnalysisOnly) return; - pm.addPass(createSparseTensorRewritePass(options.enableRuntimeLibrary)); + pm.addPass(createSparseTensorPreRewritePass()); pm.addPass(createSparsificationPass(options.sparsificationOptions())); + pm.addPass(createSparseTensorPostRewritePass(options.enableRuntimeLibrary)); if (options.enableRuntimeLibrary) { pm.addPass(createSparseTensorConversionPass( options.sparseTensorConversionOptions())); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -21,8 +21,9 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { -#define GEN_PASS_DEF_SPARSETENSORREWRITE +#define GEN_PASS_DEF_SPARSETENSORPREREWRITE #define GEN_PASS_DEF_SPARSIFICATIONPASS +#define GEN_PASS_DEF_SPARSETENSORPOSTREWRITE #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS #define GEN_PASS_DEF_SPARSETENSORCODEGEN #define GEN_PASS_DEF_SPARSEBUFFERREWRITE @@ -38,22 +39,16 @@ // Passes implementation. //===----------------------------------------------------------------------===// -struct SparseTensorRewritePass - : public impl::SparseTensorRewriteBase<SparseTensorRewritePass> { +struct SparseTensorPreRewritePass + : public impl::SparseTensorPreRewriteBase<SparseTensorPreRewritePass> { - SparseTensorRewritePass() = default; - SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default; - SparseTensorRewritePass(bool enableRT, bool foreach, bool convert) { - enableRuntimeLibrary = enableRT; - enableForeach = foreach; - enableConvert = convert; - } + SparseTensorPreRewritePass() = default; + SparseTensorPreRewritePass(const SparseTensorPreRewritePass &pass) = default; void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - populateSparseTensorRewriting(patterns, enableRuntimeLibrary, enableForeach, - enableConvert); + populateSparseTensorPreRewriting(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; @@ -80,6 +75,27 @@ } }; +struct SparseTensorPostRewritePass + : public impl::SparseTensorPostRewriteBase<SparseTensorPostRewritePass> { + + SparseTensorPostRewritePass() = default; + SparseTensorPostRewritePass(const SparseTensorPostRewritePass &pass) = + default; + SparseTensorPostRewritePass(bool enableRT, bool foreach, bool convert) { + enableRuntimeLibrary = enableRT; + enableForeach = foreach; + enableConvert = convert; + } + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateSparseTensorPostRewriting(patterns, enableRuntimeLibrary, + enableForeach, enableConvert); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct SparseTensorConversionPass : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> { @@ -254,15 +270,8 @@ // Pass creation methods. //===----------------------------------------------------------------------===// -std::unique_ptr<Pass> mlir::createSparseTensorRewritePass() { - return std::make_unique<SparseTensorRewritePass>(); -} - -std::unique_ptr<Pass> mlir::createSparseTensorRewritePass(bool enableRT, - bool enableForeach, - bool enableConvert) { - return std::make_unique<SparseTensorRewritePass>(enableRT, enableForeach, - enableConvert); +std::unique_ptr<Pass> mlir::createSparseTensorPreRewritePass() { + return std::make_unique<SparseTensorPreRewritePass>(); } std::unique_ptr<Pass> mlir::createSparsificationPass() { @@ -274,6 +283,17 @@ return std::make_unique<SparsificationPass>(options); } +std::unique_ptr<Pass> mlir::createSparseTensorPostRewritePass() { + return std::make_unique<SparseTensorPostRewritePass>(); +} + +std::unique_ptr<Pass> +mlir::createSparseTensorPostRewritePass(bool enableRT, bool enableForeach, + bool enableConvert) { + return std::make_unique<SparseTensorPostRewritePass>(enableRT, enableForeach, + enableConvert); +} + std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { return std::make_unique<SparseTensorConversionPass>(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -1019,11 +1019,15 @@ //===---------------------------------------------------------------------===// // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// -void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns, - bool enableRT, bool enableForeach, - bool enableConvert) { - patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, - ReshapeRewriter<tensor::ExpandShapeOp>, +void mlir::populateSparseTensorPreRewriting(RewritePatternSet &patterns) { + patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd>( + patterns.getContext()); +} + +void mlir::populateSparseTensorPostRewriting(RewritePatternSet &patterns, + bool enableRT, bool enableForeach, + bool enableConvert) { + patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>, ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext()); if (enableForeach) patterns.add<ForeachRewriter>(patterns.getContext()); diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s -// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \ +// RUN: mlir-opt %s --sparse-tensor-post-rewrite="enable-runtime-library=false enable-foreach=false" \ // RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT #SparseVector = #sparse_tensor.encoding<{ diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir --- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s -// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \ +// RUN: mlir-opt %s --sparse-tensor-post-rewrite="enable-runtime-library=false enable-foreach=false" \ // RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT #SparseVector = #sparse_tensor.encoding<{ diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir @@ -6,7 +6,7 @@ // RUN: mlir-opt %s --sparse-tensor-conversion="s2s-strategy=0" \ // RUN: --canonicalize --cse | FileCheck %s -check-prefixes=CHECK-AUTO,CHECK -// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \ +// RUN: mlir-opt %s --sparse-tensor-post-rewrite="enable-runtime-library=false enable-foreach=false" \ // RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT #SparseVector64 = #sparse_tensor.encoding<{ diff --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/rewriting.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -sparse-tensor-rewrite | FileCheck %s +// RUN: mlir-opt %s -sparse-tensor-post-rewrite | FileCheck %s #SparseVector = #sparse_tensor.encoding<{ dimLevelType = ["compressed"] diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" |\ +// RUN: mlir-opt %s -sparse-tensor-post-rewrite="enable-runtime-library=false enable-convert=false" |\ // RUN: FileCheck %s #CSR = #sparse_tensor.encoding<{ diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \ +// RUN: mlir-opt %s --sparse-tensor-post-rewrite="enable-runtime-library=false enable-convert=false" \ // RUN: --sparsification | FileCheck %s #DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-tensor-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s +// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-tensor-pre-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s #DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND // RUN: mlir-opt %s --sparse-tensor-conversion --cse --canonicalize | FileCheck %s --check-prefix=CHECK-CONV -// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \ +// RUN: mlir-opt %s --sparse-tensor-post-rewrite="enable-runtime-library=false enable-convert=false" \ // RUN: --cse --canonicalize | FileCheck %s --check-prefix=CHECK-RWT #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --tensor-copy-insertion --sparse-tensor-rewrite --sparsification --cse | FileCheck %s +// RUN: mlir-opt %s --tensor-copy-insertion --sparse-tensor-pre-rewrite --sparsification --cse | FileCheck %s #SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>