diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -56,10 +56,9 @@ /// `alias`. Additionally, merge their equivalence classes. void insertNewBufferEquivalence(Value newValue, Value alias); - /// Return true if the buffer to which `operand` would bufferize aliases a - /// buffer that is known to not be writable. This implies that the matching - /// OpResult cannot be bufferized inplace. - bool aliasesNonWritableBuffer(OpOperand &operand) const; + /// Return true if, under current bufferization decisions, the buffer of + /// `value` is not writable. + bool aliasesNonWritableBuffer(Value value) const; /// Return true if the buffer to which `operand` would bufferize is equivalent /// to some buffer write. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -736,40 +736,36 @@ equivalentInfo.unionSets(newValue, alias); } -/// Return true if the buffer to which `operand` would bufferize aliases a -/// buffer that is known to not be writable. This implies that the matching -/// OpResult cannot be bufferized inplace. -bool BufferizationAliasInfo::aliasesNonWritableBuffer( - OpOperand &operand) const { +/// Return true if, under current bufferization decisions, the buffer of `value` +/// is not writable. +bool BufferizationAliasInfo::aliasesNonWritableBuffer(Value value) const { LDBG("----Start aliasesNonWritableBuffer\n"); - LDBG("-------for -> #" << operand.getOperandNumber() << ": " - << printOperationInfo(operand.getOwner()) << '\n'); - for (Value v : getAliases(operand.get())) { + for (Value v : getAliases(value)) { LDBG("-----------examine: " << printValueInfo(v) << '\n'); if (bufferizesToWritableMemory(v)) { - LDBG("-----------Value is known to be writeable -> skip: " + LDBG("-----------Value is known to be writable -> skip: " << printValueInfo(v) << '\n'); continue; } if (auto bbArg = v.dyn_cast()) { if (getInPlace(bbArg) == InPlaceSpec::True) { - LDBG("-----------bbArg is writeable -> skip: " << printValueInfo(bbArg) - << '\n'); + LDBG("-----------bbArg is writable -> skip: " << printValueInfo(bbArg) + << '\n'); continue; } - LDBG("-----------notWriteable\n"); + LDBG("-----------notWritable bbArg\n"); return true; } if (Operation *op = v.getDefiningOp()) { if (isa(op) || !hasKnownBufferizationAliasingBehavior(op)) { - LDBG("-----------notWritable\n"); + LDBG("-----------notWritable op\n"); return true; } } } - LDBG("---->operand is writable\n"); + LDBG("---->value is writable\n"); return false; } @@ -2239,7 +2235,7 @@ // aliasing a write into a non-writable buffer. bool wouldCreateAliasingWriteToNonWritableBuffer = aliasInfo.aliasesInPlaceWrite(extractSliceOp.result()) && - aliasInfo.aliasesNonWritableBuffer(extractSliceOp->getOpOperand(0)); + aliasInfo.aliasesNonWritableBuffer(extractSliceOp.source()); if (wouldCreateAliasingWriteToNonWritableBuffer) LDBG("->the corresponding buffer is not writable\n"); @@ -2292,7 +2288,7 @@ // 2. a constant op. // to be considered for inplace bufferization bool wouldCreateAliasingWriteToNonWritableBuffer = - aliasInfo.aliasesNonWritableBuffer(operand); + aliasInfo.aliasesNonWritableBuffer(operand.get()); if (wouldCreateAliasingWriteToNonWritableBuffer) LDBG("->the corresponding buffer is not writable\n"); else