diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -304,6 +304,12 @@ ControlFlow operations will be replaced successfully. Otherwise a single ControlFlow switch branching to one block per return-like operation kind remains. + + This pass may need to create unreachable terminators in case of infinite + loops, which is only supported for 'func.func' for now. If you potentially + have infinite loops inside CFG regions not belonginh to 'func.func', + consider using `transformCFGToSCF` function directly with corresponding + `CFGToSCFInterface::createUnreachableTerminator` implementation. }]; let dependentDialects = ["scf::SCFDialect", diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp --- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp +++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp @@ -140,10 +140,11 @@ // TODO: This should create a `ub.unreachable` op. Once such an operation // exists to make the pass independent of the func dialect. For now just // return poison values. - auto funcOp = dyn_cast(region.getParentOp()); + Operation *parentOp = region.getParentOp(); + auto funcOp = dyn_cast(parentOp); if (!funcOp) - return emitError(loc, "Expected '") - << func::FuncOp::getOperationName() << "' as top level operation"; + return emitError(loc, "Cannot create unreachable terminator for '") + << parentOp->getName() << "'"; return builder .create( @@ -165,18 +166,29 @@ ControlFlowToSCFTransformation transformation; bool changed = false; - WalkResult result = getOperation()->walk([&](func::FuncOp funcOp) { + Operation *op = getOperation(); + WalkResult result = op->walk([&](func::FuncOp funcOp) { if (funcOp.getBody().empty()) return WalkResult::advance(); - FailureOr changedFunc = transformCFGToSCF( - funcOp.getBody(), transformation, - funcOp != getOperation() ? getChildAnalysis(funcOp) - : getAnalysis()); - if (failed(changedFunc)) + auto &domInfo = funcOp != op ? getChildAnalysis(funcOp) + : getAnalysis(); + + auto visitor = [&](Operation *innerOp) -> WalkResult { + for (Region ® : innerOp->getRegions()) { + FailureOr changedFunc = + transformCFGToSCF(reg, transformation, domInfo); + if (failed(changedFunc)) + return WalkResult::interrupt(); + + changed |= *changedFunc; + } + return WalkResult::advance(); + }; + + if (funcOp->walk(visitor).wasInterrupted()) return WalkResult::interrupt(); - changed |= *changedFunc; return WalkResult::advance(); }); if (result.wasInterrupted()) diff --git a/mlir/test/Conversion/ControlFlowToSCF/test.mlir b/mlir/test/Conversion/ControlFlowToSCF/test.mlir --- a/mlir/test/Conversion/ControlFlowToSCF/test.mlir +++ b/mlir/test/Conversion/ControlFlowToSCF/test.mlir @@ -678,3 +678,35 @@ // CHECK: scf.yield // CHECK: call @foo(%[[WHILE]]#1) // CHECK-NEXT: return + +// ----- + +func.func @nested_region() { + scf.execute_region { + %cond = "test.test1"() : () -> i1 + cf.cond_br %cond, ^bb1, ^bb2 + ^bb1: + "test.test2"() : () -> () + cf.br ^bb3 + ^bb2: + "test.test3"() : () -> () + cf.br ^bb3 + ^bb3: + "test.test4"() : () -> () + scf.yield + } + return +} + +// CHECK-LABEL: func @nested_region +// CHECK: scf.execute_region { +// CHECK: %[[COND:.*]] = "test.test1"() +// CHECK-NEXT: scf.if %[[COND]] +// CHECK-NEXT: "test.test2"() +// CHECK-NEXT: else +// CHECK-NEXT: "test.test3"() +// CHECK-NEXT: } +// CHECK-NEXT: "test.test4"() +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: return