diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -170,14 +170,6 @@ return returnOp; } -/// Return true if `value` is the result of an InitTensorOp or a cast thereof. -static bool isInitTensorOp(Value value) { - tensor::CastOp castOp; - while ((castOp = value.getDefiningOp())) - value = castOp.source(); - return value.getDefiningOp(); -} - //===----------------------------------------------------------------------===// // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// @@ -466,9 +458,8 @@ /// Furthermore, BlockArguments are also assumed to be writes. There is no /// analysis across block boundaries. /// -/// Note: To simplify the analysis, scf.if ops are considered writes. Treating -/// a non-writing op as a writing op may introduce unnecessary out-of-place -/// bufferizations, but is always safe from a correctness point of view. +/// Note: When reaching an end of the reverse SSA use-def chain, that value +/// is returned regardless of whether it is a memory write or not. static Value findLastPrecedingWrite(Value value) { SetVector result = findValueInReverseUseDefChain(value, [](Value value) { @@ -481,6 +472,10 @@ return bufferizableOp.isMemoryWrite(value.cast()); }); + // To simplify the analysis, `scf.if` ops are considered memory writes. There + // are currently no other ops where one OpResult may alias with multiple + // OpOperands. Therefore, this function should return exactly one result at + // the moment. assert(result.size() == 1 && "expected exactly one result"); return result.front(); } @@ -1028,9 +1023,16 @@ // Allocate the result buffer. Value resultBuffer = createNewAllocDeallocPairForShapedValue( b, loc, operand, aliasInfo, allocationFns); - // Do not copy the result of an InitTensorOp. - if (isInitTensorOp(operand)) - skipCopy = true; + // Do not copy if the last preceding write of `operand` is an op that does + // not write (skipping ops that merely create aliases). E.g., InitTensorOp. + // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA + // use-def chain, it returns that value, regardless of whether it is a + // memory write or not. + Value lastWrite = findLastPrecedingWrite(operand); + if (auto bufferizableOp = + lastWrite.getDefiningOp()) + if (!bufferizableOp.isMemoryWrite(lastWrite.cast())) + skipCopy = true; // Do not copy if the copied data is never read. if (!isValueRead(result)) skipCopy = true; @@ -2226,6 +2228,11 @@ return {}; } + bool isMemoryWrite(Operation *op, OpResult opResult) const { + // InitTensorOps allocate but do not write. + return false; + } + LogicalResult bufferize(Operation *op, OpBuilder &b, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,