diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -991,13 +991,13 @@ static Value getResultBuffer(OpBuilder &b, OpResult result, const BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, - AllocationCallbacks allocationFns, - bool skipCopy = false) { + AllocationCallbacks allocationFns) { OpBuilder::InsertionGuard guard(b); Operation *op = result.getOwner(); SmallVector aliasingOperands = getAliasingOpOperand(result); assert(!aliasingOperands.empty() && "could not get aliasing OpOperand"); - Value operand = aliasingOperands.front()->get(); + OpOperand *opOperand = aliasingOperands.front(); + Value operand = opOperand->get(); Value operandBuffer = lookup(bvm, operand); assert(operandBuffer && "operand buffer not found"); // Make sure that all OpOperands are the same buffer. If this is not the case, @@ -1023,6 +1023,7 @@ // Allocate the result buffer. Value resultBuffer = createNewAllocDeallocPairForShapedValue( b, loc, operand, aliasInfo, allocationFns); + bool skipCopy = false; // Do not copy if the last preceding write of `operand` is an op that does // not write (skipping ops that merely create aliases). E.g., InitTensorOp. // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA @@ -1036,6 +1037,10 @@ // Do not copy if the copied data is never read. if (!isValueRead(result)) skipCopy = true; + // Do not copy if this op does not read the data, but writes it. + if (bufferizesToMemoryWrite(*opOperand) && + !bufferizesToMemoryRead(*opOperand)) + skipCopy = true; if (!skipCopy) { // Set insertion point now that potential alloc/dealloc are introduced. b.setInsertionPoint(op); @@ -2108,9 +2113,8 @@ OpResult opResult = cast(op.getOperation()) .getAliasingOpResult(*opOperand); assert(opResult && "could not find correspond OpResult"); - bool skipCopy = !op.payloadUsesValueFromOperand(opOperand); Value resultBuffer = - getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns, skipCopy); + getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns); if (!resultBuffer) return failure(); resultBuffers.push_back(resultBuffer); @@ -2175,8 +2179,9 @@ OpTy> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { auto genericOp = cast(op); - return genericOp.isInputTensor(&opOperand) || - genericOp.isInitTensor(&opOperand); + return (genericOp.isInputTensor(&opOperand) || + genericOp.isInitTensor(&opOperand)) && + genericOp.payloadUsesValueFromOperand(&opOperand); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {