diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -461,10 +461,27 @@ scf::ForOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of - // its matching bbArg may. - auto forOp = cast(op); - return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); + // Tensor iter_args of scf::ForOps are always considered as a read. + // + // Note: bufferizesToMemoryRead + bufferizesToMemoryWrite could analyze the + // loop body to figure out if the bbArg is read/written. However, this is + // tricky because the loop may have zero iterations. This can trigger a + // copy elision. E.g.: + // + // %t1 = ... + // %t2 = ... + // scf.for ... iter_args(%0 = %t1) -> tensor { + // "write_only"(%0) + // scf.yield %t2 + // } + // + // If the loop has more than 0 iterations, the init_arg operand bufferizes + // to a memory write (but not a read). If the loop has 0 iterations, the + // init_arg operand bufferizes neither to a read nor to a write. It simply + // creates an alias. Since we cannot accurately determine the number of + // iterations in many cases, it is safest to consider the iter_args as both + // a read and a write. + return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, @@ -1048,10 +1065,9 @@ ForeachThreadOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - // scf::ForeachThreadOp alone doesn't bufferize to a memory read, one of the - // uses of its matching bbArg may. - auto foreachThreadOp = cast(op); - return state.isValueRead(foreachThreadOp.getTiedBlockArgument(&opOperand)); + // Outputs of scf::ForeachThreadOps are always considered as a read. + // See ForOpInterface for more details. + return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir @@ -163,58 +163,6 @@ // ----- -#accesses = [ - affine_map<(i) -> (i)> -] -#trait = { - indexing_maps = #accesses, - iterator_types = ["parallel"] -} - -// CHECK-LABEL: func @non_reading_scf_for -func.func @non_reading_scf_for(%t1: tensor {bufferization.writable = true}, - %s: index, %v: vector<5xf32>) -> (tensor, vector<5xf32>) { - - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.0 : f32 - - // Write to %t1. - // CHECK: vector.transfer_write - // CHECK-SAME: __inplace_operands_attr__ = ["none", "true", "none"] - %t3 = vector.transfer_write %v, %t1[%s] : vector<5xf32>, tensor - - // This loop does not read from %t1. It only writes to it. - // CHECK: scf.for - %r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor, vector<5xf32>) { - // Write to %t1 via %t2. (Overwrite %t3.) - // CHECK: linalg.generic - // CHECK-SAME: __inplace_operands_attr__ = ["true"] - %o2 = linalg.generic #trait outs (%t2 : tensor) { - ^bb(%0: f32) : - linalg.yield %cst : f32 - } -> (tensor) - - // Read overwritten value. This is not a read of %t1. - %v2 = vector.transfer_read %o2[%s], %cst : tensor, vector<5xf32> - scf.yield %o2, %v2 : tensor, vector<5xf32> - } - - // Use %t3 in some way without reading it, so that it does not get DCE'd. - // CHECK: linalg.generic - // CHECK-SAME: __inplace_operands_attr__ = ["true"] - %o = linalg.generic #trait outs (%t3 : tensor) { - ^bb(%0: f32) : - linalg.yield %cst : f32 - } -> (tensor) - - // CHECK: return - // CHECK-SAME: __equivalent_func_args__ = [0, -1] - return %o, %v3 : tensor, vector<5xf32> -} - -// ----- - //===----------------------------------------------------------------------===// // scf.if cases //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -38,6 +38,32 @@ // ----- +// CHECK-LABEL: func @scf_for_is_reading( +// CHECK-SAME: %[[A:.*]]: memref>, %[[B:.*]]: memref> +func.func @scf_for_is_reading(%A : tensor, %B : tensor, + %lb : index, %ub : index) + -> (f32, f32) +{ + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + + // This is a regression test to make sure that an alloc + copy is emitted. + + // CHECK: %[[alloc:.*]] = memref.alloc + // CHECK: memref.copy %[[A]], %[[alloc]] + // CHECK: %[[clone:.*]] = bufferization.clone %[[alloc]] + // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[clone]]) + %0 = scf.for %iv = %lb to %ub step %c1 iter_args(%1 = %A) -> tensor { + %r = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor + scf.yield %B : tensor + } + %1 = tensor.extract %0[%c1] : tensor + %2 = tensor.extract %A[%c1] : tensor + return %1, %2 : f32, f32 +} + +// ----- + // Ensure that the function bufferizes without error. This tests pre-order // traversal of scf.for loops during bufferization. No need to check the IR, // just want to make sure that it does not crash.