diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h @@ -127,6 +127,9 @@ FailureOr getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment); +// Remove bufferization attributes on every FuncOp arguments in the ModuleOp. +void removeBufferizationAttributesInModule(ModuleOp moduleOp); + } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -11,7 +11,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/IR/Operation.h" @@ -196,3 +198,23 @@ global->moveBefore(&moduleOp.front()); return global; } + +//===----------------------------------------------------------------------===// +// BufferizationAttributesCleanUp +//===----------------------------------------------------------------------===// + +/// Remove bufferization attributes on FuncOp arguments. +static void removeBufferizationAttributes(BlockArgument bbArg) { + auto funcOp = cast(bbArg.getOwner()->getParentOp()); + funcOp.removeArgAttr(bbArg.getArgNumber(), + BufferizationDialect::kBufferLayoutAttrName); + funcOp.removeArgAttr(bbArg.getArgNumber(), + BufferizationDialect::kWritableAttrName); +} + +void bufferization::removeBufferizationAttributesInModule(ModuleOp moduleOp) { + moduleOp.walk([&](func::FuncOp op) { + for (BlockArgument bbArg : op.getArguments()) + removeBufferizationAttributes(bbArg); + }); +} diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -61,6 +61,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" @@ -201,15 +202,6 @@ } } // namespace -/// Remove bufferization attributes on FuncOp arguments. -static void removeBufferizationAttributes(BlockArgument bbArg) { - auto funcOp = cast(bbArg.getOwner()->getParentOp()); - funcOp.removeArgAttr(bbArg.getArgNumber(), - BufferizationDialect::kBufferLayoutAttrName); - funcOp.removeArgAttr(bbArg.getArgNumber(), - BufferizationDialect::kWritableAttrName); -} - /// Return the func::FuncOp called by `callOp`. static func::FuncOp getCalledFunction(CallOpInterface callOp) { SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast(); @@ -405,10 +397,7 @@ } // Post-pass cleanup of function argument attributes. - moduleOp.walk([&](func::FuncOp op) { - for (BlockArgument bbArg : op.getArguments()) - removeBufferizationAttributes(bbArg); - }); + removeBufferizationAttributesInModule(moduleOp); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" @@ -75,9 +76,14 @@ } return true; }); - return bufferization::bufferizeOp(getOperation(), bufferizationOptions, - /*copyBeforeWrite=*/false, - &denseOpFilter); + + if (failed(bufferization::bufferizeOp(getOperation(), bufferizationOptions, + /*copyBeforeWrite=*/false, + &denseOpFilter))) + return failure(); + + bufferization::removeBufferizationAttributesInModule(getOperation()); + return success(); } void runOnOperation() override {