diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp --- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp +++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp @@ -266,6 +266,17 @@ LogicalResult matchAndRewrite(WhileOp whileOp, PatternRewriter &rewriter) const override; }; + +/// Optimized version of the above for the case of the "after" region merely +/// forwarding its arguments back to the "before" region (i.e., a "do-while" +/// loop). This avoid inlining the "after" region completely and branches back +/// to the "before" entry instead. +struct DoWhileLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp whileOp, + PatternRewriter &rewriter) const override; +}; } // namespace LogicalResult ForLowering::matchAndRewrite(ForOp forOp, @@ -507,10 +518,60 @@ return success(); } +LogicalResult +DoWhileLowering::matchAndRewrite(WhileOp whileOp, + PatternRewriter &rewriter) const { + if (!llvm::hasSingleElement(whileOp.after())) + return rewriter.notifyMatchFailure(whileOp, + "do-while simplification applicable to " + "single-block 'after' region only"); + + Block &afterBlock = whileOp.after().front(); + if (!llvm::hasSingleElement(afterBlock)) + return rewriter.notifyMatchFailure(whileOp, + "do-while simplification applicable " + "only if 'after' region has no payload"); + + auto yield = dyn_cast(&afterBlock.front()); + if (!yield || yield.results() != afterBlock.getArguments()) + return rewriter.notifyMatchFailure(whileOp, + "do-while simplification applicable " + "only to forwarding 'after' regions"); + + // Split the current block before the WhileOp to create the inlining point. + OpBuilder::InsertionGuard guard(rewriter); + Block *currentBlock = rewriter.getInsertionBlock(); + Block *continuation = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + + // Only the "before" region should be inlined. + Block *before = &whileOp.before().front(); + Block *beforeLast = &whileOp.before().back(); + rewriter.inlineRegionBefore(whileOp.before(), continuation); + + // Branch to the "before" region. + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(whileOp.getLoc(), before, whileOp.inits()); + + // Loop around the "before" region based on condition. + rewriter.setInsertionPointToEnd(beforeLast); + auto condOp = cast(beforeLast->getTerminator()); + rewriter.replaceOpWithNewOp(condOp, condOp.condition(), before, + condOp.args(), continuation, + ValueRange()); + + // Replace the op with values "yielded" from the "before" region, which are + // visible by dominance. + rewriter.replaceOp(whileOp, condOp.args()); + + return success(); +} + void mlir::populateLoopToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert( ctx); + patterns.insert(ctx, /*benefit=*/2); } void SCFToStandardPass::runOnOperation() { diff --git a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir --- a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir @@ -424,6 +424,8 @@ scf.condition(%0) } do { // CHECK: ^[[AFTER]]: + // CHECK: "test.some_payload"() : () -> () + "test.some_payload"() : () -> () // CHECK: br ^[[BEFORE]] scf.yield } @@ -432,6 +434,25 @@ return } +// CHECK-LABEL: @do_while +func @do_while(%arg0: f32) { + // CHECK: br ^[[BEFORE:.*]]({{.*}}: f32) + scf.while (%arg1 = %arg0) : (f32) -> (f32) { + // CHECK: ^[[BEFORE]](%[[VAL:.*]]: f32): + // CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1 + %0 = "test.make_condition"() : () -> i1 + // CHECK: cond_br %[[COND]], ^[[BEFORE]](%[[VAL]] : f32), ^[[CONT:.*]] + scf.condition(%0) %arg1 : f32 + } do { + ^bb0(%arg2: f32): + // CHECK-NOT: br ^[[BEFORE]] + scf.yield %arg2 : f32 + } + // CHECK: ^[[CONT]]: + // CHECK: return + return +} + // CHECK-LABEL: @while_values // CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32) func @while_values(%arg0: i32, %arg1: f32) {