diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp --- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp +++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp @@ -165,18 +165,29 @@ ControlFlowToSCFTransformation transformation; bool changed = false; - WalkResult result = getOperation()->walk([&](func::FuncOp funcOp) { + Operation *op = getOperation(); + WalkResult result = op->walk([&](func::FuncOp funcOp) { if (funcOp.getBody().empty()) return WalkResult::advance(); - FailureOr changedFunc = transformCFGToSCF( - funcOp.getBody(), transformation, - funcOp != getOperation() ? getChildAnalysis(funcOp) - : getAnalysis()); - if (failed(changedFunc)) + auto &domInfo = funcOp != op ? getChildAnalysis(funcOp) + : getAnalysis(); + + auto visitor = [&](Operation *innerOp) -> WalkResult { + for (Region ® : innerOp->getRegions()) { + FailureOr changedFunc = + transformCFGToSCF(reg, transformation, domInfo); + if (failed(changedFunc)) + return WalkResult::interrupt(); + + changed |= *changedFunc; + } + return WalkResult::advance(); + }; + + if (funcOp->walk(visitor).wasInterrupted()) return WalkResult::interrupt(); - changed |= *changedFunc; return WalkResult::advance(); }); if (result.wasInterrupted()) diff --git a/mlir/test/Conversion/ControlFlowToSCF/test.mlir b/mlir/test/Conversion/ControlFlowToSCF/test.mlir --- a/mlir/test/Conversion/ControlFlowToSCF/test.mlir +++ b/mlir/test/Conversion/ControlFlowToSCF/test.mlir @@ -678,3 +678,33 @@ // CHECK: scf.yield // CHECK: call @foo(%[[WHILE]]#1) // CHECK-NEXT: return + +// ----- + +func.func @nested_region() { + scf.execute_region { + %cond = "test.test1"() : () -> i1 + cf.cond_br %cond, ^bb1, ^bb2 + ^bb1: + "test.test2"() : () -> () + cf.br ^bb3 + ^bb2: + "test.test3"() : () -> () + cf.br ^bb3 + ^bb3: + "test.test4"() : () -> () + } + return +} + +// CHECK-LABEL: func @nested_region +// CHECK: scf.execute_region { +// CHECK: %[[COND:.*]] = "test.test1"() +// CHECK-NEXT: scf.if %[[COND]] +// CHECK-NEXT: "test.test2"() +// CHECK-NEXT: else +// CHECK-NEXT: "test.test3"() +// CHECK-NEXT: } +// CHECK-NEXT: "test.test4"() +// CHECK-NEXT: } +// CHECK-NEXT: return