diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -408,7 +408,9 @@ } namespace { -// Fold away ForOp iter arguments that are also yielded by the op. +// Fold away ForOp iter arguments that are also yielded by the op or +// iter arguments that have no use and the corresponding outer region +// iterator (input) is yielded. // These arguments must be defined outside of the ForOp region and can just be // forwarded after simplifying the op inits, yields and returns. // @@ -441,8 +443,13 @@ forOp.getRegionIterArgs(), // iter inside region yieldOp.getOperands() // iter yield )) { - // Forwarded is `true` when the region `iter` argument is yielded. - bool forwarded = (std::get<1>(it) == std::get<2>(it)); + // Forwarded is `true` when: + // 1) The region `iter` argument is yielded. + // 2) The region `iter` argument has zero use and the iterator from + // outside the region (input) is yielded. + bool forwarded = + ((std::get<1>(it) == std::get<2>(it)) || + (std::get<1>(it).use_empty() && std::get<0>(it) == std::get<2>(it))); keepMask.push_back(!forwarded); canonicalize |= forwarded; if (forwarded) { @@ -483,7 +490,7 @@ "unexpected argument size mismatch"); // No results case: the scf::ForOp builder already created a zero - // reult terminator. Merge before this terminator and just get rid of the + // result terminator. Merge before this terminator and just get rid of the // original terminator that has been merged in. if (newIterArgs.empty()) { auto newYieldOp = cast(newBlock.getTerminator()); diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -335,6 +335,7 @@ } // ----- + func private @process(%0 : memref<128x128xf32>) func private @process_tensor(%0 : tensor<128x128xf32>) -> memref<128x128xf32> @@ -382,3 +383,22 @@ // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> } + +// ----- + +// CHECK-LABEL: fold_away_iter_with_no_use_and_yielded_input +// CHECK-SAME: %[[A0:[0-9a-z]*]]: i32 +func @fold_away_iter_with_no_use_and_yielded_input(%arg0 : i32, + %ub : index, %lb : index, %step : index) -> (i32, i32) { + // CHECK-NEXT: %[[C32:.*]] = constant 32 : i32 + %cst = constant 32 : i32 + // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) { + %0:2 = scf.for %arg1 = %lb to %ub step %step iter_args(%arg2 = %arg0, %arg3 = %cst) + -> (i32, i32) { + %1 = addi %arg2, %cst : i32 + scf.yield %1, %cst : i32, i32 + } + + // CHECK: return %[[FOR_RES]], %[[C32]] : i32, i32 + return %0#0, %0#1 : i32, i32 +}