diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -652,11 +652,7 @@ } /// Return true if `opOperand` bufferizes to a memory write. -/// If inPlaceSpec is different from InPlaceSpec::None, additionally require the -/// write to match the inplace specification. -static bool -bufferizesToMemoryWrite(OpOperand &opOperand, - InPlaceSpec inPlaceSpec = InPlaceSpec::None) { +static bool bufferizesToMemoryWrite(OpOperand &opOperand) { // These terminators are not writes. if (isa(opOperand.getOwner())) return false; @@ -675,14 +671,9 @@ if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner())) return true; OpResult opResult = getAliasingOpResult(opOperand); - // Supported op without a matching result for opOperand (e.g. ReturnOp). - // This does not bufferize to a write. - if (!opResult) - return false; - // If we have a matching OpResult, this is a write. - // Additionally allow to restrict to only inPlace write, if so specified. - return inPlaceSpec == InPlaceSpec::None || - getInPlace(opResult) == inPlaceSpec; + // Only supported op with a matching result for opOperand bufferize to a + // write. E.g., ReturnOp does not bufferize to a write. + return static_cast(opResult); } /// Returns the relationship between the operand and the its corresponding @@ -699,6 +690,15 @@ // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// +/// Return true if opOperand has been decided to bufferize in-place. +static bool isInplaceMemoryWrite(OpOperand &opOperand) { + // Ops that do not bufferize to a memory write, cannot be write in-place. + if (!bufferizesToMemoryWrite(opOperand)) + return false; + OpResult opResult = getAliasingOpResult(opOperand); + return opResult && getInPlace(opResult) == InPlaceSpec::True; +} + BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { rootOp->walk([&](Operation *op) { for (Value v : op->getResults()) @@ -783,7 +783,7 @@ LDBG("-------for : " << printValueInfo(value) << '\n'); for (Value v : getAliases(value)) { for (auto &use : v.getUses()) { - if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) { + if (isInplaceMemoryWrite(use)) { LDBG("-----------wants to bufferize to inPlace write: " << printOperationInfo(use.getOwner()) << '\n'); return true; @@ -906,7 +906,7 @@ for (Value alias : getAliases(root)) { for (auto &use : alias.getUses()) { // Inplace write to a value that aliases root. - if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) { + if (isInplaceMemoryWrite(use)) { LDBG("------------bufferizesToMemoryWrite: " << use.getOwner()->getName().getStringRef() << "\n"); res.insert(&use); @@ -1118,6 +1118,7 @@ if (!candidateOp) continue; OpOperand *operand = getAliasingOpOperand(mit->v.cast()); + // TODO: Should we check for isInplaceMemoryWrite instead? if (!operand || !bufferizesToMemoryWrite(*operand)) continue; LDBG("---->clobbering candidate: " << printOperationInfo(candidateOp)