diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h @@ -63,6 +63,10 @@ *this, "test-bufferization-analysis-only", desc("Run only the inplacability analysis"), init(false)}; + PassOptions::Option enableBufferInitialization{ + *this, "enable-buffer-initialization", + desc("Enable zero-initialization of memory buffers"), init(false)}; + /// Projects out the options for `createSparsificationPass`. SparsificationOptions sparsificationOptions() const { return SparsificationOptions(parallelization); diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -153,8 +153,10 @@ std::unique_ptr createDenseBufferizationPass( const bufferization::OneShotBufferizationOptions &options); -void populateSparseBufferRewriting(RewritePatternSet &patterns); -std::unique_ptr createSparseBufferRewritePass(); +void populateSparseBufferRewriting(RewritePatternSet &patterns, + bool enableBufferInitialization); +std::unique_ptr +createSparseBufferRewritePass(bool enableBufferInitialization = false); //===----------------------------------------------------------------------===// // Registration. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -198,6 +198,10 @@ "scf::SCFDialect", "sparse_tensor::SparseTensorDialect", ]; + let options = [ + Option<"enableBufferInitialization", "enable-buffer-initialization", "bool", + "false", "Enable zero-initialization of the memory buffers">, + ]; } #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -65,7 +65,7 @@ options.sparseTensorConversionOptions())); else pm.addPass(createSparseTensorCodegenPass()); - pm.addPass(createSparseBufferRewritePass()); + pm.addPass(createSparseBufferRewritePass(options.enableBufferInitialization)); pm.addPass(createDenseBufferizationPass( getBufferizationOptions(/*analysisOnly=*/false))); pm.addNestedPass(createCanonicalizerPass()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -635,6 +635,8 @@ struct PushBackRewriter : OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + PushBackRewriter(MLIRContext *context, bool enableInit) + : OpRewritePattern(context), enableBufferInitialization(enableInit) {} LogicalResult matchAndRewrite(PushBackOp op, PatternRewriter &rewriter) const override { // Rewrite push_back(buffer, value, n) to: @@ -705,6 +707,16 @@ Value newBuffer = rewriter.create(loc, bufferType, buffer, capacity); + if (enableBufferInitialization) { + Value fillSize = rewriter.create(loc, capacity, newSize); + Value fillValue = rewriter.create( + loc, value.getType(), rewriter.getZeroAttr(value.getType())); + Value subBuffer = rewriter.create( + loc, newBuffer, /*offset=*/ValueRange{newSize}, + /*size=*/ValueRange{fillSize}, + /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); + rewriter.create(loc, fillValue, subBuffer); + } rewriter.create(loc, newBuffer); // False branch. @@ -731,6 +743,9 @@ rewriter.replaceOp(op, buffer); return success(); } + +private: + bool enableBufferInitialization; }; /// Sparse rewriting rule for the sort operator. @@ -777,6 +792,9 @@ // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// -void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); +void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, + bool enableBufferInitialization) { + patterns.add(patterns.getContext(), + enableBufferInitialization); + patterns.add(patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -215,11 +215,14 @@ SparseBufferRewritePass() = default; SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default; + SparseBufferRewritePass(bool enableInit) { + enableBufferInitialization = enableInit; + } void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - populateSparseBufferRewriting(patterns); + populateSparseBufferRewriting(patterns, enableBufferInitialization); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; @@ -279,6 +282,7 @@ return std::make_unique(); } -std::unique_ptr mlir::createSparseBufferRewritePass() { - return std::make_unique(); +std::unique_ptr +mlir::createSparseBufferRewritePass(bool enableBufferInitialization) { + return std::make_unique(enableBufferInitialization); }