diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -133,8 +133,9 @@ RewritePatternSet &patterns, bool enableBufferInitialization); +std::unique_ptr createSparseTensorCodegenPass(); std::unique_ptr -createSparseTensorCodegenPass(bool enableBufferInitialization = false); +createSparseTensorCodegenPass(bool enableBufferInitialization); //===----------------------------------------------------------------------===// // The SparseTensorRewriting pass. @@ -157,8 +158,10 @@ void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization); + +std::unique_ptr createSparseBufferRewritePass(); std::unique_ptr -createSparseBufferRewritePass(bool enableBufferInitialization = false); +createSparseBufferRewritePass(bool enableBufferInitialization); //===----------------------------------------------------------------------===// // Registration. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -283,11 +283,19 @@ return std::make_unique(options); } +std::unique_ptr mlir::createSparseTensorCodegenPass() { + return std::make_unique(); +} + std::unique_ptr mlir::createSparseTensorCodegenPass(bool enableBufferInitialization) { return std::make_unique(enableBufferInitialization); } +std::unique_ptr mlir::createSparseBufferRewritePass() { + return std::make_unique(); +} + std::unique_ptr mlir::createSparseBufferRewritePass(bool enableBufferInitialization) { return std::make_unique(enableBufferInitialization);