diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -595,6 +595,50 @@ let printer = [{ p << getOperationName(); }]; } +def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", + [AttrSizedOperandSegments, DeclareOpInterfaceMethods, + NoSideEffect]> { + let arguments = (ins LLVM_i32:$value, + Variadic:$defaultOperands, + Variadic:$caseOperands, + OptionalAttr:$case_values, + OptionalAttr:$case_operand_offsets, + OptionalAttr:$branch_weights); + let successors = (successor + AnySuccessor:$defaultDestination, + VariadicSuccessor:$caseDestinations); + + let verifier = [{ return ::verify(*this); }]; + let assemblyFormat = [{ + $value `,` + $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)? + `[` `\n` custom($case_values, $caseDestinations, + $caseOperands, type($caseOperands), + $case_operand_offsets) `]` + attr-dict + }]; + + let builders = [ + OpBuilderDAG<(ins "Value":$value, + "Block *":$defaultDestination, + "ValueRange":$defaultOperands, + "BlockRange":$caseDestinations, + CArg<"ArrayRef", "{}">:$caseValues, + CArg<"ArrayRef", "{}">:$caseOperands, + CArg<"ArrayRef", "{}">:$branchWeights)>, + LLVM_TerminatorPassthroughOpBuilder + ]; + + let extraClassDeclaration = [{ + /// Return the operands for the case destination block at the given index. + OperandRange getCaseOperands(unsigned index); + + /// Return a mutable range of operands for the case destination block at the + /// given index. + MutableOperandRange getCaseOperandsMutable(unsigned index); + }]; +} + //////////////////////////////////////////////////////////////////////////////// // Auxiliary operations (do not appear in LLVM IR but necessary for the dialect // to work correctly). diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -189,6 +189,149 @@ return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable(); } +//===----------------------------------------------------------------------===// +// LLVM::SwitchOp +//===----------------------------------------------------------------------===// + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, + Block *defaultDestination, ValueRange defaultOperands, + BlockRange caseDestinations, ArrayRef caseValues, + ArrayRef caseOperands, + ArrayRef branchWeights) { + SmallVector flattenedCaseOperands; + SmallVector caseOperandOffsets; + int32_t offset = 0; + for (ValueRange operands : caseOperands) { + flattenedCaseOperands.append(operands.begin(), operands.end()); + caseOperandOffsets.push_back(offset); + offset += operands.size(); + } + ElementsAttr caseValuesAttr; + if (!caseValues.empty()) + caseValuesAttr = builder.getI32VectorAttr(caseValues); + ElementsAttr caseOperandOffsetsAttr; + if (!caseOperandOffsets.empty()) + caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets); + + ElementsAttr weightsAttr; + if (!branchWeights.empty()) + weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); + + build(builder, result, value, defaultOperands, flattenedCaseOperands, + caseValuesAttr, caseOperandOffsetsAttr, weightsAttr, defaultDestination, + caseDestinations); +} + +/// ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? +/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? +static ParseResult +parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues, + SmallVectorImpl &caseDestinations, + SmallVectorImpl &caseOperands, + SmallVectorImpl &caseOperandTypes, + ElementsAttr &caseOperandOffsets) { + SmallVector values; + SmallVector offsets; + int32_t value, offset = 0; + do { + OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); + if (values.empty() && !integerParseResult.hasValue()) + return success(); + + if (!integerParseResult.hasValue() || integerParseResult.getValue()) + return failure(); + values.push_back(value); + + Block *destination; + SmallVector operands; + if (parser.parseColon() || parser.parseSuccessor(destination)) + return failure(); + if (!parser.parseOptionalLParen()) { + if (parser.parseRegionArgumentList(operands) || + parser.parseColonTypeList(caseOperandTypes) || parser.parseRParen()) + return failure(); + } + caseDestinations.push_back(destination); + caseOperands.append(operands.begin(), operands.end()); + offsets.push_back(offset); + offset += operands.size(); + } while (!parser.parseOptionalComma()); + + Builder &builder = parser.getBuilder(); + caseValues = builder.getI32VectorAttr(values); + caseOperandOffsets = builder.getI32VectorAttr(offsets); + + return success(); +} + +static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, + ElementsAttr caseValues, + SuccessorRange caseDestinations, + OperandRange caseOperands, + TypeRange caseOperandTypes, + ElementsAttr caseOperandOffsets) { + if (!caseValues) + return; + + size_t index = 0; + llvm::interleave( + llvm::zip(caseValues.cast(), caseDestinations), + [&](auto i) { + p << std::get<0>(i).getLimitedValue(); + p << ": "; + p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++)); + }, + [&] { + p << ','; + p.printNewline(); + }); + p.printNewline(); +} + +static LogicalResult verify(SwitchOp op) { + if ((!op.case_values() && !op.caseDestinations().empty()) || + (op.case_values() && + op.case_values()->size() != + static_cast(op.caseDestinations().size()))) + return op.emitOpError("expects number of case values to match number of " + "case destinations"); + if (op.branch_weights() && + op.branch_weights()->size() != op.getNumSuccessors()) + return op.emitError("expects number of branch weights to match number of " + "successors: ") + << op.branch_weights()->size() << " vs " << op.getNumSuccessors(); + return success(); +} + +OperandRange SwitchOp::getCaseOperands(unsigned index) { + return getCaseOperandsMutable(index); +} + +MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) { + MutableOperandRange caseOperands = caseOperandsMutable(); + if (!case_operand_offsets()) { + assert(caseOperands.size() == 0 && + "non-empty case operands must have offsets"); + return caseOperands; + } + + ElementsAttr offsets = case_operand_offsets().getValue(); + assert(index < offsets.size() && "invalid case operand offset index"); + + int64_t begin = offsets.getValue(index).cast().getInt(); + int64_t end = index + 1 == offsets.size() + ? caseOperands.size() + : offsets.getValue(index + 1).cast().getInt(); + return caseOperandsMutable().slice(begin, end - begin); +} + +Optional +SwitchOp::getMutableSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return index == 0 ? defaultOperandsMutable() + : getCaseOperandsMutable(index - 1); +} + //===----------------------------------------------------------------------===// // Builder, printer and parser for for LLVM::LoadOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -323,19 +323,31 @@ if (isa(terminator)) return terminator.getOperand(index); - // For conditional branches, we need to check if the current block is reached - // through the "true" or the "false" branch and take the relevant operands. - auto condBranchOp = dyn_cast(terminator); - assert(condBranchOp && - "only branch operations can be terminators of a block that " - "has successors"); - assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) && - "successors with arguments in LLVM conditional branches must be " - "different blocks"); - - return condBranchOp.getSuccessor(0) == current - ? condBranchOp.trueDestOperands()[index] - : condBranchOp.falseDestOperands()[index]; + SuccessorRange successors = terminator.getSuccessors(); + assert(std::adjacent_find(successors.begin(), successors.end()) == + successors.end() && + "successors with arguments in LLVM branches must be different blocks"); + + // For instructions that branch based on a condition value, we need to take + // the operands for the branch that was taken. + if (auto condBranchOp = dyn_cast(terminator)) { + // For conditional branches, we take the operands from either the "true" or + // the "false" branch. + return condBranchOp.getSuccessor(0) == current + ? condBranchOp.trueDestOperands()[index] + : condBranchOp.falseDestOperands()[index]; + } else if (auto switchOp = dyn_cast(terminator)) { + // For switches, we take the operands from either the default case, or from + // the case branch that was taken. + if (switchOp.defaultDestination() == current) + return switchOp.defaultOperands()[index]; + for (auto i : llvm::enumerate(switchOp.caseDestinations())) + if (i.value() == current) + return switchOp.getCaseOperands(i.index())[index]; + } + + llvm_unreachable("only branch or switch operations can be terminators of a " + "block that has successors"); } /// Connect the PHI nodes to the results of preceding blocks. @@ -717,6 +729,34 @@ branchMapping.try_emplace(&opInst, branch); return success(); } + if (auto switchOp = dyn_cast(opInst)) { + llvm::MDNode *branchWeights = nullptr; + if (auto weights = switchOp.branch_weights()) { + llvm::SmallVector weightValues; + weightValues.reserve(weights->size()); + for (llvm::APInt weight : weights->cast()) + weightValues.push_back(weight.getLimitedValue()); + branchWeights = llvm::MDBuilder(llvmModule->getContext()) + .createBranchWeights(weightValues); + } + + llvm::SwitchInst *switchInst = + builder.CreateSwitch(valueMapping[switchOp.value()], + blockMapping[switchOp.defaultDestination()], + switchOp.caseDestinations().size(), branchWeights); + + auto *ty = llvm::cast( + convertType(switchOp.value().getType().cast())); + for (auto i : + llvm::zip(switchOp.case_values()->cast(), + switchOp.caseDestinations())) + switchInst->addCase( + llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()), + blockMapping[std::get<1>(i)]); + + branchMapping.try_emplace(&opInst, switchInst); + return success(); + } // Emit addressof. We need to look up the global value referenced by the // operation and store it in the MLIR-to-LLVM value mapping. This does not diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -652,3 +652,18 @@ module attributes {llvm.data_layout = "#vjkr32"} { func @invalid_data_layout() } + +// ----- + +func @switch_wrong_number_of_weights(%arg0 : !llvm.i32) { + // expected-error@+1 {{expects number of branch weights to match number of successors: 3 vs 2}} + llvm.switch %arg0, ^bb1 [ + 42: ^bb2(%arg0, %arg0 : !llvm.i32, !llvm.i32) + ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>} + +^bb1: + llvm.return + +^bb2(%1: !llvm.i32, %2: !llvm.i32): // pred: *bb0 + llvm.return +} diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -71,8 +71,8 @@ // CHECK: ^[[BB1]] ^bb1: -// CHECK: llvm.cond_br %7, ^[[BB2:.*]], ^[[BB1]] - llvm.cond_br %7, ^bb2, ^bb1 +// CHECK: llvm.cond_br %7, ^[[BB2:.*]], ^[[BB3:.*]] + llvm.cond_br %7, ^bb2, ^bb3 // CHECK: ^[[BB2]] ^bb2: @@ -80,7 +80,41 @@ // CHECK: %{{.*}} = llvm.mlir.constant(42 : i64) : !llvm.i47 %22 = llvm.mlir.undef : !llvm.struct<(i32, double, i32)> %23 = llvm.mlir.constant(42) : !llvm.i47 + // CHECK: llvm.switch %0, ^[[BB3]] [ + // CHECK-NEXT: 1: ^[[BB4:.*]], + // CHECK-NEXT: 2: ^[[BB5:.*]], + // CHECK-NEXT: 3: ^[[BB6:.*]] + // CHECK-NEXT: ] + llvm.switch %0, ^bb3 [ + 1: ^bb4, + 2: ^bb5, + 3: ^bb6 + ] + +// CHECK: ^[[BB3]] +^bb3: +// CHECK: llvm.switch %0, ^[[BB7:.*]] [ +// CHECK-NEXT: ] + llvm.switch %0, ^bb7 [ + ] + +// CHECK: ^[[BB4]] +^bb4: + llvm.switch %0, ^bb7 [ + ] + +// CHECK: ^[[BB5]] +^bb5: + llvm.switch %0, ^bb7 [ + ] + +// CHECK: ^[[BB6]] +^bb6: + llvm.switch %0, ^bb7 [ + ] +// CHECK: ^[[BB7]] +^bb7: // Misc operations. // CHECK: %{{.*}} = llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i32 %24 = llvm.select %7, %0, %1 : !llvm.i1, !llvm.i32 diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1358,3 +1358,60 @@ llvm.return } +// ----- + +// CHECK-LABEL: @switch_args +llvm.func @switch_args(%arg0: !llvm.i32) { + %0 = llvm.mlir.constant(5 : i32) : !llvm.i32 + %1 = llvm.mlir.constant(7 : i32) : !llvm.i32 + %2 = llvm.mlir.constant(11 : i32) : !llvm.i32 + // CHECK: switch i32 %[[SWITCH_arg0:[0-9]+]], label %[[SWITCHDEFAULT_bb1:[0-9]+]] [ + // CHECK-NEXT: i32 -1, label %[[SWITCHCASE_bb2:[0-9]+]] + // CHECK-NEXT: i32 1, label %[[SWITCHCASE_bb3:[0-9]+]] + // CHECK-NEXT: ] + llvm.switch %arg0, ^bb1 [ + -1: ^bb2(%0 : !llvm.i32), + 1: ^bb3(%1, %2 : !llvm.i32, !llvm.i32) + ] + +// CHECK: [[SWITCHDEFAULT_bb1]]: +// CHECK-NEXT: ret i32 %[[SWITCH_arg0]] +^bb1: // pred: ^bb0 + llvm.return %arg0 : !llvm.i32 + +// CHECK: [[SWITCHCASE_bb2]]: +// CHECK-NEXT: phi i32 [ 5, %1 ] +// CHECK-NEXT: ret i32 +^bb2(%3: !llvm.i32): // pred: ^bb0 + llvm.return %1 : !llvm.i32 + +// CHECK: [[SWITCHCASE_bb3]]: +// CHECK-NEXT: phi i32 [ 7, %1 ] +// CHECK-NEXT: phi i32 [ 11, %1 ] +// CHECK-NEXT: ret i32 +^bb3(%4: !llvm.i32, %5: !llvm.i32): // pred: ^bb0 + llvm.return %4 : !llvm.i32 +} + +// CHECK-LABEL: @switch_weights +llvm.func @switch_weights(%arg0: !llvm.i32) { + %0 = llvm.mlir.constant(19 : i32) : !llvm.i32 + %1 = llvm.mlir.constant(23 : i32) : !llvm.i32 + %2 = llvm.mlir.constant(29 : i32) : !llvm.i32 + // CHECK: !prof ![[SWITCH_WEIGHT_NODE:[0-9]+]] + llvm.switch %arg0, ^bb1(%0 : !llvm.i32) [ + 9: ^bb2(%1, %2 : !llvm.i32, !llvm.i32), + 99: ^bb3 + ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>} + +^bb1(%3: !llvm.i32): // pred: ^bb0 + llvm.return %3 : !llvm.i32 + +^bb2(%4: !llvm.i32, %5: !llvm.i32): // pred: ^bb0 + llvm.return %5 : !llvm.i32 + +^bb3: // pred: ^bb0 + llvm.return %arg0 : !llvm.i32 +} + +// CHECK: ![[SWITCH_WEIGHT_NODE]] = !{!"branch_weights", i32 13, i32 17, i32 19}