diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp --- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp +++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp @@ -200,6 +200,71 @@ LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp, PatternRewriter &rewriter) const override; }; + +/// Create a CFG subgraph for this loop construct. The regions of the loop need +/// not be a single block anymore (for example, if other SCF constructs that +/// they contain have been already converted to CFG), but need to be single-exit +/// from the last block of each region. The operations following the original +/// WhileOp are split into a new continuation block. Both regions of the WhileOp +/// are inlined, and their terminators are rewritten to organize the control +/// flow implementing the loop as follows. +/// +/// +---------------------------------+ +/// | | +/// | br ^before(%operands...) | +/// +---------------------------------+ +/// | +/// -------| | +/// | v v +/// | +--------------------------------+ +/// | | ^before(%bargs...): | +/// | | %vals... = | +/// | +--------------------------------+ +/// | | +/// | ... +/// | | +/// | +--------------------------------+ +/// | | ^before-last: +/// | | %cond = | +/// | | cond_br %cond, | +/// | | ^after(%vals...), ^cont | +/// | +--------------------------------+ +/// | | | +/// | | -------------| +/// | v | +/// | +--------------------------------+ | +/// | | ^after(%aargs...): | | +/// | | | | +/// | +--------------------------------+ | +/// | | | +/// | ... | +/// | | | +/// | +--------------------------------+ | +/// | | ^after-last: | | +/// | | %yields... = | | +/// | | br ^before(%yields...) | | +/// | +--------------------------------+ | +/// | | | +/// |----------- |-------------------- +/// v +/// +--------------------------------+ +/// | ^cont: | +/// | | +/// | <%vals from 'before' region | +/// | visible by dominance> | +/// +--------------------------------+ +/// +/// Values are communicated between ex-regions through block arguments of their +/// entry blocks, which are visible in all other dominated blocks. Similarly, +/// the results of the WhileOp are defined in the 'before' region, which is +/// required to have a single existing block, and are therefore accessible in +/// the continuation block due to dominance. +struct WhileLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp whileOp, + PatternRewriter &rewriter) const override; +}; } // namespace LogicalResult ForLowering::matchAndRewrite(ForOp forOp, @@ -399,18 +464,61 @@ return success(); } +LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, + PatternRewriter &rewriter) const { + OpBuilder::InsertionGuard guard(rewriter); + Location loc = whileOp.getLoc(); + + // Split the current block before the WhileOp to create the inlining point. + Block *currentBlock = rewriter.getInsertionBlock(); + Block *continuation = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + + // Inline both regions. + Block *after = &whileOp.after().front(); + Block *afterLast = &whileOp.after().back(); + Block *before = &whileOp.before().front(); + Block *beforeLast = &whileOp.before().back(); + rewriter.inlineRegionBefore(whileOp.after(), continuation); + rewriter.inlineRegionBefore(whileOp.before(), after); + + // Branch to the "before" region. + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, before, whileOp.inits()); + + // Replace terminators with branches. Assuming bodies are SESE, which holds + // given only the patterns from this file, we only need to look at the last + // block. This should be reconsidered if we allow break/continue in SCF. + rewriter.setInsertionPointToEnd(beforeLast); + auto condOp = cast(beforeLast->getTerminator()); + rewriter.replaceOpWithNewOp(condOp, condOp.condition(), after, + condOp.args(), continuation, + ValueRange()); + + rewriter.setInsertionPointToEnd(afterLast); + auto yieldOp = cast(afterLast->getTerminator()); + rewriter.replaceOpWithNewOp(yieldOp, before, yieldOp.results()); + + // Replace the op with values "yielded" from the "before" region, which are + // visible by dominance. + rewriter.replaceOp(whileOp, condOp.args()); + + return success(); +} + void mlir::populateLoopToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert( + ctx); } void SCFToStandardPass::runOnOperation() { OwningRewritePatternList patterns; populateLoopToStdConversionPatterns(patterns, &getContext()); - // Configure conversion to lower out scf.for, scf.if and scf.parallel. - // Anything else is fine. + // Configure conversion to lower out scf.for, scf.if, scf.parallel and + // scf.while. Anything else is fine. ConversionTarget target(getContext()); - target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir --- a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir @@ -412,3 +412,108 @@ } return } + +// CHECK-LABEL: @minimal_while +func @minimal_while() { + // CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1 + // CHECK: br ^[[BEFORE:.*]] + %0 = "test.make_condition"() : () -> i1 + scf.while : () -> () { + // CHECK: ^[[BEFORE]]: + // CHECK: cond_br %[[COND]], ^[[AFTER:.*]], ^[[CONT:.*]] + scf.condition(%0) + } do { + // CHECK: ^[[AFTER]]: + // CHECK: br ^[[BEFORE]] + scf.yield + } + // CHECK: ^[[CONT]]: + // CHECK: return + return +} + +// CHECK-LABEL: @while_values +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32) +func @while_values(%arg0: i32, %arg1: f32) { + // CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1 + %0 = "test.make_condition"() : () -> i1 + %c0_i32 = constant 0 : i32 + %cst = constant 0.000000e+00 : f32 + // CHECK: br ^[[BEFORE:.*]](%[[ARG0]], %[[ARG1]] : i32, f32) + %1:2 = scf.while (%arg2 = %arg0, %arg3 = %arg1) : (i32, f32) -> (i64, f64) { + // CHECK: ^bb1(%[[ARG2:.*]]: i32, %[[ARG3:.]]: f32): + // CHECK: %[[VAL1:.*]] = zexti %[[ARG0]] : i32 to i64 + %2 = zexti %arg0 : i32 to i64 + // CHECK: %[[VAL2:.*]] = fpext %[[ARG3]] : f32 to f64 + %3 = fpext %arg3 : f32 to f64 + // CHECK: cond_br %[[COND]], ^[[AFTER:.*]](%[[VAL1]], %[[VAL2]] : i64, f64), ^[[CONT:.*]] + scf.condition(%0) %2, %3 : i64, f64 + } do { + // CHECK: ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64): + ^bb0(%arg2: i64, %arg3: f64): // no predecessors + // CHECK: br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32) + scf.yield %c0_i32, %cst : i32, f32 + } + // CHECK: ^bb3: + // CHECK: return + return +} + +// CHECK-LABEL: @nested_while_ops +func @nested_while_ops(%arg0: f32) -> i64 { + // CHECK: br ^[[OUTER_BEFORE:.*]](%{{.*}} : f32) + %0 = scf.while(%outer = %arg0) : (f32) -> i64 { + // CHECK: ^[[OUTER_BEFORE]](%{{.*}}: f32): + // CHECK: %[[OUTER_COND:.*]] = "test.outer_before_pre"() : () -> i1 + %cond = "test.outer_before_pre"() : () -> i1 + // CHECK: br ^[[INNER_BEFORE_BEFORE:.*]](%{{.*}} : f32) + %1 = scf.while(%inner = %outer) : (f32) -> i64 { + // CHECK: ^[[INNER_BEFORE_BEFORE]](%{{.*}}: f32): + // CHECK: %[[INNER1:.*]]:2 = "test.inner_before"(%{{.*}}) : (f32) -> (i1, i64) + %2:2 = "test.inner_before"(%inner) : (f32) -> (i1, i64) + // CHECK: cond_br %[[INNER1]]#0, ^[[INNER_BEFORE_AFTER:.*]](%[[INNER1]]#1 : i64), ^[[OUTER_BEFORE_LAST:.*]] + scf.condition(%2#0) %2#1 : i64 + } do { + // CHECK: ^[[INNER_BEFORE_AFTER]](%{{.*}}: i64): + ^bb0(%arg1: i64): + // CHECK: %[[INNER2:.*]] = "test.inner_after"(%{{.*}}) : (i64) -> f32 + %3 = "test.inner_after"(%arg1) : (i64) -> f32 + // CHECK: br ^[[INNER_BEFORE_BEFORE]](%[[INNER2]] : f32) + scf.yield %3 : f32 + } + // CHECK: ^[[OUTER_BEFORE_LAST]]: + // CHECK: "test.outer_before_post"() : () -> () + "test.outer_before_post"() : () -> () + // CHECK: cond_br %[[OUTER_COND]], ^[[OUTER_AFTER:.*]](%[[INNER1]]#1 : i64), ^[[CONTINUATION:.*]] + scf.condition(%cond) %1 : i64 + } do { + // CHECK: ^[[OUTER_AFTER]](%{{.*}}: i64): + ^bb2(%arg2: i64): + // CHECK: "test.outer_after_pre"(%{{.*}}) : (i64) -> () + "test.outer_after_pre"(%arg2) : (i64) -> () + // CHECK: br ^[[INNER_AFTER_BEFORE:.*]](%{{.*}} : i64) + %4 = scf.while(%inner = %arg2) : (i64) -> f32 { + // CHECK: ^[[INNER_AFTER_BEFORE]](%{{.*}}: i64): + // CHECK: %[[INNER3:.*]]:2 = "test.inner2_before"(%{{.*}}) : (i64) -> (i1, f32) + %5:2 = "test.inner2_before"(%inner) : (i64) -> (i1, f32) + // CHECK: cond_br %[[INNER3]]#0, ^[[INNER_AFTER_AFTER:.*]](%[[INNER3]]#1 : f32), ^[[OUTER_AFTER_LAST:.*]] + scf.condition(%5#0) %5#1 : f32 + } do { + // CHECK: ^[[INNER_AFTER_AFTER]](%{{.*}}: f32): + ^bb3(%arg3: f32): + // CHECK: %{{.*}} = "test.inner2_after"(%{{.*}}) : (f32) -> i64 + %6 = "test.inner2_after"(%arg3) : (f32) -> i64 + // CHECK: br ^[[INNER_AFTER_BEFORE]](%{{.*}} : i64) + scf.yield %6 : i64 + } + // CHECK: ^[[OUTER_AFTER_LAST]]: + // CHECK: "test.outer_after_post"() : () -> () + "test.outer_after_post"() : () -> () + // CHECK: br ^[[OUTER_BEFORE]](%[[INNER3]]#1 : f32) + scf.yield %4 : f32 + } + // CHECK: ^[[CONTINUATION]]: + // CHECK: return %{{.*}} : i64 + return %0 : i64 +} +