diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -931,7 +931,7 @@ AnyInteger:$value, Variadic:$defaultOperands, VariadicOfVariadic:$caseOperands, - OptionalAttr:$case_values, + OptionalAttr:$case_values, DenseI32ArrayAttr:$case_operand_segments, OptionalAttr:$branch_weights ); @@ -950,6 +950,13 @@ let hasVerifier = 1; let builders = [ + OpBuilder<(ins "Value":$value, + "Block *":$defaultDestination, + "ValueRange":$defaultOperands, + CArg<"ArrayRef", "{}">:$caseValues, + CArg<"BlockRange", "{}">:$caseDestinations, + CArg<"ArrayRef", "{}">:$caseOperands, + CArg<"ArrayRef", "{}">:$branchWeights)>, OpBuilder<(ins "Value":$value, "Block *":$defaultDestination, "ValueRange":$defaultOperands, @@ -957,6 +964,13 @@ CArg<"BlockRange", "{}">:$caseDestinations, CArg<"ArrayRef", "{}">:$caseOperands, CArg<"ArrayRef", "{}">:$branchWeights)>, + OpBuilder<(ins "Value":$value, + "Block *":$defaultDestination, + "ValueRange":$defaultOperands, + CArg<"DenseIntElementsAttr", "{}">:$caseValues, + CArg<"BlockRange", "{}">:$caseDestinations, + CArg<"ArrayRef", "{}">:$caseOperands, + CArg<"ArrayRef", "{}">:$branchWeights)>, LLVM_TerminatorPassthroughOpBuilder ]; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -355,25 +355,54 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, Block *defaultDestination, ValueRange defaultOperands, - ArrayRef caseValues, BlockRange caseDestinations, + DenseIntElementsAttr caseValues, + BlockRange caseDestinations, ArrayRef caseOperands, ArrayRef branchWeights) { - ElementsAttr caseValuesAttr; - if (!caseValues.empty()) - caseValuesAttr = builder.getI32VectorAttr(caseValues); - ElementsAttr weightsAttr; if (!branchWeights.empty()) weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); - build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, + build(builder, result, value, defaultOperands, caseOperands, caseValues, weightsAttr, defaultDestination, caseDestinations); } +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, + Block *defaultDestination, ValueRange defaultOperands, + ArrayRef caseValues, BlockRange caseDestinations, + ArrayRef caseOperands, + ArrayRef branchWeights) { + DenseIntElementsAttr caseValuesAttr; + if (!caseValues.empty()) { + ShapedType caseValueType = VectorType::get( + static_cast(caseValues.size()), value.getType()); + caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); + } + + build(builder, result, value, defaultDestination, defaultOperands, + caseValuesAttr, caseDestinations, caseOperands, branchWeights); +} + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, + Block *defaultDestination, ValueRange defaultOperands, + ArrayRef caseValues, BlockRange caseDestinations, + ArrayRef caseOperands, + ArrayRef branchWeights) { + DenseIntElementsAttr caseValuesAttr; + if (!caseValues.empty()) { + ShapedType caseValueType = VectorType::get( + static_cast(caseValues.size()), value.getType()); + caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); + } + + build(builder, result, value, defaultDestination, defaultOperands, + caseValuesAttr, caseDestinations, caseOperands, branchWeights); +} + /// ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? static ParseResult parseSwitchOpCases( - OpAsmParser &parser, Type flagType, ElementsAttr &caseValues, + OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues, SmallVectorImpl &caseDestinations, SmallVectorImpl> &caseOperands, SmallVectorImpl> &caseOperandTypes) { @@ -412,7 +441,7 @@ } static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, - ElementsAttr caseValues, + DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) { @@ -421,7 +450,7 @@ size_t index = 0; llvm::interleave( - llvm::zip(llvm::cast(caseValues), caseDestinations), + llvm::zip(caseValues, caseDestinations), [&](auto i) { p << " "; p << std::get<0>(i).getLimitedValue(); @@ -446,6 +475,9 @@ return emitError("expects number of branch weights to match number of " "successors: ") << getBranchWeights()->size() << " vs " << getNumSuccessors(); + if (getCaseValues() && + getValue().getType() != getCaseValues()->getElementType()) + return emitError("expects case value type to match condition value type"); return success(); } diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1348,7 +1348,7 @@ unsigned numCases = swInst->getNumCases(); SmallVector> caseOperands(numCases); SmallVector caseOperandRefs(numCases); - SmallVector caseValues(numCases); + SmallVector caseValues(numCases); SmallVector caseBlocks(numCases); for (const auto &it : llvm::enumerate(swInst->cases())) { const llvm::SwitchInst::CaseHandle &caseHandle = it.value(); @@ -1356,7 +1356,7 @@ if (failed(convertBranchArgs(swInst, succBB, caseOperands[it.index()]))) return failure(); caseOperandRefs[it.index()] = caseOperands[it.index()]; - caseValues[it.index()] = caseHandle.getCaseValue()->getSExtValue(); + caseValues[it.index()] = caseHandle.getCaseValue()->getValue(); caseBlocks[it.index()] = lookupBlock(succBB); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -872,6 +872,17 @@ // ----- +func.func @switch_case_type_mismatch(%arg0 : i64) { + // expected-error@below {{expects case value type to match condition value type}} + "llvm.switch"(%arg0)[^bb1, ^bb2] <{case_operand_segments = array, case_values = dense<42> : vector<1xi32>, operand_segment_sizes = array}> : (i64) -> () +^bb1: // pred: ^bb0 + llvm.return +^bb2: // pred: ^bb0 + llvm.return +} + +// ----- + // expected-error@below {{expected zero value for 'common' linkage}} llvm.mlir.global common @non_zero_global_common_linkage(42 : i32) : i32 diff --git a/mlir/test/Target/LLVMIR/Import/control-flow.ll b/mlir/test/Target/LLVMIR/Import/control-flow.ll --- a/mlir/test/Target/LLVMIR/Import/control-flow.ll +++ b/mlir/test/Target/LLVMIR/Import/control-flow.ll @@ -46,33 +46,33 @@ ; CHECK-LABEL: @simple_switch( ; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -define i32 @simple_switch(i32 %arg1) { +define i64 @simple_switch(i64 %arg1) { ; CHECK: %[[VAL1:.+]] = llvm.add ; CHECK: %[[VAL2:.+]] = llvm.sub ; CHECK: %[[VAL3:.+]] = llvm.mul - %1 = add i32 %arg1, 42 - %2 = sub i32 %arg1, 42 - %3 = mul i32 %arg1, 42 - ; CHECK: llvm.switch %[[ARG1]] : i32, ^[[BBD:.+]] [ + %1 = add i64 %arg1, 42 + %2 = sub i64 %arg1, 42 + %3 = mul i64 %arg1, 42 + ; CHECK: llvm.switch %[[ARG1]] : i64, ^[[BBD:.+]] [ ; CHECK: 0: ^[[BB1:.+]], ; CHECK: 9: ^[[BB2:.+]] ; CHECK: ] - switch i32 %arg1, label %bbd [ - i32 0, label %bb1 - i32 9, label %bb2 + switch i64 %arg1, label %bbd [ + i64 0, label %bb1 + i64 9, label %bb2 ] bb1: ; CHECK: ^[[BB1]]: ; CHECK: llvm.return %[[VAL1]] - ret i32 %1 + ret i64 %1 bb2: ; CHECK: ^[[BB2]]: ; CHECK: llvm.return %[[VAL2]] - ret i32 %2 + ret i64 %2 bbd: ; CHECK: ^[[BBD]]: ; CHECK: llvm.return %[[VAL3]] - ret i32 %3 + ret i64 %3 } ; // -----