diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -78,6 +78,15 @@ const OpFilter *opFilter = nullptr, BufferizationStatistics *statistics = nullptr); +/// Bufferize the signature of `block`. All block argument types are changed to +/// memref types. +/// +/// It is expected that the parent op of this block implements the +/// `BufferizableOpInterface`. The buffer types of tensor block arguments are +/// computed with `BufferizableOpIntercace::getBufferType`. +LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, + const BufferizationOptions &options); + BufferizationOptions getPartialBufferizationOptions(); } // namespace bufferization diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -535,6 +535,56 @@ return success(); } +LogicalResult +bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, + const BufferizationOptions &options) { + OpBuilder::InsertionGuard g(rewriter); + auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp()); + if (!bufferizableOp) + return failure(); + + // Compute the new signature. + SmallVector newTypes; + for (BlockArgument &bbArg : block->getArguments()) { + auto tensorType = dyn_cast(bbArg.getType()); + if (!tensorType) { + newTypes.push_back(bbArg.getType()); + continue; + } + + FailureOr memrefType = + bufferization::getBufferType(bbArg, options); + if (failed(memrefType)) + return failure(); + newTypes.push_back(*memrefType); + } + + // Change the type of all block arguments. + for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) { + if (bbArg.getType() == type) + continue; + + // Collect all uses of the bbArg. + SmallVector bbArgUses; + for (OpOperand &use : bbArg.getUses()) + bbArgUses.push_back(&use); + + // Change the bbArg type to memref. + bbArg.setType(type); + + // Replace all uses of the original tensor bbArg. + rewriter.setInsertionPointToStart(block); + if (!bbArgUses.empty()) { + Value toTensorOp = + rewriter.create(bbArg.getLoc(), bbArg); + for (OpOperand *use : bbArgUses) + use->set(toTensorOp); + } + } + + return success(); +} + BufferizationOptions bufferization::getPartialBufferizationOptions() { BufferizationOptions options; options.allowUnknownOps = true; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -374,37 +375,11 @@ assert(returnOp && "expected func with single return op"); Location loc = returnOp.getLoc(); - // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. - Block &frontBlock = funcOp.getBody().front(); - for (BlockArgument &bbArg : frontBlock.getArguments()) { - auto tensorType = dyn_cast(bbArg.getType()); - // Non-tensor types stay the same. - if (!tensorType) - continue; - - // Collect all uses of the bbArg. - SmallVector bbArgUses; - for (OpOperand &use : bbArg.getUses()) - bbArgUses.push_back(&use); - - // Change the bbArg type to memref. - FailureOr memrefType = - bufferization::getBufferType(bbArg, options); - if (failed(memrefType)) + // 1. Bufferize every block. + for (Block &block : funcOp.getBody()) + if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, + options))) return failure(); - bbArg.setType(*memrefType); - - // Replace all uses of the original tensor bbArg. - rewriter.setInsertionPointToStart(&frontBlock); - if (!bbArgUses.empty()) { - // Insert to_tensor because the remaining function body has not been - // bufferized yet. - Value toTensorOp = - rewriter.create(funcOp.getLoc(), bbArg); - for (OpOperand *use : bbArgUses) - use->set(toTensorOp); - } - } // 2. For each result, keep track of which inplace argument it reuses. SmallVector returnValues;