diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -290,6 +290,14 @@ LogicalResult matchAndRewrite(WhileOp whileOp, PatternRewriter &rewriter) const override; }; + +/// Lower an `scf.index_switch` operation to a `cf.switch` operation. +struct IndexSwitchLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IndexSwitchOp op, + PatternRewriter &rewriter) const override; +}; } // namespace LogicalResult ForLowering::matchAndRewrite(ForOp forOp, @@ -615,10 +623,68 @@ return success(); } +LogicalResult +IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, + PatternRewriter &rewriter) const { + // Split the block at the op. + Block *condBlock = rewriter.getInsertionBlock(); + Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op)); + + // Create the arguments on the continue block with which to replace the + // results of the op. + SmallVector results; + results.reserve(op.getNumResults()); + for (Type resultType : op.getResultTypes()) + results.push_back(continueBlock->addArgument(resultType, op.getLoc())); + + // Handle the regions. + auto convertRegion = [&](Region ®ion) -> FailureOr { + Block *block = ®ion.front(); + + // Convert the yield terminator to a branch to the continue block. + auto yield = cast(block->getTerminator()); + rewriter.setInsertionPoint(yield); + rewriter.replaceOpWithNewOp(yield, continueBlock, + yield.getOperands()); + + // Inline the region. + rewriter.inlineRegionBefore(region, continueBlock); + return block; + }; + + // Convert the case regions. + SmallVector caseSuccessors; + SmallVector caseValues; + caseSuccessors.reserve(op.getCases().size()); + caseValues.reserve(op.getCases().size()); + for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) { + FailureOr block = convertRegion(region); + if (failed(block)) + return failure(); + caseSuccessors.push_back(*block); + caseValues.push_back(value); + } + + // Convert the default region. + FailureOr defaultBlock = convertRegion(op.getDefaultRegion()); + if (failed(defaultBlock)) + return failure(); + + // Create the switch. + rewriter.setInsertionPointToEnd(condBlock); + SmallVector caseOperands(caseSuccessors.size(), {}); + rewriter.create( + op.getLoc(), op.getArg(), *defaultBlock, ValueRange(), + rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands); + rewriter.replaceOp(op, continueBlock->getArguments()); + return success(); +} + void mlir::populateSCFToControlFlowConversionPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + ExecuteRegionLowering, IndexSwitchLowering>( + patterns.getContext()); patterns.add(patterns.getContext(), /*benefit=*/2); } diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir --- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir @@ -473,7 +473,7 @@ scf.condition(%0) %2, %3 : i64, f64 } do { // CHECK: ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64): - ^bb0(%arg2: i64, %arg3: f64): + ^bb0(%arg2: i64, %arg3: f64): // CHECK: cf.br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32) scf.yield %c0_i32, %cst : i32, f32 } @@ -620,3 +620,30 @@ // CHECK: ^[[bb3]](%[[z:.+]]: i64): // CHECK: "test.bar"(%[[z]]) // CHECK: return + +// SWITCH-LABEL: @index_switch +func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 { + // SWITCH: cf.switch %arg0 : index + // SWITCH-NEXT: default: ^bb3 + // SWITCH-NEXT: 0: ^bb1 + // SWITCH-NEXT: 1: ^bb2 + %0 = scf.index_switch %i -> i32 + // SWITCH: ^bb1: + case 0 { + // SWITCH-NEXT: llvm.br ^bb4(%arg1 + scf.yield %a : i32 + } + // SWITCH: ^bb2: + case 1 { + // SWITCH-NEXT: llvm.br ^bb4(%arg2 + scf.yield %b : i32 + } + // SWITCH: ^bb3: + default { + // SWITCH-NEXT: llvm.br ^bb4(%arg3 + scf.yield %c : i32 + } + // SWITCH: ^bb4(%[[V:.*]]: i32 + // SWITCH-NEXT: return %[[V]] + return %0 : i32 +}