diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -952,8 +952,8 @@ let assemblyFormat = [{ $value `:` type($value) `,` $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)? - `[` `\n` custom(ref(type($value)), $case_values, $caseDestinations, - $caseOperands, type($caseOperands)) `]` + custom(ref(type($value)), $case_values, $caseDestinations, + $caseOperands, type($caseOperands)) attr-dict }]; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -399,22 +399,22 @@ caseValuesAttr, caseDestinations, caseOperands, branchWeights); } -/// ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? -/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? +/// ::= `[` (case (`,` case )* )? `]` +/// ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? static ParseResult parseSwitchOpCases( OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues, SmallVectorImpl &caseDestinations, SmallVectorImpl> &caseOperands, SmallVectorImpl> &caseOperandTypes) { + if (failed(parser.parseLSquare())) + return failure(); + if (succeeded(parser.parseOptionalRSquare())) + return success(); SmallVector values; unsigned bitWidth = flagType.getIntOrFloatBitWidth(); - do { + auto parseCase = [&]() { int64_t value = 0; - OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); - if (values.empty() && !integerParseResult.has_value()) - return success(); - - if (!integerParseResult.has_value() || integerParseResult.value()) + if (failed(parser.parseInteger(value))) return failure(); values.push_back(APInt(bitWidth, value)); @@ -432,12 +432,15 @@ caseDestinations.push_back(destination); caseOperands.emplace_back(operands); caseOperandTypes.emplace_back(operandTypes); - } while (!parser.parseOptionalComma()); + return success(); + }; + if (failed(parser.parseCommaSeparatedList(parseCase))) + return failure(); ShapedType caseValueType = VectorType::get(static_cast(values.size()), flagType); caseValues = DenseIntElementsAttr::get(caseValueType, values); - return success(); + return parser.parseRSquare(); } static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, @@ -445,8 +448,12 @@ SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) { - if (!caseValues) + p << '['; + p.printNewline(); + if (!caseValues) { + p << ']'; return; + } size_t index = 0; llvm::interleave( @@ -462,6 +469,7 @@ p.printNewline(); }); p.printNewline(); + p << ']'; } LogicalResult SwitchOp::verify() { diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -857,6 +857,19 @@ // ----- +func.func @switch_superfluous_comma(%arg0 : i64) { + // expected-error@+3 {{custom op 'llvm.switch' expected integer value}} + llvm.switch %arg0 : i32, ^bb1 [ + 42: ^bb2, + ] +^bb1: + llvm.return +^bb2: + llvm.return +} + +// ----- + func.func @switch_wrong_number_of_weights(%arg0 : i32) { // expected-error@+1 {{expects number of branch weights to match number of successors: 3 vs 2}} llvm.switch %arg0 : i32, ^bb1 [