diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -654,11 +654,7 @@ } /// Return true if `opOperand` bufferizes to a memory write. -/// If inPlaceSpec is different from InPlaceSpec::None, additionally require the -/// write to match the inplace specification. -static bool -bufferizesToMemoryWrite(OpOperand &opOperand, - InPlaceSpec inPlaceSpec = InPlaceSpec::None) { +static bool bufferizesToMemoryWrite(OpOperand &opOperand) { // These terminators are not writes. if (isa(opOperand.getOwner())) return false; @@ -677,14 +673,9 @@ if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner())) return true; OpResult opResult = getAliasingOpResult(opOperand); - // Supported op without a matching result for opOperand (e.g. ReturnOp). - // This does not bufferize to a write. - if (!opResult) - return false; - // If we have a matching OpResult, this is a write. - // Additionally allow to restrict to only inPlace write, if so specified. - return inPlaceSpec == InPlaceSpec::None || - getInPlace(opResult) == inPlaceSpec; + // Only supported op with a matching result for opOperand bufferize to a + // write. E.g., ReturnOp does not bufferize to a write. + return static_cast(opResult); } /// Returns the relationship between the operand and the its corresponding @@ -701,6 +692,15 @@ // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// +/// Return true if opOperand has been decided to bufferize in-place. +static bool isInplaceMemoryWrite(OpOperand &opOperand) { + // Ops that do not bufferize to a memory write, cannot be write in-place. + if (!bufferizesToMemoryWrite(opOperand)) + return false; + OpResult opResult = getAliasingOpResult(opOperand); + return opResult && getInPlace(opResult) == InPlaceSpec::True; +} + BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { rootOp->walk([&](Operation *op) { for (Value v : op->getResults()) @@ -785,7 +785,7 @@ LDBG("-------for : " << printValueInfo(value) << '\n'); for (Value v : getAliases(value)) { for (auto &use : v.getUses()) { - if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) { + if (isInplaceMemoryWrite(use)) { LDBG("-----------wants to bufferize to inPlace write: " << printOperationInfo(use.getOwner()) << '\n'); return true; @@ -914,7 +914,7 @@ for (Value alias : getAliases(root)) { for (auto &use : alias.getUses()) { // Inplace write to a value that aliases root. - if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) { + if (isInplaceMemoryWrite(use)) { LDBG("------------bufferizesToMemoryWrite: " << use.getOwner()->getName().getStringRef() << "\n"); res.insert(&use); @@ -1135,6 +1135,7 @@ SmallVector operands = getAliasingOpOperand(mit->v.cast()); assert(operands.size() <= 1 && "more than 1 OpOperand not supported yet"); + // TODO: Should we check for isInplaceMemoryWrite instead? if (operands.empty() || !bufferizesToMemoryWrite(*operands.front())) continue; LDBG("---->clobbering candidate: " << printOperationInfo(candidateOp)