diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -522,10 +522,13 @@ /// Create alloc/dealloc ops as specified in the bufferization options. If /// `onlyLeakingAlloc`, only those buffer allocations are processed for which no -/// buffer deallocation can be created. +/// buffer deallocation can be created. `changed` is set to `true` if the IR was +/// modified. LogicalResult createAllocDeallocOps(Operation *op, const BufferizationOptions &options, - bool onlyLeakingAllocs = false); + bool onlyLeakingAllocs = false, + bool *changed = nullptr); + } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -512,8 +512,10 @@ LogicalResult bufferization::createAllocDeallocOps(Operation *op, const BufferizationOptions &options, - bool onlyLeakingAllocs) { + bool onlyLeakingAllocs, bool *changed) { IRRewriter rewriter(op->getContext()); + if (changed) + *changed = false; // Bufferization creates memref.alloca ops. After bufferization, these must be // rewritten to alloc/dealloc ops as specified in the bufferization options. @@ -536,6 +538,8 @@ if (failed(alloc)) return WalkResult::interrupt(); rewriter.replaceOp(allocaOp, *alloc); + if (changed) + *changed = true; // Stop here if the buffer should not be deallocated. if (skipDealloc) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -304,12 +304,21 @@ LogicalResult bufferization::finalizeBuffers(Operation *op, const BufferizationOptions &options) { + // Hoist buffers. if (failed(hoistBufferAllocations(op, options))) return failure(); - if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true))) + + // Deallocate buffers that escape block boundaries ("leaking buffers") with + // the buffer deallocation pass. + bool hasLeakingAlloc = false; + if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true, + &hasLeakingAlloc))) return failure(); - if (options.createDeallocs && failed(deallocateBuffers(op))) + if (options.createDeallocs && hasLeakingAlloc && + failed(deallocateBuffers(op))) return failure(); + + // Deallocate all remaining buffers at the end of the block. if (failed(createAllocDeallocOps(op, options))) return failure();