diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -60,6 +60,35 @@ llvm_unreachable("bufferizesToMemoryWrite not implemented"); }] >, + InterfaceMethod< + /*desc=*/[{ + Return `true` if the given OpResult is a memory write. This is the + case if in the following cases: + + * The corresponding aliasing OpOperand bufferizes to a memory write. + * Or: There is no corresponding aliasing OpOperand. + + If the OpResult has multiple aliasing OpOperands, this method + returns `true` if at least one of them bufferizes to a memory write. + }], + /*retType=*/"bool", + /*methodName=*/"isMemoryWrite", + /*args=*/(ins "OpResult":$opResult), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto bufferizableOp = + cast($_op.getOperation()); + SmallVector opOperands = + bufferizableOp.getAliasingOpOperand(opResult); + if (opOperands.empty()) + return true; + return llvm::any_of( + opOperands, + [&](OpOperand *operand) { + return bufferizableOp.bufferizesToMemoryWrite(*operand); + }); + }] + >, InterfaceMethod< /*desc=*/[{ Return the OpResult that aliases with a given OpOperand when diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -736,16 +736,7 @@ return true; if (isa(op)) return true; - - SmallVector opOperands = - bufferizableOp.getAliasingOpOperand(value.cast()); - assert(opOperands.size() <= 1 && - "op with multiple aliasing OpOperands not expected"); - - if (opOperands.empty()) - return true; - - return bufferizesToMemoryWrite(*opOperands.front()); + return bufferizableOp.isMemoryWrite(value.cast()); }); assert(result.size() == 1 && "expected exactly one result"); return result.front();