diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp --- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp +++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp @@ -194,6 +194,13 @@ PatternRewriter &rewriter) const override; }; +struct ExecuteRegionLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const override; +}; + struct ParallelLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -400,6 +407,38 @@ return success(); } +LogicalResult +ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + + auto *condBlock = rewriter.getInsertionBlock(); + auto opPosition = rewriter.getInsertionPoint(); + auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); + + auto ®ion = op.region(); + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create(loc, ®ion.front()); + + for (Block &block : region) { + if (auto terminator = dyn_cast(block.getTerminator())) { + ValueRange terminatorOperands = terminator->getOperands(); + rewriter.setInsertionPointToEnd(&block); + rewriter.create(loc, remainingOpsBlock, terminatorOperands); + rewriter.eraseOp(terminator); + } + } + + rewriter.inlineRegionBefore(region, remainingOpsBlock); + + SmallVector vals; + for (auto arg : remainingOpsBlock->addArguments(op->getResultTypes())) { + vals.push_back(arg); + } + rewriter.replaceOp(op, vals); + return success(); +} + LogicalResult ParallelLowering::matchAndRewrite(ParallelOp parallelOp, PatternRewriter &rewriter) const { @@ -569,8 +608,8 @@ } void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns.add(patterns.getContext()); patterns.add(patterns.getContext(), /*benefit=*/2); } @@ -580,7 +619,8 @@ // Configure conversion to lower out scf.for, scf.if, scf.parallel and // scf.while. Anything else is fine. ConversionTarget target(getContext()); - target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir --- a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir @@ -587,3 +587,36 @@ // CHECK: return return } + +// CHECK-LABEL: func @func_execute_region_elim_multi_yield +func @func_execute_region_elim_multi_yield() { + "test.foo"() : () -> () + %v = scf.execute_region -> i64 { + %c = "test.cmp"() : () -> i1 + cond_br %c, ^bb2, ^bb3 + ^bb2: + %x = "test.val1"() : () -> i64 + scf.yield %x : i64 + ^bb3: + %y = "test.val2"() : () -> i64 + scf.yield %y : i64 + } + "test.bar"(%v) : (i64) -> () + return +} + +// CHECK-NOT: execute_region +// CHECK: "test.foo" +// CHECK: br ^[[rentry:.+]] +// CHECK: ^[[rentry]] +// CHECK: %[[cmp:.+]] = "test.cmp" +// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]] +// CHECK: ^[[bb1]]: +// CHECK: %[[x:.+]] = "test.val1" +// CHECK: br ^[[bb3:.+]](%[[x]] : i64) +// CHECK: ^[[bb2]]: +// CHECK: %[[y:.+]] = "test.val2" +// CHECK: br ^[[bb3]](%[[y:.+]] : i64) +// CHECK: ^[[bb3]](%[[z:.+]]: i64): +// CHECK: "test.bar"(%[[z]]) +// CHECK: return