diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2030,6 +2030,89 @@ let hasFolder = 1; } + +//===----------------------------------------------------------------------===// +// SwitchOp +//===----------------------------------------------------------------------===// + +def SwitchOp : Std_Op<"switch", + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + NoSideEffect, Terminator]> { + let summary = "switch operation"; + let description = [{ + The `switch` terminator operation represents a switch on a 32-bit integer + value. If the flag matches one of the specified cases, then the + corresponding destination is jumped to. If the flag does not match any of + the cases, the default destination is jumped to. The count and types of + operands must align with the arguments in the corresponding target blocks. + + Example: + + ```mlir + switch %flag : i32, [ + default: ^bb1(%a : i32), + 42: ^bb1(%b : i32), + 43: ^bb3(%c : i32) + ] + ``` + }]; + + let arguments = (ins AnyInteger:$flag, + Variadic:$defaultOperands, + Variadic:$caseOperands, + OptionalAttr:$case_values, + OptionalAttr:$case_operand_offsets); + let successors = (successor + AnySuccessor:$defaultDestination, + VariadicSuccessor:$caseDestinations); + let builders = [ + OpBuilder<(ins "Value":$flag, + "Block *":$defaultDestination, + "ValueRange":$defaultOperands, + CArg<"ArrayRef", "{}">:$caseValues, + CArg<"BlockRange", "{}">:$caseDestinations, + CArg<"ArrayRef", "{}">:$caseOperands)>, + OpBuilder<(ins "Value":$flag, + "Block *":$defaultDestination, + "ValueRange":$defaultOperands, + CArg<"ArrayRef", "{}">:$caseValues, + CArg<"BlockRange", "{}">:$caseDestinations, + CArg<"ArrayRef", "{}">:$caseOperands)>, + OpBuilder<(ins "Value":$flag, + "Block *":$defaultDestination, + "ValueRange":$defaultOperands, + CArg<"DenseIntElementsAttr", "{}">:$caseValues, + CArg<"BlockRange", "{}">:$caseDestinations, + CArg<"ArrayRef", "{}">:$caseOperands)> + ]; + + let assemblyFormat = [{ + $flag `:` type($flag) `,` `[` `\n` + custom(ref(type($flag)),$defaultDestination, + $defaultOperands, + type($defaultOperands), + $case_values, + $caseDestinations, + $caseOperands, + type($caseOperands), + $case_operand_offsets) + `]` + attr-dict + }]; + + let extraClassDeclaration = [{ + /// Return the operands for the case destination block at the given index. + OperandRange getCaseOperands(unsigned index); + + /// Return a mutable range of operands for the case destination block at the + /// given index. + MutableOperandRange getCaseOperandsMutable(unsigned index); + }]; + + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // TruncateIOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1333,13 +1333,15 @@ .isIndex()}]>, "index elements attribute">; -class AnyIntElementsAttr : IntElementsAttrBase< +def AnyIntElementsAttr : IntElementsAttrBase, "integer elements attribute">; + +class IntElementsAttrOf : IntElementsAttrBase< CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType()." "getElementType().isInteger(" # width # ")">, width # "-bit integer elements attribute">; -def AnyI32ElementsAttr : AnyIntElementsAttr<32>; -def AnyI64ElementsAttr : AnyIntElementsAttr<64>; +def AnyI32ElementsAttr : IntElementsAttrOf<32>; +def AnyI64ElementsAttr : IntElementsAttrOf<64>; class SignlessIntElementsAttr : IntElementsAttrBase< CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType()." diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -441,8 +441,9 @@ /// Given a successor, try to collapse it to a new destination if it only /// contains a passthrough unconditional branch. If the successor is /// collapsable, `successor` and `successorOperands` are updated to reference -/// the new destination and values. `argStorage` is an optional storage to use -/// if operands to the collapsed successor need to be remapped. +/// the new destination and values. `argStorage` is used as storage if operands +/// to the collapsed successor need to be remapped. It must outlive uses of +/// successorOperands. static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl &argStorage) { @@ -2160,6 +2161,492 @@ SubTensorInsertOpCastFolder>(context); } +//===----------------------------------------------------------------------===// +// SwitchOp +//===----------------------------------------------------------------------===// + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, + Block *defaultDestination, ValueRange defaultOperands, + DenseIntElementsAttr caseValues, + BlockRange caseDestinations, + ArrayRef caseOperands) { + SmallVector flattenedCaseOperands; + SmallVector caseOperandOffsets; + int32_t offset = 0; + for (ValueRange operands : caseOperands) { + flattenedCaseOperands.append(operands.begin(), operands.end()); + caseOperandOffsets.push_back(offset); + offset += operands.size(); + } + DenseIntElementsAttr caseOperandOffsetsAttr; + if (!caseOperandOffsets.empty()) + caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets); + + build(builder, result, value, defaultOperands, flattenedCaseOperands, + caseValues, caseOperandOffsetsAttr, defaultDestination, + caseDestinations); +} + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, + Block *defaultDestination, ValueRange defaultOperands, + ArrayRef caseValues, BlockRange caseDestinations, + ArrayRef caseOperands) { + DenseIntElementsAttr caseValuesAttr; + if (!caseValues.empty()) { + ShapedType caseValueType = VectorType::get( + static_cast(caseValues.size()), value.getType()); + caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); + } + build(builder, result, value, defaultDestination, defaultOperands, + caseValuesAttr, caseDestinations, caseOperands); +} + +/// ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? +/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* +static ParseResult +parseSwitchOpCases(OpAsmParser &parser, Type &flagType, + Block *&defaultDestination, + SmallVectorImpl &defaultOperands, + SmallVectorImpl &defaultOperandTypes, + DenseIntElementsAttr &caseValues, + SmallVectorImpl &caseDestinations, + SmallVectorImpl &caseOperands, + SmallVectorImpl &caseOperandTypes, + DenseIntElementsAttr &caseOperandOffsets) { + + if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) || + failed(parser.parseSuccessor(defaultDestination))) + return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (failed(parser.parseRegionArgumentList(defaultOperands)) || + failed(parser.parseColonTypeList(defaultOperandTypes)) || + failed(parser.parseRParen())) + return failure(); + } + + SmallVector values; + SmallVector offsets; + unsigned bitWidth = flagType.getIntOrFloatBitWidth(); + int64_t offset = 0; + while (succeeded(parser.parseOptionalComma())) { + int64_t value = 0; + if (failed(parser.parseInteger(value))) + return failure(); + values.push_back(APInt(bitWidth, value)); + + Block *destination; + SmallVector operands; + if (failed(parser.parseColon()) || + failed(parser.parseSuccessor(destination))) + return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (failed(parser.parseRegionArgumentList(operands)) || + failed(parser.parseColonTypeList(caseOperandTypes)) || + failed(parser.parseRParen())) + return failure(); + } + caseDestinations.push_back(destination); + caseOperands.append(operands.begin(), operands.end()); + offsets.push_back(offset); + offset += operands.size(); + } + + if (values.empty()) + return success(); + + Builder &builder = parser.getBuilder(); + ShapedType caseValueType = + VectorType::get(static_cast(values.size()), flagType); + caseValues = DenseIntElementsAttr::get(caseValueType, values); + caseOperandOffsets = builder.getI32VectorAttr(offsets); + + return success(); +} + +static void printSwitchOpCases( + OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, + OperandRange defaultOperands, TypeRange defaultOperandTypes, + DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, + OperandRange caseOperands, TypeRange caseOperandTypes, + ElementsAttr caseOperandOffsets) { + p << " default: "; + p.printSuccessorAndUseList(defaultDestination, defaultOperands); + + if (!caseValues) + return; + + for (int64_t i = 0, size = caseValues.size(); i < size; ++i) { + p << ','; + p.printNewline(); + p << " "; + p << caseValues.getValue(i).getLimitedValue(); + p << ": "; + p.printSuccessorAndUseList(caseDestinations[i], op.getCaseOperands(i)); + } + p.printNewline(); +} + +static LogicalResult verify(SwitchOp op) { + auto caseValues = op.case_values(); + auto caseDestinations = op.caseDestinations(); + + if (!caseValues && caseDestinations.empty()) + return success(); + + Type flagType = op.flag().getType(); + Type caseValueType = caseValues->getType().getElementType(); + if (caseValueType != flagType) + return op.emitOpError() + << "'flag' type (" << flagType << ") should match case value type (" + << caseValueType << ")"; + + if (caseValues && + caseValues->size() != static_cast(caseDestinations.size())) + return op.emitOpError() << "number of case values (" << caseValues->size() + << ") should match number of " + "case destinations (" + << caseDestinations.size() << ")"; + return success(); +} + +OperandRange SwitchOp::getCaseOperands(unsigned index) { + return getCaseOperandsMutable(index); +} + +MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) { + MutableOperandRange caseOperands = caseOperandsMutable(); + if (!case_operand_offsets()) { + assert(caseOperands.size() == 0 && + "non-empty case operands must have offsets"); + return caseOperands; + } + + ElementsAttr offsets = case_operand_offsets().getValue(); + assert(index < offsets.size() && "invalid case operand offset index"); + + int64_t begin = offsets.getValue(index).cast().getInt(); + int64_t end = index + 1 == offsets.size() + ? caseOperands.size() + : offsets.getValue(index + 1).cast().getInt(); + return caseOperandsMutable().slice(begin, end - begin); +} + +Optional +SwitchOp::getMutableSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return index == 0 ? defaultOperandsMutable() + : getCaseOperandsMutable(index - 1); +} + +Block *SwitchOp::getSuccessorForOperands(ArrayRef operands) { + if (!case_values()) + return defaultDestination(); + + auto caseDests = caseDestinations(); + auto caseValues = case_values(); + if (int64_t value = + operands.front().dyn_cast_or_null().getSInt()) { + for (int64_t i = 0, size = case_values()->size(); i < size; ++i) + if (value == caseValues->getValue(i)) + return caseDests[i]; + return defaultDestination(); + } + return nullptr; +} + +/// switch %flag : i32, [ +/// default: ^bb1 +/// ] +/// -> br ^bb1 +static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, + PatternRewriter &rewriter) { + if (!op.caseDestinations().empty()) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.defaultDestination(), + op.defaultOperands()); + return success(); +} + +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb1, +/// 43: ^bb2 +/// ] +/// -> +/// switch %flag : i32, [ +/// default: ^bb1, +/// 43: ^bb2 +/// ] +static LogicalResult +dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { + SmallVector newCaseDestinations; + SmallVector newCaseOperands; + SmallVector newCaseValues; + bool requiresChange = false; + auto caseValues = op.case_values(); + auto caseDests = op.caseDestinations(); + + for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { + if (caseDests[i] == op.defaultDestination() && + op.getCaseOperands(i) == op.defaultOperands()) { + requiresChange = true; + continue; + } + newCaseDestinations.push_back(caseDests[i]); + newCaseOperands.push_back(op.getCaseOperands(i)); + newCaseValues.push_back(caseValues->getValue(i)); + } + + if (!requiresChange) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.flag(), op.defaultDestination(), + op.defaultOperands(), newCaseValues, + newCaseDestinations, newCaseOperands); + return success(); +} + +/// Helper for folding a switch with a constant value. +/// switch %c_42 : i32, [ +/// default: ^bb1 , +/// 42: ^bb2, +/// 43: ^bb3 +/// ] +/// -> br ^bb2 +static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, + APInt caseValue) { + auto caseValues = op.case_values(); + for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { + if (caseValues->getValue(i) == caseValue) { + rewriter.replaceOpWithNewOp(op, op.caseDestinations()[i], + op.getCaseOperands(i)); + return; + } + } + rewriter.replaceOpWithNewOp(op, op.defaultDestination(), + op.defaultOperands()); +} + +/// switch %c_42 : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// 43: ^bb3 +/// ] +/// -> br ^bb2 +static LogicalResult simplifyConstSwitchValue(SwitchOp op, + PatternRewriter &rewriter) { + APInt caseValue; + if (!matchPattern(op.flag(), m_ConstantInt(&caseValue))) + return failure(); + + foldSwitch(op, rewriter, caseValue); + return success(); +} + +/// switch %c_42 : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// ] +/// ^bb2: +/// br ^bb3 +/// -> +/// switch %c_42 : i32, [ +/// default: ^bb1, +/// 42: ^bb3, +/// ] +static LogicalResult simplifyPassThroughSwitch(SwitchOp op, + PatternRewriter &rewriter) { + + SmallVector newCaseDests; + SmallVector newCaseOperands; + SmallVector> argStorage; + auto caseValues = op.case_values(); + auto caseDests = op.caseDestinations(); + bool requiresChange = false; + for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { + Block *caseDest = caseDests[i]; + ValueRange caseOperands = op.getCaseOperands(i); + argStorage.emplace_back(); + if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back()))) + requiresChange = true; + + newCaseDests.push_back(caseDest); + newCaseOperands.push_back(caseOperands); + } + + Block *defaultDest = op.defaultDestination(); + ValueRange defaultOperands = op.defaultOperands(); + argStorage.emplace_back(); + + if (succeeded( + collapseBranch(defaultDest, defaultOperands, argStorage.back()))) + requiresChange = true; + + if (!requiresChange) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.flag(), defaultDest, + defaultOperands, caseValues.getValue(), + newCaseDests, newCaseOperands); + return success(); +} + +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// ] +/// ^bb2: +/// switch %flag : i32, [ +/// default: ^bb3, +/// 42: ^bb4 +/// ] +/// -> +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// ] +/// ^bb2: +/// br ^bb4 +/// +/// and +/// +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// ] +/// ^bb2: +/// switch %flag : i32, [ +/// default: ^bb3, +/// 43: ^bb4 +/// ] +/// -> +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// ] +/// ^bb2: +/// br ^bb3 +static LogicalResult +simplifySwitchFromSwitchOnSameCondition(SwitchOp op, + PatternRewriter &rewriter) { + // Check that we have a single distinct predecessor. + Block *currentBlock = op->getBlock(); + Block *predecessor = currentBlock->getSinglePredecessor(); + if (!predecessor) + return failure(); + + // Check that the predecessor terminates with a switch branch to this block + // and that it branches on the same condition and that this branch isn't the + // default destination. + auto predSwitch = dyn_cast(predecessor->getTerminator()); + if (!predSwitch || op.flag() != predSwitch.flag() || + predSwitch.defaultDestination() == currentBlock) + return failure(); + + // Fold this switch to an unconditional branch. + APInt caseValue; + bool isDefault = true; + auto predDests = predSwitch.caseDestinations(); + auto predCaseValues = predSwitch.case_values(); + for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) { + if (currentBlock == predDests[i]) { + caseValue = predCaseValues->getValue(i); + isDefault = false; + break; + } + } + if (isDefault) + rewriter.replaceOpWithNewOp(op, op.defaultDestination(), + op.defaultOperands()); + else + foldSwitch(op, rewriter, caseValue); + return success(); +} + +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb2 +/// ] +/// ^bb1: +/// switch %flag : i32, [ +/// default: ^bb3, +/// 42: ^bb4, +/// 43: ^bb5 +/// ] +/// -> +/// switch %flag : i32, [ +/// default: ^bb1, +/// 42: ^bb2, +/// ] +/// ^bb1: +/// switch %flag : i32, [ +/// default: ^bb3, +/// 43: ^bb5 +/// ] +static LogicalResult +simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, + PatternRewriter &rewriter) { + // Check that we have a single distinct predecessor. + Block *currentBlock = op->getBlock(); + Block *predecessor = currentBlock->getSinglePredecessor(); + if (!predecessor) + return failure(); + + // Check that the predecessor terminates with a switch branch to this block + // and that it branches on the same condition and that this branch is the + // default destination. + auto predSwitch = dyn_cast(predecessor->getTerminator()); + if (!predSwitch || op.flag() != predSwitch.flag() || + predSwitch.defaultDestination() != currentBlock) + return failure(); + + // Delete case values that are not possible here. + DenseSet caseValuesToRemove; + auto predDests = predSwitch.caseDestinations(); + auto predCaseValues = predSwitch.case_values(); + for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) { + if (currentBlock != predDests[i]) { + caseValuesToRemove.insert(predCaseValues->getValue(i)); + } + } + + SmallVector newCaseDestinations; + SmallVector newCaseOperands; + SmallVector newCaseValues; + bool requiresChange = false; + + auto caseValues = op.case_values(); + auto caseDests = op.caseDestinations(); + for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { + if (caseValuesToRemove.contains(caseValues->getValue(i))) { + requiresChange = true; + continue; + } + newCaseDestinations.push_back(caseDests[i]); + newCaseOperands.push_back(op.getCaseOperands(i)); + newCaseValues.push_back(caseValues->getValue(i)); + } + + if (!requiresChange) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.flag(), op.defaultDestination(), + op.defaultOperands(), newCaseValues, + newCaseDestinations, newCaseOperands); + return success(); +} + +void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(&simplifySwitchWithOnlyDefault) + .add(&dropSwitchCasesThatMatchDefault) + .add(&simplifyConstSwitchValue) + .add(&simplifyPassThroughSwitch) + .add(&simplifySwitchFromSwitchOnSameCondition) + .add(&simplifySwitchFromDefaultSwitchOnSameCondition); +} + //===----------------------------------------------------------------------===// // TruncateIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir --- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir +++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck --dump-input-context 20 %s /// Test the folding of BranchOp. @@ -139,6 +139,268 @@ return } + +/// Test the folding of SwitchOp + +// CHECK-LABEL: func @switch_only_default( +// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]] +func @switch_only_default(%flag : i32, %caseOperand0 : f32) { + // add predecessors for all blocks to avoid other canonicalizations. + "foo.pred"() [^bb1, ^bb2] : () -> () + ^bb1: + // CHECK-NOT: switch + // CHECK: br ^[[BB2:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]] + switch %flag : i32, [ + default: ^bb2(%caseOperand0 : f32) + ] + // CHECK: ^[[BB2]]({{.*}}): + ^bb2(%bb2Arg : f32): + // CHECK-NEXT: "foo.bb2Terminator" + "foo.bb2Terminator"(%bb2Arg) : (f32) -> () +} + + +// CHECK-LABEL: func @switch_case_matching_default( +// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]] +func @switch_case_matching_default(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) { + // add predecessors for all blocks to avoid other canonicalizations. + "foo.pred"() [^bb1, ^bb2, ^bb3] : () -> () + ^bb1: + // CHECK: switch %[[FLAG]] + // CHECK-NEXT: default: ^[[BB1:.+]](%[[CASE_OPERAND_0]] : f32) + // CHECK-NEXT: 10: ^[[BB2:.+]](%[[CASE_OPERAND_1]] : f32) + // CHECK-NEXT: ] + switch %flag : i32, [ + default: ^bb2(%caseOperand0 : f32), + 42: ^bb2(%caseOperand0 : f32), + 10: ^bb3(%caseOperand1 : f32), + 17: ^bb2(%caseOperand0 : f32) + ] + ^bb2(%bb2Arg : f32): + "foo.bb2Terminator"(%bb2Arg) : (f32) -> () + ^bb3(%bb3Arg : f32): + "foo.bb3Terminator"(%bb3Arg) : (f32) -> () +} + + +// CHECK-LABEL: func @switch_on_const_no_match( +// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]] +func @switch_on_const_no_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) { + // add predecessors for all blocks to avoid other canonicalizations. + "foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4] : () -> () + ^bb1: + // CHECK-NOT: switch + // CHECK: br ^[[BB2:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]] + %c0_i32 = constant 0 : i32 + switch %c0_i32 : i32, [ + default: ^bb2(%caseOperand0 : f32), + -1: ^bb3(%caseOperand1 : f32), + 1: ^bb4(%caseOperand2 : f32) + ] + // CHECK: ^[[BB2]]({{.*}}): + // CHECK-NEXT: "foo.bb2Terminator" + ^bb2(%bb2Arg : f32): + "foo.bb2Terminator"(%bb2Arg) : (f32) -> () + ^bb3(%bb3Arg : f32): + "foo.bb3Terminator"(%bb3Arg) : (f32) -> () + ^bb4(%bb4Arg : f32): + "foo.bb4Terminator"(%bb4Arg) : (f32) -> () +} + +// CHECK-LABEL: func @switch_on_const_with_match( +// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]] +func @switch_on_const_with_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) { + // add predecessors for all blocks to avoid other canonicalizations. + "foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4] : () -> () + ^bb1: + // CHECK-NOT: switch + // CHECK: br ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]] + %c0_i32 = constant 1 : i32 + switch %c0_i32 : i32, [ + default: ^bb2(%caseOperand0 : f32), + -1: ^bb3(%caseOperand1 : f32), + 1: ^bb4(%caseOperand2 : f32) + ] + ^bb2(%bb2Arg : f32): + "foo.bb2Terminator"(%bb2Arg) : (f32) -> () + ^bb3(%bb3Arg : f32): + "foo.bb3Terminator"(%bb3Arg) : (f32) -> () + // CHECK: ^[[BB4]]({{.*}}): + // CHECK-NEXT: "foo.bb4Terminator" + ^bb4(%bb4Arg : f32): + "foo.bb4Terminator"(%bb4Arg) : (f32) -> () +} + +// CHECK-LABEL: func @switch_passthrough( +// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_3:[a-zA-Z0-9_]+]] +func @switch_passthrough(%flag : i32, + %caseOperand0 : f32, + %caseOperand1 : f32, + %caseOperand2 : f32, + %caseOperand3 : f32) { + // add predecessors for all blocks to avoid other canonicalizations. + "foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4, ^bb5, ^bb6] : () -> () + + ^bb1: + // CHECK: switch %[[FLAG]] + // CHECK-NEXT: default: ^[[BB5:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]] + // CHECK-NEXT: 43: ^[[BB6:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_1]] + // CHECK-NEXT: 44: ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]] + // CHECK-NEXT: ] + switch %flag : i32, [ + default: ^bb2(%caseOperand0 : f32), + 43: ^bb3(%caseOperand1 : f32), + 44: ^bb4(%caseOperand2 : f32) + ] + ^bb2(%bb2Arg : f32): + br ^bb5(%bb2Arg : f32) + ^bb3(%bb3Arg : f32): + br ^bb6(%bb3Arg : f32) + ^bb4(%bb4Arg : f32): + "foo.bb4Terminator"(%bb4Arg) : (f32) -> () + + // CHECK: ^[[BB5]]({{.*}}): + // CHECK-NEXT: "foo.bb5Terminator" + ^bb5(%bb5Arg : f32): + "foo.bb5Terminator"(%bb5Arg) : (f32) -> () + + // CHECK: ^[[BB6]]({{.*}}): + // CHECK-NEXT: "foo.bb6Terminator" + ^bb6(%bb6Arg : f32): + "foo.bb6Terminator"(%bb6Arg) : (f32) -> () +} + +// CHECK-LABEL: func @switch_from_switch_with_same_value_with_match( +// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]] +func @switch_from_switch_with_same_value_with_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) { + // add predecessors for all blocks except ^bb3 to avoid other canonicalizations. + "foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5] : () -> () + + ^bb1: + // CHECK: switch %[[FLAG]] + switch %flag : i32, [ + default: ^bb2, + 42: ^bb3 + ] + + ^bb2: + "foo.bb2Terminator"() : () -> () + ^bb3: + // prevent this block from being simplified away + "foo.op"() : () -> () + // CHECK-NOT: switch %[[FLAG]] + // CHECK: br ^[[BB5:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_1]] + switch %flag : i32, [ + default: ^bb4(%caseOperand0 : f32), + 42: ^bb5(%caseOperand1 : f32) + ] + + ^bb4(%bb4Arg : f32): + "foo.bb4Terminator"(%bb4Arg) : (f32) -> () + + // CHECK: ^[[BB5]]({{.*}}): + // CHECK-NEXT: "foo.bb5Terminator" + ^bb5(%bb5Arg : f32): + "foo.bb5Terminator"(%bb5Arg) : (f32) -> () +} + +// CHECK-LABEL: func @switch_from_switch_with_same_value_no_match( +// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]] +func @switch_from_switch_with_same_value_no_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) { + // add predecessors for all blocks except ^bb3 to avoid other canonicalizations. + "foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5, ^bb6] : () -> () + + ^bb1: + // CHECK: switch %[[FLAG]] + switch %flag : i32, [ + default: ^bb2, + 42: ^bb3 + ] + + ^bb2: + "foo.bb2Terminator"() : () -> () + ^bb3: + "foo.op"() : () -> () + // CHECK-NOT: switch %[[FLAG]] + // CHECK: br ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]] + switch %flag : i32, [ + default: ^bb4(%caseOperand0 : f32), + 0: ^bb5(%caseOperand1 : f32), + 43: ^bb6(%caseOperand2 : f32) + ] + + // CHECK: ^[[BB4]]({{.*}}) + // CHECK-NEXT: "foo.bb4Terminator" + ^bb4(%bb4Arg : f32): + "foo.bb4Terminator"(%bb4Arg) : (f32) -> () + + ^bb5(%bb5Arg : f32): + "foo.bb5Terminator"(%bb5Arg) : (f32) -> () + + ^bb6(%bb6Arg : f32): + "foo.bb6Terminator"(%bb6Arg) : (f32) -> () +} + +// CHECK-LABEL: func @switch_from_switch_default_with_same_value( +// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]] +func @switch_from_switch_default_with_same_value(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) { + // add predecessors for all blocks except ^bb3 to avoid other canonicalizations. + "foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5, ^bb6] : () -> () + + ^bb1: + // CHECK: switch %[[FLAG]] + switch %flag : i32, [ + default: ^bb3, + 42: ^bb2 + ] + + ^bb2: + "foo.bb2Terminator"() : () -> () + ^bb3: + "foo.op"() : () -> () + // CHECK: switch %[[FLAG]] + // CHECK-NEXT: default: ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]] + // CHECK-NEXT: 43: ^[[BB6:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]] + // CHECK-NOT: 42 + switch %flag : i32, [ + default: ^bb4(%caseOperand0 : f32), + 42: ^bb5(%caseOperand1 : f32), + 43: ^bb6(%caseOperand2 : f32) + ] + + // CHECK: ^[[BB4]]({{.*}}): + // CHECK-NEXT: "foo.bb4Terminator" + ^bb4(%bb4Arg : f32): + "foo.bb4Terminator"(%bb4Arg) : (f32) -> () + + ^bb5(%bb5Arg : f32): + "foo.bb5Terminator"(%bb5Arg) : (f32) -> () + + // CHECK: ^[[BB6]]({{.*}}): + // CHECK-NEXT: "foo.bb6Terminator" + ^bb6(%bb6Arg : f32): + "foo.bb6Terminator"(%bb6Arg) : (f32) -> () +} + /// Test folding conditional branches that are successors of conditional /// branches with the same condition. diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -96,3 +96,35 @@ %1 = memref.tensor_load %0 : memref<2xf32> return } + +// CHECK-LABEL: func @switch( +func @switch(%flag : i32, %caseOperand : i32) { + switch %flag : i32, [ + default: ^bb1(%caseOperand : i32), + 42: ^bb2(%caseOperand : i32), + 43: ^bb3(%caseOperand : i32) + ] + + ^bb1(%bb1arg : i32): + return + ^bb2(%bb2arg : i32): + return + ^bb3(%bb3arg : i32): + return +} + +// CHECK-LABEL: func @switch_i64( +func @switch_i64(%flag : i64, %caseOperand : i32) { + switch %flag : i64, [ + default: ^bb1(%caseOperand : i32), + 42: ^bb2(%caseOperand : i32), + 43: ^bb3(%caseOperand : i32) + ] + + ^bb1(%bb1arg : i32): + return + ^bb2(%bb2arg : i32): + return + ^bb3(%bb3arg : i32): + return +} diff --git a/mlir/test/Dialect/Standard/parser.mlir b/mlir/test/Dialect/Standard/parser.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/parser.mlir @@ -0,0 +1,69 @@ +// RUN: mlir-opt -verify-diagnostics -split-input-file %s + +func @switch_missing_case_value(%flag : i32, %caseOperand : i32) { + switch %flag : i32, [ + default: ^bb1(%caseOperand : i32), + 45: ^bb2(%caseOperand : i32), + // expected-error@+1 {{expected integer value}} + : ^bb3(%caseOperand : i32) + ] + + ^bb1(%bb1arg : i32): + return + ^bb2(%bb2arg : i32): + return + ^bb3(%bb3arg : i32): + return +} + +// ----- + +func @switch_wrong_type_case_value(%flag : i32, %caseOperand : i32) { + switch %flag : i32, [ + default: ^bb1(%caseOperand : i32), + // expected-error@+1 {{expected integer value}} + "hello": ^bb2(%caseOperand : i32) + ] + + ^bb1(%bb1arg : i32): + return + ^bb2(%bb2arg : i32): + return + ^bb3(%bb3arg : i32): + return +} + +// ----- + +func @switch_missing_comma(%flag : i32, %caseOperand : i32) { + switch %flag : i32, [ + default: ^bb1(%caseOperand : i32), + 45: ^bb2(%caseOperand : i32) + // expected-error@+1 {{expected ']'}} + 43: ^bb3(%caseOperand : i32) + ] + + ^bb1(%bb1arg : i32): + return + ^bb2(%bb2arg : i32): + return + ^bb3(%bb3arg : i32): + return +} + +// ----- + +func @switch_missing_default(%flag : i32, %caseOperand : i32) { + switch %flag : i32, [ + // expected-error@+1 {{expected 'default'}} + 45: ^bb2(%caseOperand : i32) + 43: ^bb3(%caseOperand : i32) + ] + + ^bb1(%bb1arg : i32): + return + ^bb2(%bb2arg : i32): + return + ^bb3(%bb3arg : i32): + return +}