diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1561,7 +1561,8 @@ /// bufferization is necessary. static Value getResultBuffer(OpBuilder &b, OpResult result, const BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + bool skipCopy = false) { OpBuilder::InsertionGuard guard(b); Operation *op = result.getOwner(); Optional maybeOperand = getAliasingOpOperand(result); @@ -1576,7 +1577,7 @@ // Allocate the result buffer. Value resultBuffer = createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo); - if (!isInitTensorOp(operand)) { + if (!skipCopy && !isInitTensorOp(operand)) { // Set insertion point now that potential alloc/dealloc are introduced. b.setInsertionPoint(op); b.create(loc, operandBuffer, resultBuffer); @@ -1589,13 +1590,8 @@ } /// Helper function for LinalgOp bufferization. -/// Examines each result and determines whether it bufferizes inplace on an -/// operand. -/// If the opResult bufferizes inplace, just reuse the existing buffer. -/// Otherwise allocate a new buffer to hold the result. /// When allocating a new buffer, analyze whether `op` want to read form that -/// buffer. In such a case, insert a copy to ensure the newly allocated buffer -/// is properly initialiazed. +/// buffer. Only in that case, a copy of the result buffer may be needed. static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, SmallVectorImpl &resultBuffers, BlockAndValueMapping &bvm, @@ -1607,31 +1603,11 @@ // TODO: provide the proper interface to iterate on OpResults and get the // matching OpOperands. for (OpOperand *opOperand : op.getOutputOperands()) { - Value output = opOperand->get(); - assert(output.getType().isa() && "expected tensor type"); - - // If output tensor is marked inPlace, just use the buffer. - // The following uses internal knowledge of the position of inplaceable - // operand / results. OpResult opResult = getInplaceableOpResult(*opOperand); - if (getInPlace(opResult) == InPlaceSpec::True) { - Value v = lookup(bvm, output); - assert(v && "missing buffer"); - resultBuffers.push_back(v); - continue; - } - - // Otherwise, `op` is not inplaceable and we need to allocate its result. - Value dimTensor = bvm.lookupOrDefault(output); - Value alloc = - createNewAllocDeallocPairForShapedValue(b, loc, dimTensor, aliasInfo); - resultBuffers.push_back(alloc); - - // Additionally, if the output buffer is used, clone its value for now. - if (op.payloadUsesValueFromOperand(opOperand)) { - Value v = lookup(bvm, output); - b.create(loc, v, alloc); - } + assert(opResult && "could not find correspond OpResult"); + bool skipCopy = !op.payloadUsesValueFromOperand(opOperand); + Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo, skipCopy); + resultBuffers.push_back(resultBuffer); } if (op->getNumResults())