diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -30,6 +30,9 @@ LogicalResult bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options); +/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp. +void removeBufferizationAttributesInModule(ModuleOp moduleOp); + /// Run One-Shot Module Bufferization on the given module. Performs a simple /// function call analysis to determine which function arguments are /// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -377,6 +377,14 @@ return success(); } +void mlir::bufferization::removeBufferizationAttributesInModule( + ModuleOp moduleOp) { + moduleOp.walk([&](func::FuncOp op) { + for (BlockArgument bbArg : op.getArguments()) + removeBufferizationAttributes(bbArg); + }); +} + LogicalResult mlir::bufferization::bufferizeModuleOp( ModuleOp moduleOp, const OneShotBufferizationOptions &options) { assert(options.bufferizeFunctionBoundaries && @@ -405,10 +413,7 @@ } // Post-pass cleanup of function argument attributes. - moduleOp.walk([&](func::FuncOp op) { - for (BlockArgument bbArg : op.getArguments()) - removeBufferizationAttributes(bbArg); - }); + removeBufferizationAttributesInModule(moduleOp); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" @@ -75,9 +76,14 @@ } return true; }); - return bufferization::bufferizeOp(getOperation(), bufferizationOptions, - /*copyBeforeWrite=*/false, - &denseOpFilter); + + if (failed(bufferization::bufferizeOp(getOperation(), bufferizationOptions, + /*copyBeforeWrite=*/false, + &denseOpFilter))) + return failure(); + + bufferization::removeBufferizationAttributesInModule(getOperation()); + return success(); } void runOnOperation() override {