diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2030,6 +2030,77 @@ let hasFolder = 1; } + +//===----------------------------------------------------------------------===// +// SwitchOp +//===----------------------------------------------------------------------===// + +def SwitchOp : Std_Op<"switch", + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + NoSideEffect, Terminator]> { + let summary = "switch operation"; + let description = [{ + The `switch` terminator operation represents a siwtch on a 32-bit integer + value. If the value matches one of the specified cases, then the + corresponding destination is jumped to. If the value does not match any of + the cases, the default destination is jumped to. The count and types of + operands must align with the arguments in the corresponding target blocks. + + Example: + + ```mlir + switch %flag, ^bb1(%a : i32) [ + 42: ^bb1(%b : i32), + 43: ^bb3(%c : i32) + ] + ``` + }]; + + let arguments = (ins I32:$value, + Variadic:$defaultOperands, + Variadic:$caseOperands, + OptionalAttr:$case_values, + OptionalAttr:$case_operand_offsets); + let successors = (successor + AnySuccessor:$defaultDestination, + VariadicSuccessor:$caseDestinations); + let builders = [ + OpBuilder<(ins "Value":$value, + "Block *":$defaultDestination, + "ValueRange":$defaultOperands, + CArg<"ArrayRef", "{}">:$caseValues, + CArg<"BlockRange", "{}">:$caseDestinations, + CArg<"ArrayRef", "{}">:$caseOperands)>, + OpBuilder<(ins "Value":$value, + "Block *":$defaultDestination, + "ValueRange":$defaultOperands, + CArg<"DenseIntElementsAttr", "{}">:$caseValues, + CArg<"BlockRange", "{}">:$caseDestinations, + CArg<"ArrayRef", "{}">:$caseOperands)> + ]; + + let assemblyFormat = [{ + $value `,` + $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)? + `[` `\n` custom($case_values, $caseDestinations, + $caseOperands, type($caseOperands), + $case_operand_offsets) `]` + attr-dict + }]; + + let extraClassDeclaration = [{ + /// Return the operands for the case destination block at the given index. + OperandRange getCaseOperands(unsigned index); + + /// Return a mutable range of operands for the case destination block at the + /// given index. + MutableOperandRange getCaseOperandsMutable(unsigned index); + }]; + + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // TruncateIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -25,6 +25,7 @@ #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" @@ -2160,6 +2161,437 @@ SubTensorInsertOpCastFolder>(context); } +//===----------------------------------------------------------------------===// +// SwitchOp +//===----------------------------------------------------------------------===// + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, + Block *defaultDestination, ValueRange defaultOperands, + DenseIntElementsAttr caseValues, + BlockRange caseDestinations, + ArrayRef caseOperands) { + SmallVector flattenedCaseOperands; + SmallVector caseOperandOffsets; + int32_t offset = 0; + for (ValueRange operands : caseOperands) { + flattenedCaseOperands.append(operands.begin(), operands.end()); + caseOperandOffsets.push_back(offset); + offset += operands.size(); + } + DenseIntElementsAttr caseOperandOffsetsAttr; + if (!caseOperandOffsets.empty()) + caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets); + + build(builder, result, value, defaultOperands, flattenedCaseOperands, + caseValues, caseOperandOffsetsAttr, defaultDestination, + caseDestinations); +} + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, + Block *defaultDestination, ValueRange defaultOperands, + ArrayRef caseValues, BlockRange caseDestinations, + ArrayRef caseOperands) { + DenseIntElementsAttr caseValuesAttr; + if (!caseValues.empty()) + caseValuesAttr = builder.getI32VectorAttr(caseValues); + build(builder, result, value, defaultDestination, defaultOperands, + caseValuesAttr, caseDestinations, caseOperands); +} + +/// ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? +/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? +static ParseResult +parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues, + SmallVectorImpl &caseDestinations, + SmallVectorImpl &caseOperands, + SmallVectorImpl &caseOperandTypes, + ElementsAttr &caseOperandOffsets) { + SmallVector values; + SmallVector offsets; + int32_t value, offset = 0; + do { + OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); + if (values.empty() && !integerParseResult.hasValue()) + return success(); + + if (!integerParseResult.hasValue() || integerParseResult.getValue()) + return failure(); + values.push_back(value); + + Block *destination; + SmallVector operands; + if (parser.parseColon() || parser.parseSuccessor(destination)) + return failure(); + if (!parser.parseOptionalLParen()) { + if (parser.parseRegionArgumentList(operands) || + parser.parseColonTypeList(caseOperandTypes) || parser.parseRParen()) + return failure(); + } + caseDestinations.push_back(destination); + caseOperands.append(operands.begin(), operands.end()); + offsets.push_back(offset); + offset += operands.size(); + } while (!parser.parseOptionalComma()); + + Builder &builder = parser.getBuilder(); + caseValues = builder.getI32VectorAttr(values); + caseOperandOffsets = builder.getI32VectorAttr(offsets); + + return success(); +} + +static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, + ElementsAttr caseValues, + SuccessorRange caseDestinations, + OperandRange caseOperands, + TypeRange caseOperandTypes, + ElementsAttr caseOperandOffsets) { + if (!caseValues) + return; + + size_t index = 0; + llvm::interleave( + llvm::zip(caseValues.cast(), caseDestinations), + [&](auto i) { + p << " "; + p << std::get<0>(i).getLimitedValue(); + p << ": "; + p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++)); + }, + [&] { + p << ','; + p.printNewline(); + }); + p.printNewline(); +} + +static LogicalResult verify(SwitchOp op) { + if ((!op.case_values() && !op.caseDestinations().empty()) || + (op.case_values() && + op.case_values()->size() != + static_cast(op.caseDestinations().size()))) + return op.emitOpError() + << "expects number of case values (" << op.case_values()->size() + << ") to match number of " + "case destinations (" + << op.caseDestinations().size() << ")"; + return success(); +} + +OperandRange SwitchOp::getCaseOperands(unsigned index) { + return getCaseOperandsMutable(index); +} + +MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) { + MutableOperandRange caseOperands = caseOperandsMutable(); + if (!case_operand_offsets()) { + assert(caseOperands.size() == 0 && + "non-empty case operands must have offsets"); + return caseOperands; + } + + ElementsAttr offsets = case_operand_offsets().getValue(); + assert(index < offsets.size() && "invalid case operand offset index"); + + int64_t begin = offsets.getValue(index).cast().getInt(); + int64_t end = index + 1 == offsets.size() + ? caseOperands.size() + : offsets.getValue(index + 1).cast().getInt(); + return caseOperandsMutable().slice(begin, end - begin); +} + +Optional +SwitchOp::getMutableSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return index == 0 ? defaultOperandsMutable() + : getCaseOperandsMutable(index - 1); +} + +Block *SwitchOp::getSuccessorForOperands(ArrayRef operands) { + if (!case_values()) + return defaultDestination(); + + if (int64_t value = + operands.front().dyn_cast_or_null().getSInt()) { + for (unsigned i = 0; i < case_values()->size(); ++i) { + if (value == case_values()->getValue(i)) { + return caseDestinations()[i]; + } + } + return defaultDestination(); + } + return nullptr; +} + +/// switch %value, ^bb1 [] +/// -> br ^bb1 +static LogicalResult SimplifySwitchWithOnlyDefault(SwitchOp op, + PatternRewriter &rewriter) { + if (!op.caseDestinations().empty()) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.defaultDestination(), + op.defaultOperands()); + return success(); +} + +/// switch %value, ^bb1 [ +/// 42: ^bb1, +/// 43: ^bb2 +/// ] +/// -> +/// switch %value, ^bb1 [ +/// 43: ^bb2 +/// ] +static LogicalResult +DropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { + SmallVector newCaseDestinations; + SmallVector newCaseOperands; + SmallVector newCaseValues; + bool requiresChange = false; + + for (unsigned i = 0; i < op.caseDestinations().size(); ++i) { + if (op.caseDestinations()[i] == op.defaultDestination() && + op.getCaseOperands(i) == op.defaultOperands()) { + requiresChange = true; + continue; + } + newCaseDestinations.push_back(op.caseDestinations()[i]); + newCaseOperands.push_back(op.getCaseOperands(i)); + newCaseValues.push_back(op.case_values() + ->getValue(i) + .cast() + .getValue() + .getSExtValue()); + } + + if (!requiresChange) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.value(), op.defaultDestination(), + op.defaultOperands(), newCaseValues, + newCaseDestinations, newCaseOperands); + return success(); +} + +/// Helper for folding a switch with a constant value. +/// switch %c_42, ^bb1 [ +/// 42: ^bb2, +/// 43: ^bb3 +/// ] +/// -> br ^bb2 +static void FoldSwitch(SwitchOp op, PatternRewriter &rewriter, + APInt caseValue) { + for (unsigned i = 0; i < op.case_values()->size(); ++i) { + if (op.case_values()->getValue(i) == caseValue) { + rewriter.replaceOpWithNewOp(op, op.caseDestinations()[i], + op.getCaseOperands(i)); + return; + } + } + rewriter.replaceOpWithNewOp(op, op.defaultDestination(), + op.defaultOperands()); +} + +/// switch %c_42, ^bb1 [ +/// 42: ^bb2, +/// 43: ^bb3 +/// ] +/// -> br ^bb2 +static LogicalResult SimplifyConstSwitchValue(SwitchOp op, + PatternRewriter &rewriter) { + APInt caseValue; + if (!matchPattern(op.value(), m_ConstantInt(&caseValue))) + return failure(); + + FoldSwitch(op, rewriter, caseValue); + return success(); +} + +/// switch %c_42, ^bb1 [ +/// 42: ^bb2, +/// ] +/// ^bb2: +/// br ^bb3 +/// -> +/// switch %c_42, ^bb1 [ +/// 42: ^bb3, +/// ] +static LogicalResult SimplifyPassThroughSwitch(SwitchOp op, + PatternRewriter &rewriter) { + + SmallVector newCaseDests; + SmallVector newCaseOperands; + bool requiresChange = false; + for (unsigned i = 0; i < op.caseDestinations().size(); ++i) { + Block *caseDest = op.caseDestinations()[i]; + ValueRange caseOperands = op.getCaseOperands(i); + SmallVector argStorage; + if (succeeded(collapseBranch(caseDest, caseOperands, argStorage))) + requiresChange = true; + + newCaseDests.push_back(caseDest); + newCaseOperands.push_back(caseOperands); + } + + Block *defaultDest = op.defaultDestination(); + ValueRange defaultOperands = op.defaultOperands(); + SmallVector argStorage; + + if (succeeded(collapseBranch(defaultDest, defaultOperands, argStorage))) + requiresChange = true; + + if (!requiresChange) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.value(), defaultDest, defaultOperands, op.case_values().getValue(), + newCaseDests, newCaseOperands); + return success(); +} + +/// switch %value, ^bb1 [ +/// 42: ^bb2, +/// ] +/// ^bb2: +/// switch %value, ^bb3 [ +/// 42: ^bb4 +/// ] +/// -> +/// switch %value, ^bb1 [ +/// 42: ^bb2, +/// ] +/// ^bb2: +/// br ^bb4 +/// +/// switch %value, ^bb1 [ +/// 42: ^bb2, +/// ] +/// ^bb2: +/// switch %value, ^bb3 [ +/// 43: ^bb4 +/// ] +/// -> +/// switch %value, ^bb1 [ +/// 42: ^bb2, +/// ] +/// ^bb2: +/// br ^bb3 +static LogicalResult +SimplifySwitchFromSwitchOnSameCondition(SwitchOp op, + PatternRewriter &rewriter) { + // Check that we have a single distinct predecessor. + Block *currentBlock = op->getBlock(); + Block *predecessor = currentBlock->getSinglePredecessor(); + if (!predecessor) + return failure(); + + // Check that the predecessor terminates with a switch branch to this block + // and that it branches on the same condition and that this branch isn't the + // default destination. + auto predSwitch = dyn_cast(predecessor->getTerminator()); + if (!predSwitch || op.value() != predSwitch.value() || + predSwitch.defaultDestination() == currentBlock) + return failure(); + + // Fold this switch to an unconditional branch. + APInt caseValue; + bool isDefault = true; + + for (unsigned i = 0; i < predSwitch.caseDestinations().size(); ++i) { + if (currentBlock == predSwitch.caseDestinations()[i]) { + caseValue = predSwitch.case_values()->getValue(i); + isDefault = false; + break; + } + } + if (isDefault) + rewriter.replaceOpWithNewOp(op, op.defaultDestination(), + op.defaultOperands()); + else + FoldSwitch(op, rewriter, caseValue); + return success(); +} + +/// switch %value, ^bb1 [ +/// 42: ^bb2, +/// ] +/// ^bb1: +/// switch %value, ^bb3 [ +/// 42: ^bb4, +/// 43: ^bb5 +/// ] +/// -> +/// switch %value, ^bb1 [ +/// 42: ^bb2, +/// ] +/// ^bb1: +/// switch %value, ^bb3 [ +/// 43: ^bb5 +/// ] +static LogicalResult +SimplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, + PatternRewriter &rewriter) { + // Check that we have a single distinct predecessor. + Block *currentBlock = op->getBlock(); + Block *predecessor = currentBlock->getSinglePredecessor(); + if (!predecessor) + return failure(); + + // Check that the predecessor terminates with a switch branch to this block + // and that it branches on the same condition and that this branch is the + // default destination. + auto predSwitch = dyn_cast(predecessor->getTerminator()); + if (!predSwitch || op.value() != predSwitch.value() || + predSwitch.defaultDestination() != currentBlock) + return failure(); + + // Delete case values that are not possible here. + llvm::DenseSet caseValuesToRemove; + + for (unsigned i = 0; i < predSwitch.caseDestinations().size(); ++i) { + if (currentBlock != predSwitch.caseDestinations()[i]) { + caseValuesToRemove.insert(predSwitch.case_values()->getValue(i)); + } + } + + SmallVector newCaseDestinations; + SmallVector newCaseOperands; + SmallVector newCaseValues; + bool requiresChange = false; + + for (unsigned i = 0; i < op.caseDestinations().size(); ++i) { + if (caseValuesToRemove.contains(op.case_values()->getValue(i))) { + requiresChange = true; + continue; + } + newCaseDestinations.push_back(op.caseDestinations()[i]); + newCaseOperands.push_back(op.getCaseOperands(i)); + newCaseValues.push_back(op.case_values() + ->getValue(i) + .cast() + .getValue() + .getSExtValue()); + } + + if (!requiresChange) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.value(), op.defaultDestination(), + op.defaultOperands(), newCaseValues, + newCaseDestinations, newCaseOperands); + return success(); +} + +void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(&SimplifySwitchWithOnlyDefault) + .add(&DropSwitchCasesThatMatchDefault) + .add(&SimplifyConstSwitchValue) + .add(&SimplifySwitchFromSwitchOnSameCondition) + .add(&SimplifySwitchFromDefaultSwitchOnSameCondition); +} + //===----------------------------------------------------------------------===// // TruncateIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir --- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir +++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck --dump-input-context 20 %s /// Test the folding of BranchOp. @@ -139,6 +139,193 @@ return } +/// Test the folding of SwitchOp + +// CHECK-LABEL: func @switch_fold_only_default( +func @switch_fold_only_default(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[CST:.*]] = constant 1 : i32 + // CHECK-NEXT: return %[[CST]] : i32 + %c1_i32 = constant 1 : i32 + switch %arg0, ^bb1(%c1_i32 : i32) [] + ^bb1(%x : i32): + return %x : i32 +} + + +// CHECK-LABEL: func @switch_fold_case_matching_default( +// CHECK-SAME: %[[SWITCH_VALUE:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]] +func @switch_fold_case_matching_default(%switchValue : i32, %caseOperand0 : f32, %caseOperand1 : f32) -> f32 { + // CHECK: switch %[[SWITCH_VALUE]], ^[[BB1:.+]](%[[CASE_OPERAND_0]] : f32) [ + // CHECK-NEXT: 10: ^[[BB2:.+]](%[[CASE_OPERAND_1]] : f32) + switch %switchValue, ^bb1(%caseOperand0 : f32) [ + 42: ^bb1(%caseOperand0 : f32), + 10: ^bb2(%caseOperand1 : f32), + 17: ^bb1(%caseOperand0 : f32) + ] + ^bb1(%x : f32): + return %x : f32 + ^bb2(%y : f32): + return %y : f32 +} + + +// CHECK-LABEL: func @switch_const_folding( +func @switch_const_folding() -> i32 { + // CHECK-NEXT: %[[CST:.*]] = constant 1 : i32 + // CHECK-NEXT: return %[[CST]] : i32 + %c0_i32 = constant 0 : i32 + %c1_i32 = constant 1 : i32 + %c2_i32 = constant 2 : i32 + %c3_i32 = constant 3 : i32 + switch %c0_i32, ^bb1(%c1_i32 : i32) [ + -1: ^bb2(%c2_i32 : i32), + 1: ^bb3(%c3_i32 : i32) + ] + ^bb1(%x : i32): + return %x : i32 + ^bb2(%y : i32): + return %y : i32 + ^bb3(%z : i32): + return %z : i32 +} + +// CHECK-LABEL: func @switch_passthrough( +// CHECK-SAME: %[[SWITCH_VALUE:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_3:[a-zA-Z0-9_]+]] +func @switch_passthrough(%switchValue : i32, + %caseOperand0 : f32, + %caseOperand1 : f32, + %caseOperand2 : f32, + %caseOperand3 : f32) -> f32{ + // CHECK: switch %[[SWITCH_VALUE]], ^[[BB4:.+]](%[[CASE_OPERAND_0]] : f32) [ + // CHECK-NEXT: 43: ^[[BB5:.+]](%[[CASE_OPERAND_1]] : f32) + // CHECK-NEXT: 44: ^[[BB3:.+]](%[[CASE_OPERAND_2]] : f32) + // CHECK-NEXT: ] + switch %switchValue, ^bb1(%caseOperand0 : f32) [ + 43: ^bb2(%caseOperand1 : f32), + 44: ^bb3(%caseOperand2 : f32) + ] + ^bb1(%bb1Arg : f32): + br ^bb4(%bb1Arg : f32) + ^bb2(%bb2Arg : f32): + br ^bb5(%bb2Arg : f32) + ^bb3(%bb3Arg : f32): + return %bb3Arg : f32 + + // CHECK: ^[[BB4]]( + // CHECK-SAME: %[[BB4_ARG:[a-zA-Z0-9_]+]] + // CHECK-NEXT: %[[BB4_RES:.+]] = "foo.op1"(%[[BB4_ARG]]) + // CHECK-NEXT: return %[[BB4_RES]] + ^bb4(%bb4Arg : f32): + %bb4Res = "foo.op1"(%bb4Arg) : (f32) -> f32 + return %bb4Res : f32 + + // CHECK: ^[[BB5]]( + // CHECK-SAME: %[[BB5_ARG:[a-zA-Z0-9_]+]] + // CHECK-NEXT: %[[BB5_RES:.+]] = "foo.op2"(%[[BB5_ARG]]) + // CHECK-NEXT: return %[[BB5_RES]] + ^bb5(%bb5Arg : f32): + %bb5Res = "foo.op2"(%bb5Arg) : (f32) -> f32 + return %bb5Res : f32 +} + +// CHECK-LABEL: func @switch_from_switch_with_same_value( +// CHECK-SAME: %[[SWITCH_VALUE:[a-zA-Z0-9_]+]] +func @switch_from_switch_with_same_value(%switchValue : i32) -> f32 { + // CHECK: switch %[[SWITCH_VALUE]] + switch %switchValue, ^bb1 [ + 42: ^bb2 + ] + + ^bb1: + // Add predecessors to avoid these getting folded by other patterns. + "foo.terminator1"() [^bb3, ^bb4] : () -> () + ^bb2: + "foo.op2"() : () -> () + // CHECK-NOT: switch %[[SWITCH_VALUE]] + // CHECK: br ^[[BB4:.+]] + switch %switchValue, ^bb3 [ + 42: ^bb4 + ] + + ^bb3: + "foo.terminator3"() : () -> () + + // CHECK: ^[[BB4]]: + // CHECK-NEXT: "foo.terminator4" + ^bb4: + "foo.terminator4"() : () -> () +} + +// CHECK-LABEL: func @switch_from_switch_with_same_value_no_match( +// CHECK-SAME: %[[SWITCH_VALUE:[a-zA-Z0-9_]+]] +func @switch_from_switch_with_same_value_no_match(%switchValue : i32) -> f32 { + // CHECK: switch %[[SWITCH_VALUE]] + switch %switchValue, ^bb1 [ + 42: ^bb2 + ] + + ^bb1: + // Add predecessors to avoid these getting folded by other patterns. + "foo.terminator1"() [^bb3, ^bb4] : () -> () + ^bb2: + "foo.op2"() : () -> () + // CHECK-NOT: switch %[[SWITCH_VALUE]] + // CHECK: br ^[[BB3:.+]] + switch %switchValue, ^bb3 [ + 0: ^bb5, + 43: ^bb4 + ] + + // CHECK: ^[[BB3]]: + // CHECK-NEXT: "foo.terminator3" + ^bb3: + "foo.terminator3"() : () -> () + + ^bb4: + "foo.terminator4"() : () -> () + + ^bb5: + "foo.terminator5"() : () -> () +} + +// CHECK-LABEL: func @switch_from_switch_default_with_same_value( +// CHECK-SAME: %[[SWITCH_VALUE:[a-zA-Z0-9_]+]] +func @switch_from_switch_default_with_same_value(%switchValue : i32) -> f32 { + // CHECK: switch %[[SWITCH_VALUE]] + switch %switchValue, ^bb2 [ + 42: ^bb1 + ] + + ^bb1: + // Add predecessors to avoid these getting folded by other patterns. + "foo.terminator1"() [^bb3, ^bb4, ^bb5] : () -> () + ^bb2: + "foo.op2"() : () -> () + // CHECK: switch %[[SWITCH_VALUE]] + // CHECK-NOT: 42: + switch %switchValue, ^bb3 [ + 42: ^bb4, + 43: ^bb5 + ] + + // NO-CHECK: ^[[BB3]]: + // NO-CHECK-NEXT: "foo.terminator3" + ^bb3: + "foo.terminator3"() : () -> () + + ^bb4: + "foo.terminator4"() : () -> () + + ^bb5: + "foo.terminator5"() : () -> () +} + /// Test folding conditional branches that are successors of conditional /// branches with the same condition.