Index: mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -480,6 +480,14 @@ let parser = [{ return parseReturnOp(parser, result); }]; let printer = [{ printReturnOp(p, *this); }]; } +def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", []> +{ + let arguments = (ins Variadic:$args); + let successors = (successor VariadicSuccessor:$destinations); + let parser = [{ return parseSwitchOp(parser, result); }]; + let printer = [{ printSwitchOp(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; +} def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> { string llvmBuilder = [{ builder.CreateUnreachable(); }]; let parser = [{ return success(); }]; Index: mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -813,6 +813,84 @@ return success(); } +//===----------------------------------------------------------------------===// +// Verifying/Printing/Parsing for LLVM::SwitchOp. +//===----------------------------------------------------------------------===// + +static LogicalResult verify(SwitchOp op) { + if (!op.getOperand(0).getType().cast().isIntegerTy()) + return op.emitOpError() << "a condition value should be an integer"; + + // FIXME: Prohibit any block arguments for now. + for (auto succ : op.destinations()) + if (!succ->getArguments().empty()) + return op.emitOpError() << "can't use a block with arguments as a " + "successor of SwitchOp now"; + + for (unsigned int idx = 1; idx < op.getNumSuccessors(); idx++) { + if (ConstantOp constantOp = + dyn_cast_or_null(op.getOperand(idx).getDefiningOp())) { + // FIXME : Need to check the case values are disjoint. + } else + return op.emitOpError() << "condition values should be constants"; + } + + return success(); +} + +static void printSwitchOp(OpAsmPrinter &p, SwitchOp &op) { + p << op.getOperationName(); + p.printOptionalAttrDict(op.getAttrs()); + p << ' ' << op.getOperand(0) << ", "; + p.printSuccessorAndUseList(op, 0); + p << " [ "; + interleaveComma(llvm::seq(1, op.getNumSuccessors()), p, [&](int i) { + p << op.getOperand(i) << ", "; + p.printSuccessorAndUseList(op, i); + }); + p << " ] : " << op.getOperand(0).getType(); +} + +// ::= `llvm.switch` ssa-use `,` bb-id (`[` ssa-use-and-type-list +// `]`)?, `[`(ssa-use `,` bb-id (`[` ssa-use-and-type-list +// `]`)?)+ `]` attribute-dict? `:` type +static ParseResult parseSwitchOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType switchVal; + Type switchTy; + SmallVector defaultOperands; + Block *defaultBlock; + SmallVector>, 2> fullSuccessors; + SmallVector conditions; + + if (parser.parseOperand(switchVal) || parser.parseComma() || + parser.parseSuccessorAndUseList(defaultBlock, defaultOperands) || + parser.parseLSquare()) + return failure(); + + do { + Block *dest; + OpAsmParser::OperandType condVal; + SmallVector oper; + if (parser.parseOperand(condVal)||parser.parseComma()||parser.parseSuccessorAndUseList(dest, oper)) + return failure(); + fullSuccessors.emplace_back(dest, oper); + conditions.push_back(condVal); + } while(succeeded(parser.parseOptionalComma())); + + if (parser.parseRSquare() || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(switchTy) || + parser.resolveOperand(switchVal, switchTy, result.operands) || + parser.resolveOperands(conditions, switchTy, result.operands)) + return failure(); + + result.addSuccessor(defaultBlock, defaultOperands); + for (auto &succAndArgs : fullSuccessors) + result.addSuccessor(succAndArgs.first, succAndArgs.second); + + return success(); +} + //===----------------------------------------------------------------------===// // Verifier for LLVM::AddressOfOp. //===----------------------------------------------------------------------===// Index: mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp =================================================================== --- mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -744,6 +744,47 @@ v = b.create(loc, type, ops, ArrayRef()); return success(); } + case llvm::Instruction::Switch: { + llvm::SwitchInst *switchInst = cast(inst); + SmallVector ops; + SmallVector>, 4> fullSuccessors; + OperationState state(loc, "llvm.switch"); + + // Operands of llvm::SwitchInst consists of a condition, case values and + // blocks. + // + // Operand[0] = Value to switch on + // Operand[1] = Default basic block destination + // Operand[2n ] = Value to match + // Operand[2n+1] = BasicBlock to go to on match + // + // So, for LLVM::SwitchInst, successors are Operand[1, 3, 5, ...] and + // operands are Operand[0, 2, 4 ..]. + + bool isBlock = false; + for (auto *op : switchInst->operand_values()) { + if (isBlock) { + SmallVector blockArguments; + llvm::BasicBlock *dest = cast(op); + if (failed(processBranchArgs(switchInst, dest, blockArguments))) + return failure(); + + fullSuccessors.emplace_back(blocks[dest], blockArguments); + } else { + Value value = processValue(op); + if (!value) + return failure(); + ops.push_back(value); + } + isBlock = !isBlock; + } + state.addOperands(ops); + for (auto succAndArgs : fullSuccessors) + state.addSuccessor(succAndArgs.first, succAndArgs.second); + + b.createOperation(state); + return success(); + } } } Index: mlir/lib/Target/LLVMIR/ModuleTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -362,6 +362,18 @@ return success(); } + if (auto switchOp = dyn_cast(opInst)) { + llvm::SwitchInst *switchInst = + builder.CreateSwitch(valueMapping.lookup(switchOp.getOperand(0)), + blockMapping[switchOp.getSuccessor(0)], + switchOp.getNumSuccessors() - 1); + for (unsigned i = 1; i < switchOp.getNumSuccessors(); i++) + switchInst->addCase( + cast(valueMapping.lookup(switchOp.getOperand(i))), + blockMapping[switchOp.getSuccessor(i)]); + return success(); + } + // Emit addressof. We need to look up the global value referenced by the // operation and store it in the MLIR-to-LLVM value mapping. This does not // emit any LLVM instruction. Index: mlir/test/Dialect/LLVMIR/terminator.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/terminator.mlir +++ mlir/test/Dialect/LLVMIR/terminator.mlir @@ -19,3 +19,14 @@ llvm.return } +// CHECK-LABEL: @switch +// check: llvm.switch +func @switch(%n : !llvm.i32) { + %0 = llvm.mlir.constant (0: i32) : !llvm.i32 + %1 = llvm.mlir.constant (1: i32) : !llvm.i32 + llvm.switch %n, ^bb1 [ %0, ^bb1, %1, ^bb2]: !llvm.i32 +^bb1: + llvm.return +^bb2: + llvm.return +} Index: mlir/test/Target/llvmir.mlir =================================================================== --- mlir/test/Target/llvmir.mlir +++ mlir/test/Target/llvmir.mlir @@ -1171,3 +1171,28 @@ ^bb3: // pred: ^bb1 %8 = llvm.invoke @bar(%6) to ^bb2 unwind ^bb1 : (!llvm<"i8*">) -> !llvm<"i8*"> } + +// CHECK-LABEL: @switch +// CHECK-SAME: (i32 %[[arg0:[0-9]+]]) +llvm.func @switch(%arg0: !llvm.i32) -> !llvm.i32 { + %0 = llvm.mlir.constant(0 : i32) : !llvm.i32 + %1 = llvm.mlir.constant(1 : i32) : !llvm.i32 + %2 = llvm.mlir.constant(2 : i32) : !llvm.i32 + llvm.switch %arg0, ^bb1 [ %0, ^bb2, %1, ^bb3 ] : !llvm.i32 +// CHECK: switch i32 %[[arg0]], label %[[default:[0-9]+]] [ +// CHECK-NEXT: i32 0, label %[[case1:[0-9]+]] +// CHECK-NEXT: i32 1, label %[[case2:[0-9]+]] +// CHECK-NEXT: ] + +// CHECK: [[default]] +^bb1: // pred: ^bb0 + llvm.return %0 : !llvm.i32 + +// CHECK: [[case1]] +^bb2: // pred: ^bb0 + llvm.return %1 : !llvm.i32 + +// CHECK: [[case2]] +^bb3: // pred: ^bb0 + llvm.return %2 : !llvm.i32 +}