diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -651,11 +651,7 @@ } /// Return true if `opOperand` bufferizes to a memory write. -/// If inPlaceSpec is different from InPlaceSpec::None, additionally require the -/// write to match the inplace specification. -static bool -bufferizesToMemoryWrite(OpOperand &opOperand, - InPlaceSpec inPlaceSpec = InPlaceSpec::None) { +static bool bufferizesToMemoryWrite(OpOperand &opOperand) { // These terminators are not writes. if (isa(opOperand.getOwner())) return false; @@ -674,14 +670,9 @@ if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner())) return true; OpResult opResult = getAliasingOpResult(opOperand); - // Supported op without a matching result for opOperand (e.g. ReturnOp). - // This does not bufferize to a write. - if (!opResult) - return false; - // If we have a matching OpResult, this is a write. - // Additionally allow to restrict to only inPlace write, if so specified. - return inPlaceSpec == InPlaceSpec::None || - getInPlace(opResult) == inPlaceSpec; + // Only supported op with a matching result for opOperand bufferize to a + // write. E.g., ReturnOp does not bufferize to a write. + return static_cast(opResult); } /// Returns the relationship between the operand and the its corresponding @@ -698,6 +689,15 @@ // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// +/// Return true if opOperand has been decided to bufferize in-place. +static bool isInplaceMemoryWrite(OpOperand &opOperand) { + // Ops that do not bufferize to a memory write, cannot be write in-place. + if (!bufferizesToMemoryWrite(opOperand)) + return false; + OpResult opResult = getAliasingOpResult(opOperand); + return opResult && getInPlace(opResult) == InPlaceSpec::True; +} + BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { rootOp->walk([&](Operation *op) { for (Value v : op->getResults()) @@ -782,7 +782,7 @@ LDBG("-------for : " << printValueInfo(value) << '\n'); for (Value v : getAliases(value)) { for (auto &use : v.getUses()) { - if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) { + if (isInplaceMemoryWrite(use)) { LDBG("-----------wants to bufferize to inPlace write: " << printOperationInfo(use.getOwner()) << '\n'); return true; @@ -909,7 +909,7 @@ for (Value alias : getAliases(root)) { for (auto &use : alias.getUses()) { // Inplace write to a value that aliases root. - if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) { + if (isInplaceMemoryWrite(use)) { LDBG("------------bufferizesToMemoryWrite: " << use.getOwner()->getName().getStringRef() << "\n"); res.insert(&use); @@ -1124,6 +1124,7 @@ if (!candidateOp) continue; OpOperand *operand = getAliasingOpOperand(mit->v.cast()); + // TODO: Should we check for isInplaceMemoryWrite instead? if (!operand || !bufferizesToMemoryWrite(*operand)) continue; LDBG("---->clobbering candidate: " << printOperationInfo(candidateOp)