diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -42,9 +42,16 @@ InterfaceMethod< /*desc=*/[{ Return `true` if the given OpOperand bufferizes to a memory write. + This method will never be called on OpOperands that do not have a tensor type. + This method will never be called on OpOperands that do not have an + aliasing OpResult. Intuitively, it does not make sense for an + OpOperand to bufferize to a memory write without returning an aliasing + tensor, because the write would have no visible effect outside of the + op. + Note: It is always safe to consider an OpOperand as a memory write, even if it does actually not write; however, this can introduce unnecessary out-of-place bufferization decisions. The analysis of @@ -57,6 +64,8 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpOperands. + // Does not have to be implemented for OpOperands that do not have an + // aliasing OpResult. llvm_unreachable("bufferizesToMemoryWrite not implemented"); }] >, diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -239,11 +239,15 @@ /// Return true if opOperand has been decided to bufferize in-place. static bool isInplaceMemoryWrite(OpOperand &opOperand, const BufferizationAliasInfo &aliasInfo) { - // Ops that do not bufferize to a memory write, cannot be write in-place. + // OpOperands without an aliasing OpResult do not write. + OpResult opResult = getAliasingOpResult(opOperand); + if (!opResult) + return false; + // OpOperands that do not bufferize to a memory write do not write in-place. if (!bufferizesToMemoryWrite(opOperand)) return false; - OpResult opResult = getAliasingOpResult(opOperand); - return opResult && aliasInfo.isInPlace(opResult); + // Check current bufferization decisions. + return aliasInfo.isInPlace(opResult); } /// Return true if, under current bufferization decisions, the buffer of `value` diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -447,14 +447,6 @@ return true; } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - // CallOpInterface alone doesn't bufferize to a memory write, one of the - // uses of the matching bbArg may. It is the responsibility of the caller to - // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be - // conservative. - return true; - } - SmallVector getAliasingOpOperand(Operation *op, OpResult opResult) const { // TODO: Can we do better? diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -44,6 +44,17 @@ return true; } + // TODO: For better bufferization results, this could return `true` only if + // there is a memory write in the region. + bool isMemoryWrite(Operation *op, OpResult opResult) const { + // Similar to scf.if, results of this op are always considered memory writes + // in the analysis. This is a useful pattern for all ops that have tensor + // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is + // implemented in terms of `bufferizesToMemoryWrite`, which does not work on + // ops without OpOperands. + return true; + } + LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { // TODO: Add bufferization support when needed. scf.execute_region should be diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -1470,3 +1470,25 @@ return %r, %v2 : tensor, vector<10xf32> } +// ----- + +// CHECK-LABEL: func @some_use +func @some_use(%A : tensor {linalg.inplaceable = true}, + %v : vector<5xf32>) -> (tensor) { + %idx = arith.constant 0 : index + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %0 = vector.transfer_write %v, %A[%idx] : vector<5xf32>, tensor + return %0 : tensor +} + + +// CHECK-LABEL: func @main_func +func @main_func(%A : tensor {linalg.inplaceable = true}, + %v : vector<5xf32>) -> (tensor) { + // Function calls always bufferize out-of-place at the moment. + // CHECK: call + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %0 = call @some_use(%A, %v) : (tensor, vector<5xf32>) -> (tensor) + return %0 : tensor +}