diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -412,6 +412,8 @@ // 1) The op yields the iter arguments. // 2) The iter arguments have no use and the corresponding outer region // iterators (inputs) are yielded. +// 3) The iter arguments have no use and the corresponding (operation) results +// have no use. // // These arguments must be defined outside of // the ForOp region and can just be forwarded after simplifying the op inits, @@ -444,15 +446,19 @@ newResultValues.reserve(forOp.getNumResults()); for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside forOp.getRegionIterArgs(), // iter inside region + forOp.getResults(), // op results yieldOp.getOperands() // iter yield )) { // Forwarded is `true` when: // 1) The region `iter` argument is yielded. - // 2) The region `iter` argument has zero use, and the corresponding iter + // 2) The region `iter` argument has no use, and the corresponding iter // operand (input) is yielded. - bool forwarded = - ((std::get<1>(it) == std::get<2>(it)) || - (std::get<1>(it).use_empty() && std::get<0>(it) == std::get<2>(it))); + // 3) The region `iter` argument has no use, and the corresponding op + // result has no use. + bool forwarded = ((std::get<1>(it) == std::get<3>(it)) || + (std::get<1>(it).use_empty() && + (std::get<0>(it) == std::get<3>(it) || + std::get<2>(it).use_empty()))); keepMask.push_back(!forwarded); canonicalize |= forwarded; if (forwarded) { @@ -461,7 +467,7 @@ continue; } newIterArgs.push_back(std::get<0>(it)); - newYieldValues.push_back(std::get<2>(it)); + newYieldValues.push_back(std::get<3>(it)); newBlockTransferArgs.push_back(Value()); // placeholder with null value newResultValues.push_back(Value()); // placeholder with null value } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -402,3 +402,21 @@ // CHECK: return %[[FOR_RES]], %[[C32]] : i32, i32 return %0#0, %0#1 : i32, i32 } + +// ----- + +// CHECK-LABEL: fold_away_iter_and_result_with_no_use +// CHECK-SAME: %[[A0:[0-9a-z]*]]: i32 +func @fold_away_iter_and_result_with_no_use(%arg0 : i32, + %ub : index, %lb : index, %step : index) -> (i32) { + %cst = constant 32 : i32 + // CHECK: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) { + %0:2 = scf.for %arg1 = %lb to %ub step %step iter_args(%arg2 = %arg0, %arg3 = %cst) + -> (i32, i32) { + %1 = addi %arg2, %cst : i32 + scf.yield %1, %1 : i32, i32 + } + + // CHECK: return %[[FOR_RES]] : i32 + return %0#0 : i32 +}