diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2030,6 +2030,88 @@ let hasFolder = 1; } + +//===----------------------------------------------------------------------===// +// SwitchOp +//===----------------------------------------------------------------------===// + +def SwitchOp : Std_Op<"switch", + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + NoSideEffect, Terminator]> { + let summary = "switch operation"; + let description = [{ + The `switch` terminator operation represents a switch on a 32-bit integer + value. If the flag matches one of the specified cases, then the + corresponding destination is jumped to. If the flag does not match any of + the cases, the default destination is jumped to. The count and types of + operands must align with the arguments in the corresponding target blocks. + + Example: + + ```mlir + switch %flag, ^bb1(%a : i32) [ + 42: ^bb1(%b : i32), + 43: ^bb3(%c : i32) + ] + ``` + }]; + + let arguments = (ins AnyInteger:$flag, + Variadic:$defaultOperands, + Variadic:$caseOperands, + OptionalAttr:$case_values, + OptionalAttr:$case_operand_offsets); + let successors = (successor + AnySuccessor:$defaultDestination, + VariadicSuccessor:$caseDestinations); + let builders = [ + OpBuilder<(ins "Value":$flag, + "Block *":$defaultDestination, + "ValueRange":$defaultOperands, + CArg<"ArrayRef", "{}">:$caseValues, + CArg<"BlockRange", "{}">:$caseDestinations, + CArg<"ArrayRef", "{}">:$caseOperands)>, + OpBuilder<(ins "Value":$flag, + "Block *":$defaultDestination, + "ValueRange":$defaultOperands, + CArg<"ArrayRef", "{}">:$caseValues, + CArg<"BlockRange", "{}">:$caseDestinations, + CArg<"ArrayRef", "{}">:$caseOperands)>, + OpBuilder<(ins "Value":$flag, + "Block *":$defaultDestination, + "ValueRange":$defaultOperands, + CArg<"DenseIntElementsAttr", "{}">:$caseValues, + CArg<"BlockRange", "{}">:$caseDestinations, + CArg<"ArrayRef", "{}">:$caseOperands)> + ]; + + let assemblyFormat = [{ + $flag `:` type($flag) `,` `[` `\n` + custom(ref(type($flag)),$defaultDestination, + $defaultOperands, + type($defaultOperands), + $case_values, + $caseDestinations, + $caseOperands, + type($caseOperands), + $case_operand_offsets) + `]` + attr-dict + }]; + + let extraClassDeclaration = [{ + /// Return the operands for the case destination block at the given index. + OperandRange getCaseOperands(unsigned index); + + /// Return a mutable range of operands for the case destination block at the + /// given index. + MutableOperandRange getCaseOperandsMutable(unsigned index); + }]; + + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // TruncateIOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1333,13 +1333,15 @@ .isIndex()}]>, "index elements attribute">; -class AnyIntElementsAttr : IntElementsAttrBase< +def AnyIntElementsAttr : IntElementsAttrBase, "integer elements attribute">; + +class IntElementsAttrOf : IntElementsAttrBase< CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType()." "getElementType().isInteger(" # width # ")">, width # "-bit integer elements attribute">; -def AnyI32ElementsAttr : AnyIntElementsAttr<32>; -def AnyI64ElementsAttr : AnyIntElementsAttr<64>; +def AnyI32ElementsAttr : IntElementsAttrOf<32>; +def AnyI64ElementsAttr : IntElementsAttrOf<64>; class SignlessIntElementsAttr : IntElementsAttrBase< CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType()." diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2160,6 +2160,490 @@ SubTensorInsertOpCastFolder>(context); } +//===----------------------------------------------------------------------===// +// SwitchOp +//===----------------------------------------------------------------------===// + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, + Block *defaultDestination, ValueRange defaultOperands, + DenseIntElementsAttr caseValues, + BlockRange caseDestinations, + ArrayRef caseOperands) { + SmallVector flattenedCaseOperands; + SmallVector caseOperandOffsets; + int32_t offset = 0; + for (ValueRange operands : caseOperands) { + flattenedCaseOperands.append(operands.begin(), operands.end()); + caseOperandOffsets.push_back(offset); + offset += operands.size(); + } + DenseIntElementsAttr caseOperandOffsetsAttr; + if (!caseOperandOffsets.empty()) + caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets); + + build(builder, result, value, defaultOperands, flattenedCaseOperands, + caseValues, caseOperandOffsetsAttr, defaultDestination, + caseDestinations); +} + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, + Block *defaultDestination, ValueRange defaultOperands, + ArrayRef caseValues, BlockRange caseDestinations, + ArrayRef caseOperands) { + DenseIntElementsAttr caseValuesAttr; + if (!caseValues.empty()) { + ShapedType caseValueType = VectorType::get( + static_cast(caseValues.size()), value.getType()); + caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); + } + build(builder, result, value, defaultDestination, defaultOperands, + caseValuesAttr, caseDestinations, caseOperands); +} + +/// ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? +/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* +static ParseResult +parseSwitchOpCases(OpAsmParser &parser, Type &flagType, + Block *&defaultDestination, + SmallVectorImpl &defaultOperands, + SmallVectorImpl &defaultOperandTypes, + DenseIntElementsAttr &caseValues, + SmallVectorImpl &caseDestinations, + SmallVectorImpl &caseOperands, + SmallVectorImpl &caseOperandTypes, + DenseIntElementsAttr &caseOperandOffsets) { + + if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) || + failed(parser.parseSuccessor(defaultDestination))) + return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (failed(parser.parseRegionArgumentList(defaultOperands)) || + failed(parser.parseColonTypeList(defaultOperandTypes)) || + failed(parser.parseRParen())) + return failure(); + } + + SmallVector values; + SmallVector offsets; + unsigned bitWidth = flagType.getIntOrFloatBitWidth(); + int64_t offset = 0; + while (succeeded(parser.parseOptionalComma())) { + int64_t value = 0; + if (failed(parser.parseInteger(value))) + return failure(); + values.push_back(APInt(bitWidth, value)); + + Block *destination; + SmallVector operands; + if (failed(parser.parseColon()) || + failed(parser.parseSuccessor(destination))) + return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (failed(parser.parseRegionArgumentList(operands)) || + failed(parser.parseColonTypeList(caseOperandTypes)) || + failed(parser.parseRParen())) + return failure(); + } + caseDestinations.push_back(destination); + caseOperands.append(operands.begin(), operands.end()); + offsets.push_back(offset); + offset += operands.size(); + } + + if (values.empty()) + return success(); + + Builder &builder = parser.getBuilder(); + ShapedType caseValueType = + VectorType::get(static_cast(values.size()), flagType); + caseValues = DenseIntElementsAttr::get(caseValueType, values); + caseOperandOffsets = builder.getI32VectorAttr(offsets); + + return success(); +} + +static void printSwitchOpCases( + OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, + OperandRange defaultOperands, TypeRange defaultOperandTypes, + DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, + OperandRange caseOperands, TypeRange caseOperandTypes, + ElementsAttr caseOperandOffsets) { + p << " default: "; + p.printSuccessorAndUseList(defaultDestination, defaultOperands); + + if (!caseValues) + return; + + for (int64_t i = 0, size = caseValues.size(); i < size; ++i) { + p << ','; + p.printNewline(); + p << " "; + p << caseValues.getValue(i).getLimitedValue(); + p << ": "; + p.printSuccessorAndUseList(caseDestinations[i], op.getCaseOperands(i)); + } + p.printNewline(); +} + +static LogicalResult verify(SwitchOp op) { + auto caseValues = op.case_values(); + auto caseDestinations = op.caseDestinations(); + + if (!caseValues && caseDestinations.empty()) + return success(); + + Type flagType = op.flag().getType(); + Type caseValueType = caseValues->getType().getElementType(); + if (caseValueType != flagType) + return op.emitOpError() + << "'flag' type (" << flagType << ") should match case value type (" + << caseValueType << ")"; + + if (caseValues && + caseValues->size() != static_cast(caseDestinations.size())) + return op.emitOpError() << "number of case values (" << caseValues->size() + << ") should match number of " + "case destinations (" + << caseDestinations.size() << ")"; + return success(); +} + +OperandRange SwitchOp::getCaseOperands(unsigned index) { + return getCaseOperandsMutable(index); +} + +MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) { + MutableOperandRange caseOperands = caseOperandsMutable(); + if (!case_operand_offsets()) { + assert(caseOperands.size() == 0 && + "non-empty case operands must have offsets"); + return caseOperands; + } + + ElementsAttr offsets = case_operand_offsets().getValue(); + assert(index < offsets.size() && "invalid case operand offset index"); + + int64_t begin = offsets.getValue(index).cast().getInt(); + int64_t end = index + 1 == offsets.size() + ? caseOperands.size() + : offsets.getValue(index + 1).cast().getInt(); + return caseOperandsMutable().slice(begin, end - begin); +} + +Optional +SwitchOp::getMutableSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return index == 0 ? defaultOperandsMutable() + : getCaseOperandsMutable(index - 1); +} + +Block *SwitchOp::getSuccessorForOperands(ArrayRef operands) { + if (!case_values()) + return defaultDestination(); + + auto caseDests = caseDestinations(); + auto caseValues = case_values(); + if (int64_t value = + operands.front().dyn_cast_or_null().getSInt()) { + for (int64_t i = 0, size = case_values()->size(); i < size; ++i) + if (value == caseValues->getValue(i)) + return caseDests[i]; + return defaultDestination(); + } + return nullptr; +} + +/// switch %flag : i32, [ +/// default: ^bb1 +/// ] +/// -> br ^bb1 +static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, + PatternRewriter &rewriter) { + if (!op.caseDestinations().empty()) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.defaultDestination(), + op.defaultOperands()); + return success(); +} + +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb1, +/// 43: ^bb2 +/// ] +/// -> +/// switch %flag : i32, [ +/// default: ^bb1, +/// 43: ^bb2 +/// ] +static LogicalResult +dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { + SmallVector newCaseDestinations; + SmallVector newCaseOperands; + SmallVector newCaseValues; + bool requiresChange = false; + auto caseValues = op.case_values(); + auto caseDests = op.caseDestinations(); + + for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { + if (caseDests[i] == op.defaultDestination() && + op.getCaseOperands(i) == op.defaultOperands()) { + requiresChange = true; + continue; + } + newCaseDestinations.push_back(caseDests[i]); + newCaseOperands.push_back(op.getCaseOperands(i)); + newCaseValues.push_back(caseValues->getValue(i)); + } + + if (!requiresChange) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.flag(), op.defaultDestination(), + op.defaultOperands(), newCaseValues, + newCaseDestinations, newCaseOperands); + return success(); +} + +/// Helper for folding a switch with a constant value. +/// switch %c_42 : i32, [ +/// default: ^bb1 , +/// 42: ^bb2, +/// 43: ^bb3 +/// ] +/// -> br ^bb2 +static void FoldSwitch(SwitchOp op, PatternRewriter &rewriter, + APInt caseValue) { + auto caseValues = op.case_values(); + for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { + if (caseValues->getValue(i) == caseValue) { + rewriter.replaceOpWithNewOp(op, op.caseDestinations()[i], + op.getCaseOperands(i)); + return; + } + } + rewriter.replaceOpWithNewOp(op, op.defaultDestination(), + op.defaultOperands()); +} + +/// switch %c_42 : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// 43: ^bb3 +/// ] +/// -> br ^bb2 +static LogicalResult simplifyConstSwitchValue(SwitchOp op, + PatternRewriter &rewriter) { + APInt caseValue; + if (!matchPattern(op.flag(), m_ConstantInt(&caseValue))) + return failure(); + + FoldSwitch(op, rewriter, caseValue); + return success(); +} + +/// switch %c_42 : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// ] +/// ^bb2: +/// br ^bb3 +/// -> +/// switch %c_42 : i32, [ +/// default: ^bb1, +/// 42: ^bb3, +/// ] +static LogicalResult simplifyPassThroughSwitch(SwitchOp op, + PatternRewriter &rewriter) { + + SmallVector newCaseDests; + SmallVector newCaseOperands; + auto caseValues = op.case_values(); + auto caseDests = op.caseDestinations(); + bool requiresChange = false; + for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { + Block *caseDest = caseDests[i]; + ValueRange caseOperands = op.getCaseOperands(i); + SmallVector argStorage; + if (succeeded(collapseBranch(caseDest, caseOperands, argStorage))) + requiresChange = true; + + newCaseDests.push_back(caseDest); + newCaseOperands.push_back(caseOperands); + } + + Block *defaultDest = op.defaultDestination(); + ValueRange defaultOperands = op.defaultOperands(); + SmallVector argStorage; + + if (succeeded(collapseBranch(defaultDest, defaultOperands, argStorage))) + requiresChange = true; + + if (!requiresChange) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.flag(), defaultDest, + defaultOperands, caseValues.getValue(), + newCaseDests, newCaseOperands); + return success(); +} + +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// ] +/// ^bb2: +/// switch %flag : i32, [ +/// default: ^bb3, +/// 42: ^bb4 +/// ] +/// -> +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// ] +/// ^bb2: +/// br ^bb4 +/// +/// and +/// +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// ] +/// ^bb2: +/// switch %flag : i32, [ +/// default: ^bb3, +/// 43: ^bb4 +/// ] +/// -> +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// ] +/// ^bb2: +/// br ^bb3 +static LogicalResult +simplifySwitchFromSwitchOnSameCondition(SwitchOp op, + PatternRewriter &rewriter) { + // Check that we have a single distinct predecessor. + Block *currentBlock = op->getBlock(); + Block *predecessor = currentBlock->getSinglePredecessor(); + if (!predecessor) + return failure(); + + // Check that the predecessor terminates with a switch branch to this block + // and that it branches on the same condition and that this branch isn't the + // default destination. + auto predSwitch = dyn_cast(predecessor->getTerminator()); + if (!predSwitch || op.flag() != predSwitch.flag() || + predSwitch.defaultDestination() == currentBlock) + return failure(); + + // Fold this switch to an unconditional branch. + APInt caseValue; + bool isDefault = true; + auto predDests = predSwitch.caseDestinations(); + auto predCaseValues = predSwitch.case_values(); + for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) { + if (currentBlock == predDests[i]) { + caseValue = predCaseValues->getValue(i); + isDefault = false; + break; + } + } + if (isDefault) + rewriter.replaceOpWithNewOp(op, op.defaultDestination(), + op.defaultOperands()); + else + FoldSwitch(op, rewriter, caseValue); + return success(); +} + +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb2 +/// ] +/// ^bb1: +/// switch %flag : i32, [ +/// default: ^bb3, +/// 42: ^bb4, +/// 43: ^bb5 +/// ] +/// -> +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// ] +/// ^bb1: +/// switch %flag : i32, [ +/// default: ^bb3, +/// 43: ^bb5 +/// ] +static LogicalResult +simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, + PatternRewriter &rewriter) { + // Check that we have a single distinct predecessor. + Block *currentBlock = op->getBlock(); + Block *predecessor = currentBlock->getSinglePredecessor(); + if (!predecessor) + return failure(); + + // Check that the predecessor terminates with a switch branch to this block + // and that it branches on the same condition and that this branch is the + // default destination. + auto predSwitch = dyn_cast(predecessor->getTerminator()); + if (!predSwitch || op.flag() != predSwitch.flag() || + predSwitch.defaultDestination() != currentBlock) + return failure(); + + // Delete case values that are not possible here. + DenseSet caseValuesToRemove; + auto predDests = predSwitch.caseDestinations(); + auto predCaseValues = predSwitch.case_values(); + for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) { + if (currentBlock != predDests[i]) { + caseValuesToRemove.insert(predCaseValues->getValue(i)); + } + } + + SmallVector newCaseDestinations; + SmallVector newCaseOperands; + SmallVector newCaseValues; + bool requiresChange = false; + + auto caseValues = op.case_values(); + auto caseDests = op.caseDestinations(); + for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { + if (caseValuesToRemove.contains(caseValues->getValue(i))) { + requiresChange = true; + continue; + } + newCaseDestinations.push_back(caseDests[i]); + newCaseOperands.push_back(op.getCaseOperands(i)); + newCaseValues.push_back(caseValues->getValue(i)); + } + + if (!requiresChange) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.flag(), op.defaultDestination(), + op.defaultOperands(), newCaseValues, + newCaseDestinations, newCaseOperands); + return success(); +} + +void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(&simplifySwitchWithOnlyDefault) + .add(&dropSwitchCasesThatMatchDefault) + .add(&simplifyConstSwitchValue) + .add(&simplifyPassThroughSwitch) + .add(&simplifySwitchFromSwitchOnSameCondition) + .add(&simplifySwitchFromDefaultSwitchOnSameCondition); +} + //===----------------------------------------------------------------------===// // TruncateIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir --- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir +++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir @@ -139,6 +139,205 @@ return } +/// Test the folding of SwitchOp + +// CHECK-LABEL: func @switch_fold_only_default( +func @switch_fold_only_default(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[CST:.*]] = constant 1 : i32 + // CHECK-NEXT: return %[[CST]] : i32 + %c1_i32 = constant 1 : i32 + switch %arg0 : i32, [ + default: ^bb1(%c1_i32 : i32) + ] + ^bb1(%x : i32): + return %x : i32 +} + + +// CHECK-LABEL: func @switch_fold_case_matching_default( +// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]] +func @switch_fold_case_matching_default(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) -> f32 { + // CHECK: switch %[[FLAG]] + // CHECK-NEXT: default: ^[[BB1:.+]](%[[CASE_OPERAND_0]] : f32) + // CHECK-NEXT: 10: ^[[BB2:.+]](%[[CASE_OPERAND_1]] : f32) + // CHECK-NEXT: ] + switch %flag : i32, [ + default: ^bb1(%caseOperand0 : f32), + 42: ^bb1(%caseOperand0 : f32), + 10: ^bb2(%caseOperand1 : f32), + 17: ^bb1(%caseOperand0 : f32) + ] + ^bb1(%x : f32): + return %x : f32 + ^bb2(%y : f32): + return %y : f32 +} + + +// CHECK-LABEL: func @switch_const_folding( +func @switch_const_folding() -> i32 { + // CHECK-NEXT: %[[CST:.*]] = constant 1 : i32 + // CHECK-NEXT: return %[[CST]] : i32 + %c0_i32 = constant 0 : i32 + %c1_i32 = constant 1 : i32 + %c2_i32 = constant 2 : i32 + %c3_i32 = constant 3 : i32 + switch %c0_i32 : i32, [ + default: ^bb1(%c1_i32 : i32), + -1: ^bb2(%c2_i32 : i32), + 1: ^bb3(%c3_i32 : i32) + ] + ^bb1(%x : i32): + return %x : i32 + ^bb2(%y : i32): + return %y : i32 + ^bb3(%z : i32): + return %z : i32 +} + +// CHECK-LABEL: func @switch_passthrough( +// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_3:[a-zA-Z0-9_]+]] +func @switch_passthrough(%flag : i32, + %caseOperand0 : f32, + %caseOperand1 : f32, + %caseOperand2 : f32, + %caseOperand3 : f32) -> f32{ + // CHECK: switch %[[FLAG]] + // CHECK-NEXT: default: ^[[BB4:.+]](%[[CASE_OPERAND_0]] : f32) + // CHECK-NEXT: 43: ^[[BB5:.+]](%[[CASE_OPERAND_1]] : f32) + // CHECK-NEXT: 44: ^[[BB3:.+]](%[[CASE_OPERAND_2]] : f32) + // CHECK-NEXT: ] + switch %flag : i32, [ + default: ^bb1(%caseOperand0 : f32), + 43: ^bb2(%caseOperand1 : f32), + 44: ^bb3(%caseOperand2 : f32) + ] + ^bb1(%bb1Arg : f32): + br ^bb4(%bb1Arg : f32) + ^bb2(%bb2Arg : f32): + br ^bb5(%bb2Arg : f32) + ^bb3(%bb3Arg : f32): + return %bb3Arg : f32 + + // CHECK: ^[[BB4]]( + // CHECK-SAME: %[[BB4_ARG:[a-zA-Z0-9_]+]] + // CHECK-NEXT: %[[BB4_RES:.+]] = "foo.op1"(%[[BB4_ARG]]) + // CHECK-NEXT: return %[[BB4_RES]] + ^bb4(%bb4Arg : f32): + %bb4Res = "foo.op1"(%bb4Arg) : (f32) -> f32 + return %bb4Res : f32 + + // CHECK: ^[[BB5]]( + // CHECK-SAME: %[[BB5_ARG:[a-zA-Z0-9_]+]] + // CHECK-NEXT: %[[BB5_RES:.+]] = "foo.op2"(%[[BB5_ARG]]) + // CHECK-NEXT: return %[[BB5_RES]] + ^bb5(%bb5Arg : f32): + %bb5Res = "foo.op2"(%bb5Arg) : (f32) -> f32 + return %bb5Res : f32 +} + +// CHECK-LABEL: func @switch_from_switch_with_same_value( +// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]] +func @switch_from_switch_with_same_value(%flag : i32) -> f32 { + // CHECK: switch %[[FLAG]] + switch %flag : i32, [ + default: ^bb1, + 42: ^bb2 + ] + + ^bb1: + // Add predecessors to avoid these getting folded by other patterns. + "foo.terminator1"() [^bb3, ^bb4] : () -> () + ^bb2: + "foo.op2"() : () -> () + // CHECK-NOT: switch %[[FLAG]] + // CHECK: br ^[[BB4:.+]] + switch %flag : i32, [ + default: ^bb3, + 42: ^bb4 + ] + + ^bb3: + "foo.terminator3"() : () -> () + + // CHECK: ^[[BB4]]: + // CHECK-NEXT: "foo.terminator4" + ^bb4: + "foo.terminator4"() : () -> () +} + +// CHECK-LABEL: func @switch_from_switch_with_same_value_no_match( +// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]] +func @switch_from_switch_with_same_value_no_match(%flag : i32) -> f32 { + // CHECK: switch %[[FLAG]] + switch %flag : i32, [ + default: ^bb1, + 42: ^bb2 + ] + + ^bb1: + // Add predecessors to avoid these getting folded by other patterns. + "foo.terminator1"() [^bb3, ^bb4] : () -> () + ^bb2: + "foo.op2"() : () -> () + // CHECK-NOT: switch %[[FLAG]] + // CHECK: br ^[[BB3:.+]] + switch %flag : i32, [ + default: ^bb3, + 0: ^bb5, + 43: ^bb4 + ] + + // CHECK: ^[[BB3]]: + // CHECK-NEXT: "foo.terminator3" + ^bb3: + "foo.terminator3"() : () -> () + + ^bb4: + "foo.terminator4"() : () -> () + + ^bb5: + "foo.terminator5"() : () -> () +} + +// CHECK-LABEL: func @switch_from_switch_default_with_same_value( +// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]] +func @switch_from_switch_default_with_same_value(%flag : i32) -> f32 { + // CHECK: switch %[[FLAG]] + switch %flag : i32, [ + default: ^bb2, + 42: ^bb1 + ] + + ^bb1: + // Add predecessors to avoid these getting folded by other patterns. + "foo.terminator1"() [^bb3, ^bb4, ^bb5] : () -> () + ^bb2: + "foo.op2"() : () -> () + // CHECK: switch %[[FLAG]] + // CHECK-NOT: 42: + switch %flag : i32, [ + default: ^bb3, + 42: ^bb4, + 43: ^bb5 + ] + + ^bb3: + "foo.terminator3"() : () -> () + + ^bb4: + "foo.terminator4"() : () -> () + + ^bb5: + "foo.terminator5"() : () -> () +} + /// Test folding conditional branches that are successors of conditional /// branches with the same condition. diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -96,3 +96,35 @@ %1 = memref.tensor_load %0 : memref<2xf32> return } + +// CHECK-LABEL: func @switch( +func @switch(%flag : i32, %caseOperand : i32) { + switch %flag : i32, [ + default: ^bb1(%caseOperand : i32), + 42: ^bb2(%caseOperand : i32), + 43: ^bb3(%caseOperand : i32) + ] + + ^bb1(%bb1arg : i32): + return + ^bb2(%bb2arg : i32): + return + ^bb3(%bb3arg : i32): + return +} + +// CHECK-LABEL: func @switch_i64( +func @switch_i64(%flag : i64, %caseOperand : i32) { + switch %flag : i64, [ + default: ^bb1(%caseOperand : i32), + 42: ^bb2(%caseOperand : i32), + 43: ^bb3(%caseOperand : i32) + ] + + ^bb1(%bb1arg : i32): + return + ^bb2(%bb2arg : i32): + return + ^bb3(%bb3arg : i32): + return +} diff --git a/mlir/test/Dialect/Standard/parser.mlir b/mlir/test/Dialect/Standard/parser.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/parser.mlir @@ -0,0 +1,69 @@ +// RUN: mlir-opt -verify-diagnostics -split-input-file %s + +func @switch_missing_case_value(%flag : i32, %caseOperand : i32) { + switch %flag : i32, [ + default: ^bb1(%caseOperand : i32), + 45: ^bb2(%caseOperand : i32), + // expected-error@+1 {{expected integer value}} + : ^bb3(%caseOperand : i32) + ] + + ^bb1(%bb1arg : i32): + return + ^bb2(%bb2arg : i32): + return + ^bb3(%bb3arg : i32): + return +} + +// ----- + +func @switch_wrong_type_case_value(%flag : i32, %caseOperand : i32) { + switch %flag : i32, [ + default: ^bb1(%caseOperand : i32), + // expected-error@+1 {{expected integer value}} + "hello": ^bb2(%caseOperand : i32) + ] + + ^bb1(%bb1arg : i32): + return + ^bb2(%bb2arg : i32): + return + ^bb3(%bb3arg : i32): + return +} + +// ----- + +func @switch_missing_comma(%flag : i32, %caseOperand : i32) { + switch %flag : i32, [ + default: ^bb1(%caseOperand : i32), + 45: ^bb2(%caseOperand : i32) + // expected-error@+1 {{expected ']'}} + 43: ^bb3(%caseOperand : i32) + ] + + ^bb1(%bb1arg : i32): + return + ^bb2(%bb2arg : i32): + return + ^bb3(%bb3arg : i32): + return +} + +// ----- + +func @switch_missing_default(%flag : i32, %caseOperand : i32) { + switch %flag : i32, [ + // expected-error@+1 {{expected 'default'}} + 45: ^bb2(%caseOperand : i32) + 43: ^bb3(%caseOperand : i32) + ] + + ^bb1(%bb1arg : i32): + return + ^bb2(%bb2arg : i32): + return + ^bb3(%bb3arg : i32): + return +}