diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -454,6 +454,15 @@ yieldedRanked.getMemorySpaceAsInt()); } +/// Return `true` if the given loop may have 0 iterations. +bool mayHaveZeroIterations(scf::ForOp forOp) { + Optional lb = getConstantIntValue(forOp.getLowerBound()); + Optional ub = getConstantIntValue(forOp.getUpperBound()); + if (!lb.has_value() || !ub.has_value()) + return true; + return *ub <= *lb; +} + /// Bufferization of scf.for. Replace with a new scf.for that operates on /// memrefs. struct ForOpInterface @@ -461,9 +470,15 @@ scf::ForOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { + auto forOp = cast(op); + + // If the loop has zero iterations, the results of the op are their + // corresponding init_args, meaning that the init_args bufferize to a read. + if (mayHaveZeroIterations(forOp)) + return true; + // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of // its matching bbArg may. - auto forOp = cast(op); return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); } @@ -1039,6 +1054,19 @@ } }; +/// Return `true` if the given loop may have 0 iterations. +bool mayHaveZeroIterations(scf::ForeachThreadOp foreachThreadOp) { + int64_t p = 1; + for (Value v : foreachThreadOp.getNumThreads()) { + if (Optional c = getConstantIntValue(v)) { + p *= *c; + } else { + return true; + } + } + return p == 0; +} + /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the /// region. There are op interfaces for the terminators (PerformConcurrentlyOp /// and ParallelInsertSliceOp), but these are only used during analysis. Not @@ -1048,9 +1076,16 @@ ForeachThreadOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { + auto foreachThreadOp = cast(op); + + // If the loop has zero iterations, the results of the op are their + // corresponding shared_outs, meaning that the shared_outs bufferize to a + // read. + if (mayHaveZeroIterations(foreachThreadOp)) + return true; + // scf::ForeachThreadOp alone doesn't bufferize to a memory read, one of the // uses of its matching bbArg may. - auto foreachThreadOp = cast(op); return state.isValueRead(foreachThreadOp.getTiedBlockArgument(&opOperand)); } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir @@ -177,6 +177,7 @@ %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index %cst = arith.constant 0.0 : f32 // Write to %t1. @@ -186,7 +187,7 @@ // This loop does not read from %t1. It only writes to it. // CHECK: scf.for - %r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor, vector<5xf32>) { + %r, %v3 = scf.for %i = %c0 to %c10 step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor, vector<5xf32>) { // Write to %t1 via %t2. (Overwrite %t3.) // CHECK: linalg.generic // CHECK-SAME: __inplace_operands_attr__ = ["true"] diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -38,6 +38,32 @@ // ----- +// CHECK-LABEL: func @scf_for_is_reading( +// CHECK-SAME: %[[A:.*]]: memref>, %[[B:.*]]: memref> +func.func @scf_for_is_reading(%A : tensor, %B : tensor, + %lb : index, %ub : index) + -> (f32, f32) +{ + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + + // This is a regression test to make sure that an alloc + copy is emitted. + + // CHECK: %[[alloc:.*]] = memref.alloc + // CHECK: memref.copy %[[A]], %[[alloc]] + // CHECK: %[[clone:.*]] = bufferization.clone %[[alloc]] + // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[clone]]) + %0 = scf.for %iv = %lb to %ub step %c1 iter_args(%1 = %A) -> tensor { + %r = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor + scf.yield %B : tensor + } + %1 = tensor.extract %0[%c1] : tensor + %2 = tensor.extract %A[%c1] : tensor + return %1, %2 : f32, f32 +} + +// ----- + // Ensure that the function bufferizes without error. This tests pre-order // traversal of scf.for loops during bufferization. No need to check the IR, // just want to make sure that it does not crash.