diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -595,6 +595,49 @@ let printer = [{ p << getOperationName(); }]; } +def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", + [AttrSizedOperandSegments, DeclareOpInterfaceMethods, + NoSideEffect]> { + let arguments = (ins LLVM_i32:$value, + Variadic:$defaultOperands, + Variadic:$caseOperands, + ArrayAttr:$case_values, + OptionalAttr:$case_operand_offsets, + OptionalAttr:$branch_weights); + let successors = (successor + AnySuccessor:$defaultDestination, + VariadicSuccessor:$caseDestinations); + + let verifier = [{ return ::verify(*this); }]; + let printer = [{ printSwitchOp(p, *this); }]; + let parser = [{ return parseSwitchOp(parser, result); }]; + + let builders = [ + OpBuilderDAG<(ins "Value":$value, + "Block *":$defaultDestination, + "ValueRange":$defaultOperands, + "ArrayRef":$case_values, + "BlockRange":$caseDestinations, + CArg<"ArrayRef", "{}">:$caseOperands, + CArg<"ArrayRef", "{}">:$branch_weights)>, + LLVM_TerminatorPassthroughOpBuilder + ]; + + let extraClassDeclaration = [{ + /// Return the start index and size of the case operands for the case + /// destination at the given index, or zero values if there are no operands + /// for that case. + std::pair getCaseOperandOffsets(unsigned index); + + /// Return the operands for the case destination block at the given index. + OperandRange getCaseOperands(unsigned index); + + /// Return a mutable range of operands for the case destination block at the + /// given index. + MutableOperandRange getCaseOperandsMutable(unsigned index); + }]; +} + //////////////////////////////////////////////////////////////////////////////// // Auxiliary operations (do not appear in LLVM IR but necessary for the dialect // to work correctly). diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -189,6 +189,158 @@ return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable(); } +//===----------------------------------------------------------------------===// +// LLVM::SwitchOp +//===----------------------------------------------------------------------===// + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, + Block *defaultDestination, ValueRange defaultOperands, + ArrayRef case_values, BlockRange caseDestinations, + ArrayRef caseOperands, + ArrayRef branch_weights) { + SmallVector caseValueAttrs; + for (int32_t caseValue : case_values) + caseValueAttrs.push_back(builder.getI32IntegerAttr(caseValue)); + + SmallVector flattenedCaseOperands; + SmallVector caseOperandOffsets; + int32_t offset = 0; + for (ValueRange operands : caseOperands) { + flattenedCaseOperands.append(operands.begin(), operands.end()); + caseOperandOffsets.push_back(offset); + offset += operands.size(); + } + ElementsAttr caseOperandOffsetsAttr; + if (!caseOperandOffsets.empty()) + caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets); + + ElementsAttr weightsAttr; + if (!branch_weights.empty()) + weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branch_weights)); + + build(builder, result, value, defaultOperands, flattenedCaseOperands, + builder.getArrayAttr(caseValueAttrs), caseOperandOffsetsAttr, + weightsAttr, defaultDestination, caseDestinations); +} + +static void printSwitchOp(OpAsmPrinter &p, SwitchOp &op) { + p << op.getOperationName() << ' ' << op.value() << ", "; + p.printSuccessorAndUseList(op.defaultDestination(), op.defaultOperands()); + + p << " ["; + size_t index = 0; + llvm::interleaveComma( + llvm::zip(op.case_values(), op.caseDestinations()), p, [&](auto i) { + p.printAttribute(std::get<0>(i)); + p << " = "; + p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++)); + }); + p << ']'; + + p.printOptionalAttrDict( + op.getAttrs(), + {op.getOperandSegmentSizeAttr(), "case_values", "case_operand_offsets"}); +} + +/// ::= +/// `llvm.switch` ssa-use `,` bb-id (`(` ssa-use-and-type-list `)`)? `[` +/// integer-attribute `=` bb-id (`(` ssa-use-and-type-list `)`)? +/// (`,` integer-attribute `=` bb-id (`(` ssa-use-and-type-list `)`)? )? +// `]` attribute-dict? +static ParseResult parseSwitchOp(OpAsmParser &parser, OperationState &result) { + Builder &builder = parser.getBuilder(); + + OpAsmParser::OperandType value; + Block *defaultDestination; + SmallVector defaultOperands; + if (parser.parseOperand(value) || + parser.resolveOperand(value, LLVMType::getInt32Ty(builder.getContext()), + result.operands) || + parser.parseComma() || + parser.parseSuccessorAndUseList(defaultDestination, defaultOperands) || + parser.parseLSquare()) + return failure(); + result.addSuccessors(defaultDestination); + result.addOperands(defaultOperands); + + SmallVector caseValues; + SmallVector caseDestinations; + SmallVector caseOperands; + SmallVector caseOperandOffsets; + int32_t offset = 0; + do { + IntegerAttr caseValue; + Block *caseDestination; + if (parser.parseAttribute(caseValue) || parser.parseEqual() || + parser.parseSuccessorAndUseList(caseDestination, caseOperands)) + return failure(); + caseValues.push_back(caseValue); + caseDestinations.push_back(caseDestination); + caseOperandOffsets.push_back(offset); + offset += caseOperands.size(); + } while (!parser.parseOptionalComma()); + result.addAttribute("case_values", builder.getArrayAttr(caseValues)); + result.addSuccessors(caseDestinations); + result.addOperands(caseOperands); + result.addAttribute("case_operand_offsets", + builder.getI32VectorAttr(caseOperandOffsets)); + result.addAttribute( + SwitchOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({1, static_cast(defaultOperands.size()), + static_cast(caseOperands.size())})); + + if (parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +} + +static LogicalResult verify(SwitchOp op) { + if (op.case_values().size() != op.caseDestinations().size()) + return op.emitOpError("expects number of case values to match number of " + "case destinations"); + if (op.branch_weights() && + op.branch_weights()->size() != op.getNumSuccessors()) + return op.emitError("expects numer of branch weights to match number of " + "successors"); + if (op.caseDestinations().empty()) + return op.emitOpError("expects at least one case"); + return success(); +} + +std::pair SwitchOp::getCaseOperandOffsets(unsigned index) { + MutableOperandRange caseOperands = caseOperandsMutable(); + if (!case_operand_offsets()) + return {}; + + ElementsAttr offsets = case_operand_offsets().getValue(); + assert(index < offsets.size() && "invalid case operand offset index"); + + int64_t begin = offsets.getValue(index).cast().getInt(); + int64_t end = index + 1 == offsets.size() + ? caseOperands.size() + : offsets.getValue(index + 1).cast().getInt(); + return std::make_pair(static_cast(begin), + static_cast(end - begin)); +} + +OperandRange SwitchOp::getCaseOperands(unsigned index) { + auto offsets = getCaseOperandOffsets(index); + return caseOperands().slice(offsets.first, offsets.second); +} + +MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) { + auto offsets = getCaseOperandOffsets(index); + return caseOperandsMutable().slice(offsets.first, offsets.second); +} + +Optional +SwitchOp::getMutableSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return index == 0 ? defaultOperandsMutable() + : getCaseOperandsMutable(index - 1); +} + //===----------------------------------------------------------------------===// // Builder, printer and parser for for LLVM::LoadOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -323,19 +323,44 @@ if (isa(terminator)) return terminator.getOperand(index); - // For conditional branches, we need to check if the current block is reached - // through the "true" or the "false" branch and take the relevant operands. - auto condBranchOp = dyn_cast(terminator); - assert(condBranchOp && - "only branch operations can be terminators of a block that " - "has successors"); - assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) && - "successors with arguments in LLVM conditional branches must be " - "different blocks"); - - return condBranchOp.getSuccessor(0) == current - ? condBranchOp.trueDestOperands()[index] - : condBranchOp.falseDestOperands()[index]; + // For instructions that branch based on a condition value, we need to take + // the operands for the branch that was taken. + if (auto condBranchOp = dyn_cast(terminator)) { + // For conditional branches, we take the operands from either the "true" or + // the "false" branch. + assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) && + "successors with arguments in LLVM conditional branches must be " + "different blocks"); + + return condBranchOp.getSuccessor(0) == current + ? condBranchOp.trueDestOperands()[index] + : condBranchOp.falseDestOperands()[index]; + } else if (auto switchOp = dyn_cast(terminator)) { + // For switches, we take the operands from either the default case, or from + // the case branch that was taken. +#ifndef NDEBUG + for (size_t i = 0; i < switchOp.caseDestinations().size(); ++i) { + assert(switchOp.defaultDestination() != switchOp.caseDestinations()[i] && + "successors with arguments in LLVM switch instructions must be " + "different blocks"); + for (size_t j = i + 1; j < switchOp.caseDestinations().size(); ++j) { + assert(switchOp.caseDestinations()[i] != + switchOp.caseDestinations()[j] && + "successors with arguments in LLVM switch instructions must be " + "different blocks"); + } + } +#endif + + if (switchOp.defaultDestination() == current) + return switchOp.defaultOperands()[index]; + for (auto i : llvm::enumerate(switchOp.caseDestinations())) + if (i.value() == current) + return switchOp.getCaseOperands(i.index())[index]; + } + + llvm_unreachable("only branch or switch operations can be terminators of a " + "block that has successors"); } /// Connect the PHI nodes to the results of preceding blocks. @@ -704,6 +729,34 @@ blockMapping[condbrOp.getSuccessor(1)], branchWeights); return success(); } + if (auto switchOp = dyn_cast(opInst)) { + auto weights = switchOp.branch_weights(); + llvm::MDNode *branchWeights = nullptr; + if (weights) { + llvm::SmallVector weightValues; + weightValues.reserve(weights->size()); + for (llvm::APInt weight : weights->cast()) + weightValues.push_back(weight.getLimitedValue()); + branchWeights = llvm::MDBuilder(llvmModule->getContext()) + .createBranchWeights(weightValues); + } + + llvm::SwitchInst *switchInst = + builder.CreateSwitch(valueMapping[switchOp.value()], + blockMapping[switchOp.defaultDestination()], + switchOp.caseDestinations().size(), branchWeights); + + llvm::IntegerType *int32Type = + llvm::Type::getInt32Ty(llvmModule->getContext()); + for (auto i : + llvm::zip(switchOp.case_values(), switchOp.caseDestinations())) + switchInst->addCase( + llvm::ConstantInt::get(int32Type, + std::get<0>(i).cast().getInt()), + blockMapping[std::get<1>(i)]); + + return success(); + } // Emit addressof. We need to look up the global value referenced by the // operation and store it in the MLIR-to-LLVM value mapping. This does not diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -652,3 +652,17 @@ module attributes {llvm.data_layout = "#vjkr32"} { func @invalid_data_layout() } + +// ----- + +// expected-note@+1 {{prior use here}} +func @switch_on_non_i32(%arg0 : !llvm.i1) { + // expected-error@+1 {{expects different type than prior uses: '!llvm.i32' vs '!llvm.i1'}} + llvm.switch %arg0, ^bb1 [0 : i32 = ^bb2] + +^bb1: + llvm.return + +^bb2: + llvm.return +} diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -71,8 +71,8 @@ // CHECK: ^[[BB1]] ^bb1: -// CHECK: llvm.cond_br %7, ^[[BB2:.*]], ^[[BB1]] - llvm.cond_br %7, ^bb2, ^bb1 +// CHECK: llvm.cond_br %7, ^[[BB2:.*]], ^[[BB3:.*]] + llvm.cond_br %7, ^bb2, ^bb3 // CHECK: ^[[BB2]] ^bb2: @@ -80,7 +80,28 @@ // CHECK: %{{.*}} = llvm.mlir.constant(42 : i64) : !llvm.i47 %22 = llvm.mlir.undef : !llvm.struct<(i32, double, i32)> %23 = llvm.mlir.constant(42) : !llvm.i47 + // CHECK: llvm.switch %0, ^[[BB3]] [1 : i32 = ^[[BB4:.*]], 2 : i32 = ^[[BB5:.*]], 3 : i32 = ^[[BB6:.*]]] + llvm.switch %0, ^bb3 [1 : i32 = ^bb4, 2 : i32 = ^bb5, 3 : i32 = ^bb6] +// CHECK: ^[[BB3]] +^bb3: +// CHECK: llvm.br ^[[BB7:.*]] + llvm.br ^bb7 + +// CHECK: ^[[BB4]] +^bb4: + llvm.br ^bb7 + +// CHECK: ^[[BB5]] +^bb5: + llvm.br ^bb7 + +// CHECK: ^[[BB6]] +^bb6: + llvm.br ^bb7 + +// CHECK: ^[[BB7]] +^bb7: // Misc operations. // CHECK: %{{.*}} = llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i32 %24 = llvm.select %7, %0, %1 : !llvm.i1, !llvm.i32 diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -448,7 +448,7 @@ // ----- func @failedResultSizeAttrWrongCount() { - // expected-error @+1 {{'result_segment_sizes' attribute for specifying result segments must have 4 elements}} + // expected-error @+1 {{'result_segment_sizes' attribute for specifying result segments must have 4 elements, not 3}} %0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[2, 1, 1]>: vector<3xi32>} : () -> (i32, i32, i32, i32) } diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1358,3 +1358,54 @@ llvm.return } +// ----- + +// CHECK-LABEL: @switch +llvm.func @switch(%arg0: !llvm.i32) { + %0 = llvm.mlir.constant(5 : i32) : !llvm.i32 + %1 = llvm.mlir.constant(7 : i32) : !llvm.i32 + %2 = llvm.mlir.constant(11 : i32) : !llvm.i32 + // CHECK: switch i32 %[[SWITCH_arg0:[0-9]+]], label %[[SWITCHDEFAULT_bb1:[0-9]+]] [ + // CHECK-NEXT: i32 -1, label %[[SWITCHCASE_bb2:[0-9]+]] + // CHECK-NEXT: i32 1, label %[[SWITCHCASE_bb3:[0-9]+]] + // CHECK-NEXT: ] + llvm.switch %arg0, ^bb1 [-1 : i32 = ^bb2(%0 : !llvm.i32), 1 : i32 = ^bb3(%1, %2 : !llvm.i32, !llvm.i32)] + +// CHECK: [[SWITCHDEFAULT_bb1]]: +// CHECK-NEXT: ret i32 %[[SWITCH_arg0]] +^bb1: // pred: ^bb0 + llvm.return %arg0 : !llvm.i32 + +// CHECK: [[SWITCHCASE_bb2]]: +// CHECK-NEXT: phi i32 [ 5, %1 ] +// CHECK-NEXT: ret i32 +^bb2(%3: !llvm.i32): // pred: ^bb0 + llvm.return %1 : !llvm.i32 + +// CHECK: [[SWITCHCASE_bb3]]: +// CHECK-NEXT: phi i32 [ 7, %1 ] +// CHECK-NEXT: phi i32 [ 11, %1 ] +// CHECK-NEXT: ret i32 +^bb3(%4: !llvm.i32, %5: !llvm.i32): // pred: ^bb0 + llvm.return %4 : !llvm.i32 +} + +// CHECK-LABEL: @switchWeights +llvm.func @switchWeights(%arg0: !llvm.i32) { + %0 = llvm.mlir.constant(19 : i32) : !llvm.i32 + %1 = llvm.mlir.constant(23 : i32) : !llvm.i32 + %2 = llvm.mlir.constant(29 : i32) : !llvm.i32 + // CHECK: !prof ![[SWITCH_WEIGHT_NODE:[0-9]+]] + llvm.switch %arg0, ^bb1(%0 : !llvm.i32) [9 : i32 = ^bb2(%1, %2 : !llvm.i32, !llvm.i32), 99 : i32 = ^bb3] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>} + +^bb1(%3: !llvm.i32): // pred: ^bb0 + llvm.return %3 : !llvm.i32 + +^bb2(%4: !llvm.i32, %5: !llvm.i32): // pred: ^bb0 + llvm.return %5 : !llvm.i32 + +^bb3: // pred: ^bb0 + llvm.return %arg0 : !llvm.i32 +} + +// CHECK: ![[SWITCH_WEIGHT_NODE]] = !{!"branch_weights", i32 13, i32 17, i32 19} diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2211,7 +2211,7 @@ auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements(); if (numElements != {1}) return emitError(loc, "'{0}' attribute for specifying {2} segments " - "must have {1} elements"); + "must have {1} elements, not ") << numElements; } )";