diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -649,11 +649,7 @@ } /// Return true if `opOperand` bufferizes to a memory write. -/// If inPlaceSpec is different from InPlaceSpec::None, additionally require the -/// write to match the inplace specification. -static bool -bufferizesToMemoryWrite(OpOperand &opOperand, - InPlaceSpec inPlaceSpec = InPlaceSpec::None) { +static bool bufferizesToMemoryWrite(OpOperand &opOperand) { // These terminators are not writes. if (isa(opOperand.getOwner())) return false; @@ -672,14 +668,9 @@ if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner())) return true; OpResult opResult = getAliasingOpResult(opOperand); - // Supported op without a matching result for opOperand (e.g. ReturnOp). - // This does not bufferize to a write. - if (!opResult) - return false; - // If we have a matching OpResult, this is a write. - // Additionally allow to restrict to only inPlace write, if so specified. - return inPlaceSpec == InPlaceSpec::None || - getInPlace(opResult) == inPlaceSpec; + // Only supported op with a matching result for opOperand bufferize to a + // write. E.g., ReturnOp does not bufferize to a write. + return static_cast(opResult); } /// Specify fine-grain relationship between buffers to enable more analysis. @@ -704,6 +695,15 @@ // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// +/// Return true if opOperand has been decided to bufferize in-place. +static bool isInplaceMemoryWrite(OpOperand &opOperand) { + // Ops that do not bufferize to a memory write, cannot be write in-place. + if (!bufferizesToMemoryWrite(opOperand)) + return false; + OpResult opResult = getAliasingOpResult(opOperand); + return opResult && getInPlace(opResult) == InPlaceSpec::True; +} + namespace { /// The BufferizationAliasInfo class maintains a list of buffer aliases and @@ -966,7 +966,7 @@ LDBG("-------for : " << printValueInfo(value) << '\n'); for (Value v : getAliases(value)) { for (auto &use : v.getUses()) { - if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) { + if (isInplaceMemoryWrite(use)) { LDBG("-----------wants to bufferize to inPlace write: " << printOperationInfo(use.getOwner()) << '\n'); return true; @@ -1093,7 +1093,7 @@ for (Value alias : getAliases(root)) { for (auto &use : alias.getUses()) { // Inplace write to a value that aliases root. - if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) { + if (isInplaceMemoryWrite(use)) { LDBG("------------bufferizesToMemoryWrite: " << use.getOwner()->getName().getStringRef() << "\n"); res.insert(&use); @@ -1276,6 +1276,7 @@ if (!candidateOp) continue; OpOperand *operand = getAliasingOpOperand(mit->v.cast()); + // TODO: Should we check for isInplaceMemoryWrite instead? if (!operand || !bufferizesToMemoryWrite(*operand)) continue; LDBG("---->clobbering candidate: " << printOperationInfo(candidateOp)