diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -386,8 +386,10 @@ /// Return the buffer (memref) for a given OpOperand (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization was decided. - FailureOr getBuffer(RewriterBase &rewriter, OpOperand &opOperand, - bool forceInPlace = false) const; + FailureOr + getBuffer(RewriterBase &rewriter, OpOperand &opOperand, + bool forceInPlace = false, + Optional customCopyInsertionPoint = None) const; /// Return dialect-specific bufferization state. template diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -377,7 +377,8 @@ /// bufferization is necessary. FailureOr mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer( - RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace) const { + RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace, + Optional customCopyInsertionPoint) const { OpBuilder::InsertionGuard guard(rewriter); Operation *op = opOperand.getOwner(); Location loc = op->getLoc(); @@ -418,9 +419,14 @@ if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) return resultBuffer; - // The copy happens right before the op that is bufferized. - rewriter.setInsertionPoint(op); + if (customCopyInsertionPoint) { + rewriter.setInsertionPoint(*customCopyInsertionPoint); + } else { + // The copy happens right before the op that is bufferized. + rewriter.setInsertionPoint(op); + } createMemCpy(rewriter, loc, operandBuffer, *resultBuffer); + return resultBuffer; }