diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h @@ -73,6 +73,16 @@ *this, "enable-buffer-initialization", desc("Enable zero-initialization of memory buffers"), init(false)}; + PassOptions::Option createSparseDeallocs{ + *this, "create-sparse-deallocs", + desc("Specify if the temporary sparse buffer created by the sparse " + "compiler should be deallocated. For compatibility with core " + "bufferization passes. " + "It only takes effect when enable-runtime-library=false, otherwise " + "the memory storage for sparse tensors are managed by the runtime " + "library. See also create-deallocs for BufferizationOption."), + init(true)}; + PassOptions::Option vectorLength{ *this, "vl", desc("Set the vector length (0 disables vectorization)"), init(0)}; diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -132,11 +132,13 @@ /// Sets up sparse tensor conversion rules. void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, + bool createSparseDeallocs, bool enableBufferInitialization); std::unique_ptr createSparseTensorCodegenPass(); std::unique_ptr -createSparseTensorCodegenPass(bool enableBufferInitialization); +createSparseTensorCodegenPass(bool createSparseDeallocs, + bool enableBufferInitialization); //===----------------------------------------------------------------------===// // The PreSparsificationRewriting pass. @@ -180,8 +182,9 @@ const bufferization::OneShotBufferizationOptions &bufferizationOptions, const SparsificationOptions &sparsificationOptions, const SparseTensorConversionOptions &sparseTensorConversionOptions, - bool enableRuntimeLibrary, bool enableBufferInitialization, - unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32); + bool createSparseDeallocs, bool enableRuntimeLibrary, + bool enableBufferInitialization, unsigned vectorLength, + bool enableVLAVectorization, bool enableSIMDIndex32); void populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization); diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -220,6 +220,13 @@ let options = [ Option<"enableBufferInitialization", "enable-buffer-initialization", "bool", "false", "Enable zero-initialization of the memory buffers">, + Option<"createSparseDeallocs", "create-sparse-deallocs", "bool", + "true", "Specify if the temporary sparse buffer created by the sparse " + "compiler should be deallocated. For compatibility with core " + "bufferization passes. " + "It only takes effect when enable-runtime-library=false, otherwise " + "the memory storage for sparse tensors are managed by the runtime " + "library. See also create-deallocs for BufferizationOption.">, ]; } diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -56,8 +56,8 @@ pm.addPass(createSparsificationAndBufferizationPass( getBufferizationOptions(options.testBufferizationAnalysisOnly), options.sparsificationOptions(), options.sparseTensorConversionOptions(), - options.enableRuntimeLibrary, options.enableBufferInitialization, - options.vectorLength, + options.createSparseDeallocs, options.enableRuntimeLibrary, + options.enableBufferInitialization, options.vectorLength, /*enableVLAVectorization=*/options.armSVE, /*enableSIMDIndex32=*/options.force32BitVectorIndices)); if (options.testBufferizationAnalysisOnly) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -780,6 +780,11 @@ : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + SparseTensorDeallocConverter(TypeConverter &typeConverter, + MLIRContext *context, bool createDeallocs) + : OpConversionPattern(typeConverter, context), + createDeallocs(createDeallocs) {} + LogicalResult matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -787,16 +792,22 @@ if (!enc) return failure(); - // Replace the sparse tensor deallocation with field deallocations. - Location loc = op.getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); - for (auto input : desc.getMemRefFields()) - // Deallocate every buffer used to store the sparse tensor handler. - rewriter.create(loc, input); - + // If user requests not to deallocate sparse tensors, simply erase the + // operation. + if (createDeallocs) { + // Replace the sparse tensor deallocation with field deallocations. + Location loc = op.getLoc(); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + for (auto input : desc.getMemRefFields()) + // Deallocate every buffer used to store the sparse tensor handler. + rewriter.create(loc, input); + } rewriter.eraseOp(op); return success(); } + +private: + bool createDeallocs; }; /// Sparse codegen rule for tensor rematerialization. @@ -1492,13 +1503,12 @@ /// the sparsification of linear algebra operations. void mlir::populateSparseTensorCodegenPatterns( TypeConverter &typeConverter, RewritePatternSet &patterns, - bool enableBufferInitialization) { + bool createSparseDeallocs, bool enableBufferInitialization) { patterns.add, SparseSliceGetterOpConverter(typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext(), createSparseDeallocs); patterns.add(typeConverter, patterns.getContext(), enableBufferInitialization); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -181,7 +181,8 @@ SparseTensorCodegenPass() = default; SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default; - SparseTensorCodegenPass(bool enableInit) { + SparseTensorCodegenPass(bool createDeallocs, bool enableInit) { + createSparseDeallocs = createDeallocs; enableBufferInitialization = enableInit; } @@ -232,8 +233,8 @@ converter); scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, target); - populateSparseTensorCodegenPatterns(converter, patterns, - enableBufferInitialization); + populateSparseTensorCodegenPatterns( + converter, patterns, createSparseDeallocs, enableBufferInitialization); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); @@ -378,8 +379,10 @@ } std::unique_ptr -mlir::createSparseTensorCodegenPass(bool enableBufferInitialization) { - return std::make_unique(enableBufferInitialization); +mlir::createSparseTensorCodegenPass(bool createSparseDeallocs, + bool enableBufferInitialization) { + return std::make_unique(createSparseDeallocs, + enableBufferInitialization); } std::unique_ptr mlir::createSparseBufferRewritePass() { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -56,12 +56,13 @@ const bufferization::OneShotBufferizationOptions &bufferizationOptions, const SparsificationOptions &sparsificationOptions, const SparseTensorConversionOptions &sparseTensorConversionOptions, - bool enableRuntimeLibrary, bool enableBufferInitialization, - unsigned vectorLength, bool enableVLAVectorization, - bool enableSIMDIndex32) + bool createSparseDeallocs, bool enableRuntimeLibrary, + bool enableBufferInitialization, unsigned vectorLength, + bool enableVLAVectorization, bool enableSIMDIndex32) : bufferizationOptions(bufferizationOptions), sparsificationOptions(sparsificationOptions), sparseTensorConversionOptions(sparseTensorConversionOptions), + createSparseDeallocs(createSparseDeallocs), enableRuntimeLibrary(enableRuntimeLibrary), enableBufferInitialization(enableBufferInitialization), vectorLength(vectorLength), @@ -147,7 +148,8 @@ pm.addPass( createSparseTensorConversionPass(sparseTensorConversionOptions)); } else { - pm.addPass(createSparseTensorCodegenPass(enableBufferInitialization)); + pm.addPass(createSparseTensorCodegenPass(createSparseDeallocs, + enableBufferInitialization)); pm.addPass(createSparseBufferRewritePass(enableBufferInitialization)); pm.addPass(createStorageSpecifierToLLVMPass()); } @@ -164,6 +166,7 @@ bufferization::OneShotBufferizationOptions bufferizationOptions; SparsificationOptions sparsificationOptions; SparseTensorConversionOptions sparseTensorConversionOptions; + bool createSparseDeallocs; bool enableRuntimeLibrary; bool enableBufferInitialization; unsigned vectorLength; @@ -178,13 +181,13 @@ const bufferization::OneShotBufferizationOptions &bufferizationOptions, const SparsificationOptions &sparsificationOptions, const SparseTensorConversionOptions &sparseTensorConversionOptions, - bool enableRuntimeLibrary, bool enableBufferInitialization, - unsigned vectorLength, bool enableVLAVectorization, - bool enableSIMDIndex32) { + bool createSparseDeallocs, bool enableRuntimeLibrary, + bool enableBufferInitialization, unsigned vectorLength, + bool enableVLAVectorization, bool enableSIMDIndex32) { return std::make_unique< mlir::sparse_tensor::SparsificationAndBufferizationPass>( bufferizationOptions, sparsificationOptions, - sparseTensorConversionOptions, enableRuntimeLibrary, + sparseTensorConversionOptions, createSparseDeallocs, enableRuntimeLibrary, enableBufferInitialization, vectorLength, enableVLAVectorization, enableSIMDIndex32); } diff --git a/mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir b/mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false" \ +// RUN: --sparse-tensor-codegen=create-sparse-deallocs=false \ +// RUN: --canonicalize --cse | FileCheck %s -check-prefix=CHECK-NO-DEALLOC + +// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false" \ +// RUN: --sparse-tensor-codegen=create-sparse-deallocs=true \ +// RUN: --canonicalize --cse | FileCheck %s -check-prefix=CHECK-DEALLOC + +#CSR = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed"]}> +#CSC = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "compressed"], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + +// +// No memref.dealloc is user-requested so +// CHECK-NO-DEALLOC-LABEL: @sparse_convert_permuted +// CHECK-NO-DEALLOC-NOT: memref.dealloc +// +// Otherwise memref.dealloc is created to free temporary sparse buffers. +// CHECK-DEALLOC-LABEL: @sparse_convert_permuted +// CHECK-DEALLOC: memref.dealloc +// +func.func @sparse_convert_permuted(%arg0: tensor) -> tensor { + %0 = sparse_tensor.convert %arg0 : tensor to tensor + return %0 : tensor +}