diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -158,14 +158,21 @@ std::unique_ptr createSparseTensorCodegenPass(); //===----------------------------------------------------------------------===// -// Other rewriting rules and passes. +// The SparseTensorRewriting pass. //===----------------------------------------------------------------------===// -void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT); +void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT, + bool enableForeach, bool enableConvert); std::unique_ptr createSparseTensorRewritePass(); std::unique_ptr -createSparseTensorRewritePass(const SparsificationOptions &options); +createSparseTensorRewritePass(const SparsificationOptions &options, + bool enableForeach = true, + bool enableConvert = true); + +//===----------------------------------------------------------------------===// +// Other rewriting rules and passes. +//===----------------------------------------------------------------------===// std::unique_ptr createDenseBufferizationPass( const bufferization::OneShotBufferizationOptions &options); diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -28,7 +28,11 @@ ]; let options = [ Option<"enableRuntimeLibrary", "enable-runtime-library", "bool", - "true", "Enable runtime library for manipulating sparse tensors"> + "true", "Enable runtime library for manipulating sparse tensors">, + Option<"enableForeach", "enable-foreach", "bool", + "true", "Enable rewriting rules for the foreach operator">, + Option<"enableConvert", "enable-convert", "bool", + "true", "Enable rewriting rules for the convert operator">, ]; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -43,14 +43,18 @@ SparseTensorRewritePass() = default; SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default; - SparseTensorRewritePass(const SparsificationOptions &options) { + SparseTensorRewritePass(const SparsificationOptions &options, bool foreach, + bool convert) { enableRuntimeLibrary = options.enableRuntimeLibrary; + enableForeach = foreach; + enableConvert = convert; } void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - populateSparseTensorRewriting(patterns, enableRuntimeLibrary); + populateSparseTensorRewriting(patterns, enableRuntimeLibrary, enableForeach, + enableConvert); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; @@ -255,8 +259,10 @@ } std::unique_ptr -mlir::createSparseTensorRewritePass(const SparsificationOptions &options) { - return std::make_unique(options); +mlir::createSparseTensorRewritePass(const SparsificationOptions &options, + bool enableForeach, bool enableConvert) { + return std::make_unique(options, enableForeach, + enableConvert); } std::unique_ptr mlir::createSparsificationPass() { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -612,11 +612,14 @@ // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns, - bool enableRT) { + bool enableRT, bool enableForeach, + bool /*enableConvert*/) { patterns.add, - ReshapeRewriter, ForeachRewriter>( - patterns.getContext()); + ReshapeRewriter>(patterns.getContext()); + if (enableForeach) + patterns.add(patterns.getContext()); + // TODO: If RT not enabled, rewrite concatenate ops, etc here. if (!enableRT) patterns.add