diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -355,7 +355,8 @@ /// traversed any further. /// /// When reaching the end of a chain (BlockArgument or Value without aliasing - /// OpOperands), also return the last Value of that chain. + /// OpOperands), also return the last Value of that chain if + /// `alwaysIncludeLeaves` is set. /// /// Example: /// @@ -377,7 +378,8 @@ SetVector findValueInReverseUseDefChain(Value value, llvm::function_ref condition, - bool followEquivalentOnly = false) const; + bool followEquivalentOnly = false, + bool alwaysIncludeLeaves = true) const; /// Find the Values of the last preceding write of a given Value. /// diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -120,8 +120,7 @@ result also bufferizes to a memory write. 3. At least one aliasing OpOperand's value is defined inside the - defining op of the given OpResult and it is a memory write or the - reverse SSA use-def chain ends in the defining op. + defining op of the given OpResult and it is a memory write. According to this rule, an aliasing OpOperand value that is defined inside this op and is bufferizing to a memory write makes the given diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -444,7 +444,7 @@ // further. llvm::SetVector AnalysisState::findValueInReverseUseDefChain( Value value, llvm::function_ref condition, - bool followEquivalentOnly) const { + bool followEquivalentOnly, bool alwaysIncludeLeaves) const { llvm::SetVector result, workingSet; workingSet.insert(value); @@ -469,7 +469,8 @@ (followEquivalentOnly && bufferizableOp.bufferRelation(opResult, *this) != BufferRelation::Equivalent)) { - result.insert(value); + if (alwaysIncludeLeaves) + result.insert(value); continue; } @@ -640,8 +641,7 @@ return true; // Case 3: Check if a nested aliasing OpOperand value bufferizes to a memory - // write. (Or: The reverse SSA use-def chain ends inside the reigon.) In that - // case, the OpResult bufferizes to a memory write. E.g.: + // write. In that case, the OpResult bufferizes to a memory write. E.g.: // // %0 = "some_writing_op" : tensor // %r = scf.if ... -> tensor { @@ -678,7 +678,8 @@ if (!state .findValueInReverseUseDefChain(operand->get(), isMemoryWriteInsideOp, - /*followEquivalentOnly=*/false) + /*followEquivalentOnly=*/false, + /*alwaysIncludeLeaves=*/false) .empty()) return true; } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -drop-equivalent-buffer-results -buffer-deallocation -split-input-file | FileCheck %s +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -drop-equivalent-buffer-results -split-input-file | FileCheck %s --check-prefix=CHECK-NO-DEALLOC-PASS // Run fuzzer with different seeds. // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=23 bufferize-function-boundaries" -split-input-file -o /dev/null @@ -925,3 +926,31 @@ } return } + +// ----- + +// This test does not compute anything meaningful but it tests that +// bufferizesToMemoryWrite is correctly propagated through regions. + +// CHECK-NO-DEALLOC-PASS-LABEL: func @elide_copy_of_non_writing_scf_if( +func.func @elide_copy_of_non_writing_scf_if(%c: i1, %p1: index, %p2: index, %f: f32) + -> (tensor<10xf32>, f32) +{ + %r = scf.if %c -> tensor<10xf32> { + // CHECK-NO-DEALLOC-PASS: memref.alloc + %t1 = bufferization.alloc_tensor() : tensor<10xf32> + scf.yield %t1 : tensor<10xf32> + } else { + // CHECK-NO-DEALLOC-PASS: memref.alloc + %t2 = bufferization.alloc_tensor() : tensor<10xf32> + scf.yield %t2 : tensor<10xf32> + } + // CHECK-NO-DEALLOC-PASS: memref.alloc + + // No copy should be inserted because %r does not bufferize to a memory write. + // I.e., %r does not have defined contents and the copy can be elided. + // CHECK-NO-DEALLOC-PASS-NOT: memref.copy + %r2 = tensor.insert %f into %r[%p1] : tensor<10xf32> + %r3 = tensor.extract %r[%p2] : tensor<10xf32> + return %r2, %r3 : tensor<10xf32>, f32 +}