diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -595,6 +595,102 @@ let printer = [{ p << getOperationName(); }]; } +def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", + [AttrSizedOperandSegments, DeclareOpInterfaceMethods, + NoSideEffect]> { + let arguments = (ins LLVM_i32:$value, + Variadic:$defaultOperands, + Variadic:$caseOperands, + OptionalAttr:$caseOperandOffsets, + OptionalAttr:$branch_weights); + let successors = (successor + AnySuccessor:$defaultDestination, + VariadicSuccessor:$caseDestinations); + + let verifier = [{ + if (getCaseValuesAttr().size() != caseDestinations().size()) + return emitOpError("expects number of case values to match number of " + "case destinations"); + if (caseDestinations().size() == 0) + return emitOpError("expects at least one case"); + return success(); + }]; + + let printer = [{ printSwitchOp(p, *this); }]; + let parser = [{ return parseSwitchOp(parser, result); }]; + + let builders = [ + OpBuilderDAG<(ins "Value":$value, + "Block *":$defaultDestination, + "ValueRange":$defaultOperands, + "ArrayRef":$caseValues, + "BlockRange":$caseDestinations, + CArg<"ArrayRef", "{}">:$caseOperands, + CArg<"ArrayRef", "{}">:$branch_weights), + [{ + llvm::SmallVector caseValueAttrs; + for (auto caseValue : caseValues) + caseValueAttrs.push_back($_builder.getI32IntegerAttr(caseValue)); + $_state.addAttribute(getCaseValuesAttrName(), + $_builder.getArrayAttr(caseValueAttrs)); + + llvm::SmallVector flattenedCaseOperands; + llvm::SmallVector caseOperandOffsets; + int32_t offset = 0; + for (auto operands : caseOperands) { + flattenedCaseOperands.append(operands.begin(), operands.end()); + caseOperandOffsets.push_back(offset); + offset += operands.size(); + } + ElementsAttr caseOperandOffsetsAttr; + if (!caseOperandOffsets.empty()) + caseOperandOffsetsAttr = + $_builder.getI32VectorAttr(caseOperandOffsets); + + llvm::SmallVector weights; + for (auto branchWeight : branch_weights) + weights.push_back(static_cast(branchWeight)); + ElementsAttr weightsAttr; + if (!weights.empty()) + weightsAttr = $_builder.getI32VectorAttr(weights); + + build($_builder, $_state, value, defaultOperands, flattenedCaseOperands, + caseOperandOffsetsAttr, weightsAttr, defaultDestination, + caseDestinations); + }]>, + LLVM_TerminatorPassthroughOpBuilder + ]; + + let extraClassDeclaration = [{ + /// Return the name of the attribute used for case values. + static llvm::StringRef getCaseValuesAttrName() { return "case_values"; } + + /// Return the array of case values. + ArrayAttr getCaseValuesAttr() { + return (*this)->getAttrOfType(getCaseValuesAttrName()); + } + + /// Return the name of the attribute used to delineate ranges of case + /// destination operands. + static llvm::StringRef getCaseOperandOffsetsAttrName() { + return "caseOperandOffsets"; + } + + /// Return the [begin, end) offsets of the case operands for the case + /// destination at the given index, or none if there are no operands for + /// that case. + Optional> + getCaseOperandOffsets(unsigned index); + + /// Return the operands for the case destination block at the given index. + Optional getCaseOperands(unsigned index); + + /// Return a mutable range of operands for the case destination block at the + /// given index. + Optional getCaseOperandsMutable(unsigned index); + }]; +} + //////////////////////////////////////////////////////////////////////////////// // Auxiliary operations (do not appear in LLVM IR but necessary for the dialect // to work correctly). diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -189,6 +189,125 @@ return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable(); } +//===----------------------------------------------------------------------===// +// LLVM::SwitchOp +//===----------------------------------------------------------------------===// + +static void printSwitchOp(OpAsmPrinter &p, SwitchOp &op) { + p << op.getOperationName() << ' ' << op.value() << ", "; + p.printSuccessorAndUseList(op.defaultDestination(), op.defaultOperands()); + p << ", "; + + p.printAttribute(op.getCaseValuesAttr()); + p << ", "; + + p << '['; + size_t index = 0; + llvm::interleaveComma(op.caseDestinations(), p, [&](auto *destination) { + if (auto operands = op.getCaseOperands(index++)) + p.printSuccessorAndUseList(destination, operands.getValue()); + else + p.printSuccessor(destination); + }); + p << ']'; + + p.printOptionalAttrDict(op.getAttrs(), {op.getOperandSegmentSizeAttr(), + op.getCaseValuesAttrName(), + op.getCaseOperandOffsetsAttrName()}); +} + +/// ::= `llvm.switch` +/// ssa-use `,` bb-id (`(` ssa-use-and-type-list `)`)? `,` +/// `[` integer-literal (`,` integer-literal)* `]` `,` +/// `[` bb-id (`(` ssa-use-and-type-list `)`)? +/// (`,` bb-id (`(` ssa-use-and-type-list `)`)? )? `]` +/// attribute-dict? +static ParseResult parseSwitchOp(OpAsmParser &parser, OperationState &result) { + auto &builder = parser.getBuilder(); + + OpAsmParser::OperandType value; + Block *defaultDestination, *caseDestination; + llvm::SmallVector defaultOperands; + ArrayAttr caseValuesAttr; + if (parser.parseOperand(value) || parser.parseComma() || + parser.parseSuccessorAndUseList(defaultDestination, defaultOperands) || + parser.parseComma() || + parser.parseAttribute(caseValuesAttr, SwitchOp::getCaseValuesAttrName(), + result.attributes) || + parser.parseComma() || parser.parseLSquare()) + return failure(); + + llvm::SmallVector caseDestinations; + llvm::SmallVector caseOperands; + llvm::SmallVector caseOperandOffsets; + int32_t offset = 0; + do { + if (parser.parseSuccessorAndUseList(caseDestination, caseOperands)) + return failure(); + caseDestinations.push_back(caseDestination); + caseOperandOffsets.push_back(offset); + offset += caseOperands.size(); + } while (!parser.parseOptionalComma()); + + if (parser.parseRSquare() || + parser.parseOptionalAttrDict(result.attributes) || + parser.resolveOperand(value, LLVMType::getInt32Ty(builder.getContext()), + result.operands)) + return failure(); + + result.addSuccessors(defaultDestination); + result.addOperands(defaultOperands); + + result.addSuccessors(caseDestinations); + result.addOperands(caseOperands); + result.addAttribute(SwitchOp::getCaseOperandOffsetsAttrName(), + builder.getI32VectorAttr(caseOperandOffsets)); + + result.addAttribute( + SwitchOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({1, static_cast(defaultOperands.size()), + static_cast(caseOperands.size())})); + + return success(); +} + +Optional> +SwitchOp::getCaseOperandOffsets(unsigned index) { + auto caseOperands = caseOperandsMutable(); + if (!caseOperandOffsets()) + return llvm::None; + + auto offsets = caseOperandOffsets().getValue(); + if (index >= offsets.size()) + return llvm::None; + + auto begin = offsets.getValue(index).cast().getInt(); + auto end = index + 1 == offsets.size() + ? caseOperands.size() + : offsets.getValue(index + 1).cast().getInt(); + return std::make_pair(static_cast(begin), + static_cast(end - begin)); +} + +Optional SwitchOp::getCaseOperands(unsigned index) { + if (auto offsets = getCaseOperandOffsets(index)) + return caseOperands().slice(offsets->first, offsets->second); + return llvm::None; +} + +Optional SwitchOp::getCaseOperandsMutable(unsigned index) { + if (auto offsets = getCaseOperandOffsets(index)) + return caseOperandsMutable().slice(offsets->first, offsets->second); + return llvm::None; +} + +Optional +SwitchOp::getMutableSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return index == 0 ? defaultOperandsMutable() + : getCaseOperandsMutable(index - 1); +} + //===----------------------------------------------------------------------===// // Builder, printer and parser for for LLVM::LoadOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -323,19 +323,44 @@ if (isa(terminator)) return terminator.getOperand(index); - // For conditional branches, we need to check if the current block is reached - // through the "true" or the "false" branch and take the relevant operands. - auto condBranchOp = dyn_cast(terminator); - assert(condBranchOp && - "only branch operations can be terminators of a block that " - "has successors"); - assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) && - "successors with arguments in LLVM conditional branches must be " - "different blocks"); - - return condBranchOp.getSuccessor(0) == current - ? condBranchOp.trueDestOperands()[index] - : condBranchOp.falseDestOperands()[index]; + // For instructions that branch based on a condition value, we need to take + // the operands for the branch that was taken. + if (auto condBranchOp = dyn_cast(terminator)) { + // For conditional branches, we take the operands from either the "true" or + // the "false" branch. + assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) && + "successors with arguments in LLVM conditional branches must be " + "different blocks"); + + return condBranchOp.getSuccessor(0) == current + ? condBranchOp.trueDestOperands()[index] + : condBranchOp.falseDestOperands()[index]; + } else if (auto switchOp = dyn_cast(terminator)) { + // For switches, we take the operands from either the default case, or from + // the case branch that was taken. +#ifndef NDEBUG + for (size_t i = 0; i < switchOp.caseDestinations().size(); ++i) { + assert(switchOp.defaultDestination() != switchOp.caseDestinations()[i] && + "successors with arguments in LLVM switch instructions must be " + "different blocks"); + for (size_t j = i + 1; j < switchOp.caseDestinations().size(); ++j) { + assert(switchOp.caseDestinations()[i] != + switchOp.caseDestinations()[j] && + "successors with arguments in LLVM switch instructions must be " + "different blocks"); + } + } +#endif + + if (switchOp.defaultDestination() == current) + return switchOp.defaultOperands()[index]; + for (auto i : llvm::enumerate(switchOp.caseDestinations())) + if (i.value() == current) + return switchOp.getCaseOperands(i.index()).getValue()[index]; + } + + llvm_unreachable("only branch or switch operations can be terminators of a " + "block that has successors"); } /// Connect the PHI nodes to the results of preceding blocks. @@ -704,6 +729,34 @@ blockMapping[condbrOp.getSuccessor(1)], branchWeights); return success(); } + if (auto switchOp = dyn_cast(opInst)) { + auto weights = switchOp.branch_weights(); + llvm::MDNode *branchWeights = nullptr; + if (weights) { + llvm::SmallVector weightValues; + weightValues.reserve(weights->size()); + for (int64_t i = 0; i < weights->size(); ++i) + weightValues.push_back(static_cast( + weights->getValue(i).cast().getInt())); + branchWeights = llvm::MDBuilder(llvmModule->getContext()) + .createBranchWeights(weightValues); + } + + auto *switchInst = + builder.CreateSwitch(valueMapping[switchOp.value()], + blockMapping[switchOp.defaultDestination()], + switchOp.caseDestinations().size(), branchWeights); + + auto *int32Type = llvm::Type::getInt32Ty(llvmModule->getContext()); + for (auto i : + llvm::zip(switchOp.getCaseValuesAttr(), switchOp.caseDestinations())) + switchInst->addCase( + llvm::ConstantInt::get(int32Type, + std::get<0>(i).cast().getInt()), + blockMapping[std::get<1>(i)]); + + return success(); + } // Emit addressof. We need to look up the global value referenced by the // operation and store it in the MLIR-to-LLVM value mapping. This does not diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -71,8 +71,8 @@ // CHECK: ^[[BB1]] ^bb1: -// CHECK: llvm.cond_br %7, ^[[BB2:.*]], ^[[BB1]] - llvm.cond_br %7, ^bb2, ^bb1 +// CHECK: llvm.cond_br %7, ^[[BB2:.*]], ^[[BB3:.*]] + llvm.cond_br %7, ^bb2, ^bb3 // CHECK: ^[[BB2]] ^bb2: @@ -80,7 +80,24 @@ // CHECK: %{{.*}} = llvm.mlir.constant(42 : i64) : !llvm.i47 %22 = llvm.mlir.undef : !llvm.struct<(i32, double, i32)> %23 = llvm.mlir.constant(42) : !llvm.i47 + // CHECK: llvm.switch %0, ^[[BB3]], [1, 2], [^[[BB4:.*]], ^[[BB5:.*]]] + llvm.switch %0, ^bb3, [1, 2], [^bb4, ^bb5] +// CHECK: ^[[BB3]] +^bb3: +// CHECK: llvm.br ^[[BB6:.*]] + llvm.br ^bb6 + +// CHECK: ^[[BB4]] +^bb4: + llvm.br ^bb6 + +// CHECK: ^[[BB5]] +^bb5: + llvm.br ^bb6 + +// CHECK: ^[[BB6]] +^bb6: // Misc operations. // CHECK: %{{.*}} = llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i32 %24 = llvm.select %7, %0, %1 : !llvm.i1, !llvm.i32 diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1358,3 +1358,41 @@ llvm.return } +// ----- + +// CHECK-LABEL: @switch +llvm.func @switch(%arg0: !llvm.i32) { + %0 = llvm.mlir.constant(5 : i32) : !llvm.i32 + %1 = llvm.mlir.constant(7 : i32) : !llvm.i32 + %2 = llvm.mlir.constant(11 : i32) : !llvm.i32 + llvm.switch %arg0, ^bb1, [-1, 1], [^bb2(%0 : !llvm.i32), ^bb3(%1, %2 : !llvm.i32, !llvm.i32)] + +^bb1: // pred: ^bb0 + llvm.return %arg0 : !llvm.i32 + +^bb2(%3: !llvm.i32): // pred: ^bb0 + llvm.return %1 : !llvm.i32 + +^bb3(%4: !llvm.i32, %5: !llvm.i32): // pred: ^bb0 + llvm.return %4 : !llvm.i32 +} + +// CHECK-LABEL: @switchWeights +llvm.func @switchWeights(%arg0: !llvm.i32) { + %0 = llvm.mlir.constant(19 : i32) : !llvm.i32 + %1 = llvm.mlir.constant(23 : i32) : !llvm.i32 + %2 = llvm.mlir.constant(29 : i32) : !llvm.i32 + // CHECK: !prof ![[SWITCH_WEIGHT_NODE:[0-9]+]] + llvm.switch %arg0, ^bb1(%0 : !llvm.i32), [9, 99], [^bb2(%1, %2 : !llvm.i32, !llvm.i32), ^bb3] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>} + +^bb1(%3: !llvm.i32): // pred: ^bb0 + llvm.return %3 : !llvm.i32 + +^bb2(%4: !llvm.i32, %5: !llvm.i32): // pred: ^bb0 + llvm.return %5 : !llvm.i32 + +^bb3: // pred: ^bb0 + llvm.return %arg0 : !llvm.i32 +} + +// CHECK: ![[SWITCH_WEIGHT_NODE]] = !{!"branch_weights", i32 13, i32 17, i32 19}