Index: mlir/lib/Dialect/SCF/SCF.cpp =================================================================== --- mlir/lib/Dialect/SCF/SCF.cpp +++ mlir/lib/Dialect/SCF/SCF.cpp @@ -685,8 +685,9 @@ } }; -/// Rewriting pattern that erases loops that are known not to iterate and -/// replaces single-iteration loops with their bodies. +/// Rewriting pattern that erases loops that are known not to iterate, replaces +/// single-iteration loops with their bodies, and removes empty loops that +/// iterate at least once and only return values defined outside of the loop. struct SimplifyTrivialLoops : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -728,7 +729,18 @@ return success(); } - return failure(); + Block &block = op.getRegion().front(); + if (!llvm::hasSingleElement(block)) + return failure(); + // If the loop is empty, iterates at least once, and only returns values + // defined outside of the loop, remove it. + auto yieldOp = cast(block.getTerminator()); + auto yieldOperands = yieldOp.getOperands(); + for (auto operand : yieldOperands) + if (!op.isDefinedOutsideOfLoop(operand)) + return failure(); + rewriter.replaceOp(op, yieldOperands); + return success(); } }; Index: mlir/test/Dialect/SCF/canonicalize.mlir =================================================================== --- mlir/test/Dialect/SCF/canonicalize.mlir +++ mlir/test/Dialect/SCF/canonicalize.mlir @@ -361,6 +361,22 @@ // CHECK-NEXT: } // CHECK-NEXT: return %[[a]], %[[r1]], %[[b]] : i32, i32, i32 +func @for_yields_4() -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %a = arith.constant 3 : i32 + %b = call @make_i32() : () -> (i32) + %r = scf.for %i = %c0 to %c2 step %c1 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r : i32 +} + +// CHECK-LABEL: func @for_yields_4 +// CHECK-NEXT: %[[b:.*]] = call @make_i32() : () -> i32 +// CHECK-NEXT: return %[[b]] : i32 + // ----- // CHECK-LABEL: @replace_true_if