diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -35,6 +35,15 @@ let hasCustomAssemblyFormat = 1; } +// Attribute definition for the LLVM Linkage enum. +def CConvAttr : LLVM_Attr<"CConv"> { + let mnemonic = "cconv"; + let parameters = (ins + "CConv":$CConv + ); + let hasCustomAssemblyFormat = 1; +} + def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> { let mnemonic = "loopopts"; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -39,6 +39,7 @@ // attribute definition itself. // TODO: this shouldn't be needed after we unify the attribute generation, i.e. // --gen-attr-* and --gen-attrdef-*. +using cconv::CConv; using linkage::Linkage; } // namespace LLVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -234,6 +234,14 @@ string llvmClassName = llvmName; } +// LLVM_CEnumAttr is functionally identical to LLVM_EnumAttr, but to be used for +// non-class enums. +class LLVM_CEnumAttr cases> : + I64EnumAttr { + string llvmClassName = llvmNS; +} + // For every value in the list, substitutes the value in the place of "$0" in // "pattern" and stores the list of strings as "lst". class ListIntSubst values> { diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -67,6 +67,134 @@ let cppNamespace = "::mlir::LLVM"; } +// These values must match llvm::CallingConv ones. +// See https://llvm.org/doxygen/namespacellvm_1_1CallingConv.html for full list +// of supported calling conventions. +def CConvC : LLVM_EnumAttrCase<"C", "ccc", "C", 0>; +def CConvFast : LLVM_EnumAttrCase<"Fast", "fastcc", "Fast", 8>; +def CConvCold : LLVM_EnumAttrCase<"Cold", "coldcc", "Cold", 9>; +def CConvGHC : LLVM_EnumAttrCase<"GHC", "cc_10", "GHC", 10>; +def CConvHiPE : LLVM_EnumAttrCase<"HiPE", "cc_11", "HiPE", 11>; +def CConvWebKitJS : LLVM_EnumAttrCase<"WebKit_JS", "webkit_jscc", + "WebKit_JS", 12>; +def CConvAnyReg : LLVM_EnumAttrCase<"AnyReg", "anyregcc", "AnyReg", 13>; +def CConvPreserveMost : LLVM_EnumAttrCase<"PreserveMost", "preserve_mostcc", + "PreserveMost", 14>; +def CConvPreserveAll : LLVM_EnumAttrCase<"PreserveAll", "preserve_allcc", + "PreserveAll", 15>; +def CConvSwift : LLVM_EnumAttrCase<"Swift", "swiftcc", "Swift", 16>; +def CConvCXXFastTLS : LLVM_EnumAttrCase<"CXX_FAST_TLS", "cxx_fast_tlscc", + "CXX_FAST_TLS", 17>; +def CConvTail : LLVM_EnumAttrCase<"Tail", "tailcc", "Tail", 18>; +def CConvCFGuard_Check : LLVM_EnumAttrCase<"CFGuard_Check", + "cfguard_checkcc", + "CFGuard_Check", 19>; +def CConvSwiftTail : LLVM_EnumAttrCase<"SwiftTail", "swifttailcc", + "SwiftTail", 20>; +def CConvX86_StdCall : LLVM_EnumAttrCase<"X86_StdCall", "x86_stdcallcc", + "X86_StdCall", 64>; +def CConvX86_FastCall : LLVM_EnumAttrCase<"X86_FastCall", "x86_fastcallcc", + "X86_FastCall", 65>; +def CConvARM_APCS : LLVM_EnumAttrCase<"ARM_APCS", "arm_apcscc", "ARM_APCS", 66>; +def CConvARM_AAPCS : LLVM_EnumAttrCase<"ARM_AAPCS", "arm_aapcscc", "ARM_AAPCS", + 67>; +def CConvARM_AAPCS_VFP : LLVM_EnumAttrCase<"ARM_AAPCS_VFP", "arm_aapcs_vfpcc", + "ARM_AAPCS_VFP", 68>; +def CConvMSP430_INTR : LLVM_EnumAttrCase<"MSP430_INTR", "msp430_intrcc", + "MSP430_INTR", 69>; +def CConvX86_ThisCall : LLVM_EnumAttrCase<"X86_ThisCall", "x86_thiscallcc", + "X86_ThisCall", 70>; +def CConvPTX_Kernel : LLVM_EnumAttrCase<"PTX_Kernel", "ptx_kernelcc", + "PTX_Kernel", 71>; +def CConvPTX_Device : LLVM_EnumAttrCase<"PTX_Device", "ptx_devicecc", + "PTX_Device", 72>; +def CConvSPIR_FUNC : LLVM_EnumAttrCase<"SPIR_FUNC", "spir_funccc", + "SPIR_FUNC", 75>; +def CConvSPIR_KERNEL : LLVM_EnumAttrCase<"SPIR_KERNEL", "spir_kernelcc", + "SPIR_KERNEL", 76>; +def CConvIntel_OCL_BI : LLVM_EnumAttrCase<"Intel_OCL_BI", "intel_ocl_bicc", + "Intel_OCL_BI", 77>; +def CConvX86_64_SysV : LLVM_EnumAttrCase<"X86_64_SysV", "x86_64_sysvcc", + "X86_64_SysV", 78>; +def CConvWin64 : LLVM_EnumAttrCase<"Win64", "win64cc", "Win64", 79>; +def CConvX86_VectorCall : LLVM_EnumAttrCase<"X86_VectorCall", + "x86_vectorcallcc", + "X86_VectorCall", 80>; +def CConvHHVM : LLVM_EnumAttrCase<"HHVM", "hhvmcc", "HHVM", 81>; +def CConvHHVM_C : LLVM_EnumAttrCase<"HHVM_C", "hhvm_ccc", "HHVM_C", 82>; +def CConvX86_INTR : LLVM_EnumAttrCase<"X86_INTR", "x86_intrcc", "X86_INTR", 83>; +def CConvAVR_INTR : LLVM_EnumAttrCase<"AVR_INTR", "avr_intrcc", "AVR_INTR", 84>; +def CConvAVR_SIGNAL : LLVM_EnumAttrCase<"AVR_SIGNAL", "avr_signalcc", + "AVR_SIGNAL", 85>; +def CConvAVR_BUILTIN : LLVM_EnumAttrCase<"AVR_BUILTIN", "avr_builtincc", + "AVR_BUILTIN", 86>; +def CConvAMDGPU_VS : LLVM_EnumAttrCase<"AMDGPU_VS", "amdgpu_vscc", "AMDGPU_VS", + 87>; +def CConvAMDGPU_GS : LLVM_EnumAttrCase<"AMDGPU_GS", "amdgpu_gscc", "AMDGPU_GS", + 88>; +def CConvAMDGPU_PS : LLVM_EnumAttrCase<"AMDGPU_PS", "amdgpu_pscc", "AMDGPU_PS", + 89>; +def CConvAMDGPU_CS : LLVM_EnumAttrCase<"AMDGPU_CS", "amdgpu_cscc", "AMDGPU_CS", + 90>; +def CConvAMDGPU_KERNEL : LLVM_EnumAttrCase<"AMDGPU_KERNEL", "amdgpu_kernelcc", + "AMDGPU_KERNEL", 91>; +def CConvX86_RegCall : LLVM_EnumAttrCase<"X86_RegCall", "x86_regcallcc", + "X86_RegCall", 92>; +def CConvAMDGPU_HS : LLVM_EnumAttrCase<"AMDGPU_HS", "amdgpu_hscc", "AMDGPU_HS", + 93>; +def CConvMSP430_BUILTIN : LLVM_EnumAttrCase<"MSP430_BUILTIN", + "msp430_builtincc", + "MSP430_BUILTIN", 94>; +def CConvAMDGPU_LS : LLVM_EnumAttrCase<"AMDGPU_LS", "amdgpu_lscc", "AMDGPU_LS", + 95>; +def CConvAMDGPU_ES : LLVM_EnumAttrCase<"AMDGPU_ES", "amdgpu_escc", "AMDGPU_ES", + 96>; +def CConvAArch64_VectorCall : LLVM_EnumAttrCase<"AArch64_VectorCall", + "aarch64_vectorcallcc", + "AArch64_VectorCall", 97>; +def CConvAArch64_SVE_VectorCall : LLVM_EnumAttrCase<"AArch64_SVE_VectorCall", + "aarch64_sve_vectorcallcc", + "AArch64_SVE_VectorCall", + 98>; +def CConvWASM_EmscriptenInvoke : LLVM_EnumAttrCase<"WASM_EmscriptenInvoke", + "wasm_emscripten_invokecc", + "WASM_EmscriptenInvoke", 99>; +def CConvAMDGPU_Gfx : LLVM_EnumAttrCase<"AMDGPU_Gfx", "amdgpu_gfxcc", + "AMDGPU_Gfx", 100>; +def CConvM68k_INTR : LLVM_EnumAttrCase<"M68k_INTR", "m68k_intrcc", "M68k_INTR", + 101>; + +def CConvEnum : LLVM_CEnumAttr< + "CConv", + "::llvm::CallingConv", + "Calling Conventions", + [CConvC, CConvFast, CConvCold, CConvGHC, CConvHiPE, CConvWebKitJS, + CConvAnyReg, CConvPreserveMost, CConvPreserveAll, CConvSwift, + CConvCXXFastTLS, CConvTail, CConvCFGuard_Check, CConvSwiftTail, + CConvX86_StdCall, CConvX86_FastCall, CConvARM_APCS, + CConvARM_AAPCS, CConvARM_AAPCS_VFP, CConvMSP430_INTR, CConvX86_ThisCall, + CConvPTX_Kernel, CConvPTX_Device, CConvSPIR_FUNC, CConvSPIR_KERNEL, + CConvIntel_OCL_BI, CConvX86_64_SysV, CConvWin64, CConvX86_VectorCall, + CConvHHVM, CConvHHVM_C, CConvX86_INTR, CConvAVR_INTR, CConvAVR_BUILTIN, + CConvAMDGPU_VS, CConvAMDGPU_GS, CConvAMDGPU_CS, CConvAMDGPU_KERNEL, + CConvX86_RegCall, CConvAMDGPU_HS, CConvMSP430_BUILTIN, CConvAMDGPU_LS, + CConvAMDGPU_ES, CConvAArch64_VectorCall, CConvAArch64_SVE_VectorCall, + CConvWASM_EmscriptenInvoke, CConvAMDGPU_Gfx, CConvM68k_INTR + ]> { + let cppNamespace = "::mlir::LLVM::cconv"; +} + +def CConv : DialectAttr< + LLVM_Dialect, + CPred<"$_self.isa<::mlir::LLVM::CConvAttr>()">, + "LLVM Calling Convention specification"> { + let storageType = "::mlir::LLVM::CConvAttr"; + let returnType = "::mlir::LLVM::cconv::CConv"; + let convertFromStorage = "$_self.getCConv()"; + let constBuilderCall = + "::mlir::LLVM::CConvAttr::get($_builder.getContext(), $0)"; +} + class LLVM_Builder { string llvmBuilder = builder; } @@ -1233,6 +1361,7 @@ TypeAttrOf:$function_type, DefaultValuedAttr:$linkage, UnitAttr:$dso_local, + DefaultValuedAttr:$CConv, OptionalAttr:$personality, OptionalAttr:$garbageCollector, OptionalAttr:$passthrough @@ -1246,6 +1375,7 @@ OpBuilder<(ins "StringRef":$name, "Type":$type, CArg<"Linkage", "Linkage::External">:$linkage, CArg<"bool", "false">:$dsoLocal, + CArg<"CConv", "CConv::C">:$cconv, CArg<"ArrayRef", "{}">:$attrs, CArg<"ArrayRef", "{}">:$argAttrs)> ]; diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -139,7 +139,8 @@ prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments()); auto wrapperFuncOp = rewriter.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), - wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes); + wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, + /*cconv*/ LLVM::CConv::C, attributes); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock()); @@ -206,7 +207,8 @@ // Create the auxiliary function. auto wrapperFunc = builder.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), - wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes); + wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false, + /*cconv*/ LLVM::CConv::C, attributes); builder.setInsertionPointToStart(newFuncOp.addEntryBlock()); @@ -345,7 +347,7 @@ } auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), llvmType, linkage, - /*dsoLocal*/ false, attributes); + /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -68,7 +68,8 @@ attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr()); auto llvmFuncOp = rewriter.create( gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, - LLVM::Linkage::External, /*dsoLocal*/ false, attributes); + LLVM::Linkage::External, /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C, + attributes); { // Insert operations that correspond to converted workgroup and private diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -37,6 +37,7 @@ using namespace mlir; using namespace mlir::LLVM; +using mlir::LLVM::cconv::getMaxEnumValForCConv; using mlir::LLVM::linkage::getMaxEnumValForLinkage; #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" @@ -1821,6 +1822,7 @@ REGISTER_ENUM_TYPE(Linkage); REGISTER_ENUM_TYPE(UnnamedAddr); +REGISTER_ENUM_TYPE(CConv); } // namespace /// Parse an enum from the keyword, or default to the provided default value. @@ -2124,7 +2126,8 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, StringRef name, Type type, LLVM::Linkage linkage, - bool dsoLocal, ArrayRef attrs, + bool dsoLocal, CConv cconv, + ArrayRef attrs, ArrayRef argAttrs) { result.addRegion(); result.addAttribute(SymbolTable::getSymbolAttrName(), @@ -2133,6 +2136,8 @@ TypeAttr::get(type)); result.addAttribute(getLinkageAttrName(result.name), LinkageAttr::get(builder.getContext(), linkage)); + result.addAttribute(getCConvAttrName(result.name), + CConvAttr::get(builder.getContext(), cconv)); result.attributes.append(attrs.begin(), attrs.end()); if (dsoLocal) result.addAttribute("dso_local", builder.getUnitAttr()); @@ -2185,7 +2190,8 @@ // Parses an LLVM function. // -// operation ::= `llvm.func` linkage? function-signature function-attributes? +// operation ::= `llvm.func` linkage? cconv? function-signature +// function-attributes? // function-body // ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { @@ -2196,6 +2202,12 @@ parseOptionalLLVMKeyword( parser, result, LLVM::Linkage::External))); + // Default to C Calling Convention if no keyword is provided. + result.addAttribute( + getCConvAttrName(result.name), + CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword( + parser, result, LLVM::CConv::C))); + StringAttr nameAttr; SmallVector entryArgs; SmallVector resultAttrs; @@ -2239,6 +2251,9 @@ p << ' '; if (getLinkage() != LLVM::Linkage::External) p << stringifyLinkage(getLinkage()) << ' '; + if (getCConv() != LLVM::CConv::C) + p << stringifyCConv(getCConv()) << ' '; + p.printSymbolName(getName()); LLVMFunctionType fnType = getFunctionType(); @@ -2255,7 +2270,8 @@ function_interface_impl::printFunctionSignature(p, *this, argTypes, isVarArg(), resTypes); function_interface_impl::printFunctionAttributes( - p, *this, argTypes.size(), resTypes.size(), {getLinkageAttrName()}); + p, *this, argTypes.size(), resTypes.size(), + {getLinkageAttrName(), getCConvAttrName()}); // Print the body if this is not an external function. Region &body = getBody(); @@ -2645,7 +2661,7 @@ //===----------------------------------------------------------------------===// void LLVMDialect::initialize() { - addAttributes(); + addAttributes(); // clang-format off addTypes(getCConv()) <= cconv::getMaxEnumValForCConv()) + printer << stringifyEnum(getCConv()); + else + printer << "INVALID_cc_" << static_cast(getCConv()); + printer << ">"; +} + +Attribute CConvAttr::parse(AsmParser &parser, Type type) { + StringRef convName; + + if (parser.parseLess() || parser.parseKeyword(&convName) || + parser.parseGreater()) + return {}; + auto cconv = cconv::symbolizeCConv(convName); + if (!cconv) { + parser.emitError(parser.getNameLoc(), "unknown calling convention: ") + << convName; + return {}; + } + CConv cconvVal = *cconv; + return CConvAttr::get(parser.getContext(), cconvVal); +} + LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) : options(attr.getOptions().begin(), attr.getOptions().end()) {} diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -1139,10 +1139,13 @@ if (!functionType) return failure(); + bool dsoLocal = f->hasLocalLinkage(); + CConv cconv = convertCConvFromLLVM(f->getCallingConv()); + b.setInsertionPoint(module.getBody(), getFuncInsertPt()); - LLVMFuncOp fop = - b.create(UnknownLoc::get(context), f->getName(), functionType, - convertLinkageFromLLVM(f->getLinkage())); + LLVMFuncOp fop = b.create( + UnknownLoc::get(context), f->getName(), functionType, + convertLinkageFromLLVM(f->getLinkage()), dsoLocal, cconv); if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f)) fop->setAttr(b.getStringAttr("personality"), personality); diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -144,6 +144,21 @@ -> (!llvm.struct<(i32)> {llvm.struct_attrs = [{llvm.noalias}]}) { llvm.return %arg0 : !llvm.struct<(i32)> } + + // CHECK: llvm.func @cconv1 + llvm.func ccc @cconv1() { + llvm.return + } + + // CHECK: llvm.func weak @cconv2 + llvm.func weak ccc @cconv2() { + llvm.return + } + + // CHECK: llvm.func weak fastcc @cconv3 + llvm.func weak fastcc @cconv3() { + llvm.return + } } // ----- @@ -251,3 +266,18 @@ // expected-error@+1 {{functions cannot have 'common' linkage}} llvm.func common @common_linkage_func() } + +// ----- + +module { + // expected-error@+1 {{custom op 'llvm.func' expected valid '@'-identifier for symbol name}} + llvm.func cc_12 @unknown_calling_convention() +} + +// ----- + +module { + // expected-error@+2 {{unknown calling convention: cc_12}} + "llvm.func"() ({ + }) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv, function_type = !llvm.func} : () -> () +} diff --git a/mlir/test/Target/LLVMIR/Import/basic.ll b/mlir/test/Target/LLVMIR/Import/basic.ll --- a/mlir/test/Target/LLVMIR/Import/basic.ll +++ b/mlir/test/Target/LLVMIR/Import/basic.ll @@ -122,8 +122,13 @@ ; CHECK: llvm.func @fe(i32) -> f32 declare float @fe(i32) +; CHECK: llvm.func internal spir_funccc @spir_func_internal() +define internal spir_func void @spir_func_internal() { + ret void +} + ; FIXME: function attributes. -; CHECK-LABEL: llvm.func internal @f1(%arg0: i64) -> i32 { +; CHECK-LABEL: llvm.func internal @f1(%arg0: i64) -> i32 attributes {dso_local} { ; CHECK-DAG: %[[c2:[0-9]+]] = llvm.mlir.constant(2 : i32) : i32 ; CHECK-DAG: %[[c42:[0-9]+]] = llvm.mlir.constant(42 : i32) : i32 ; CHECK-DAG: %[[c1:[0-9]+]] = llvm.mlir.constant(true) : i1 diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -210,6 +210,27 @@ return cases; } }; + +// Wraper class around a Tablegen definition of a C-style LLVM enum attribute. +class LLVMCEnumAttr : public tblgen::EnumAttr { +public: + using tblgen::EnumAttr::EnumAttr; + + // Returns the C++ enum name for the LLVM API. + StringRef getLLVMClassName() const { + return def->getValueAsString("llvmClassName"); + } + + // Returns all associated cases viewed as LLVM-specific enum cases. + std::vector getAllCases() const { + std::vector cases; + + for (auto &c : tblgen::EnumAttr::getAllCases()) + cases.emplace_back(c); + + return cases; + } +}; } // namespace // Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing @@ -242,6 +263,37 @@ os << "}\n\n"; } +// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing +// switch-based logic to convert from the MLIR LLVM dialect enum attribute case +// (Enum) to the corresponding LLVM API C-style enumerant +static void emitOneCEnumToConversion(const llvm::Record *record, + raw_ostream &os) { + LLVMCEnumAttr enumAttr(record); + StringRef llvmClass = enumAttr.getLLVMClassName(); + StringRef cppClassName = enumAttr.getEnumClassName(); + StringRef cppNamespace = enumAttr.getCppNamespace(); + + // Emit the function converting the enum attribute to its LLVM counterpart. + os << formatv("static LLVM_ATTRIBUTE_UNUSED int64_t " + "convert{0}ToLLVM({1}::{0} value) {{\n", + cppClassName, cppNamespace); + os << " switch (value) {\n"; + + for (const auto &enumerant : enumAttr.getAllCases()) { + StringRef llvmEnumerant = enumerant.getLLVMEnumerant(); + StringRef cppEnumerant = enumerant.getSymbol(); + os << formatv(" case {0}::{1}::{2}:\n", cppNamespace, cppClassName, + cppEnumerant); + os << formatv(" return static_cast({0}::{1});\n", llvmClass, + llvmEnumerant); + } + + os << " }\n"; + os << formatv(" llvm_unreachable(\"unknown {0} type\");\n", + enumAttr.getEnumClassName()); + os << "}\n\n"; +} + // Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and // containing switch-based logic to convert from the LLVM API enumerant to MLIR // LLVM dialect enum attribute (Enum). @@ -272,6 +324,38 @@ os << "}\n\n"; } +// Emits conversion function "Enum convertEnumFromLLVM(LLVMEnum)" and +// containing switch-based logic to convert from the LLVM API C-style enumerant +// to MLIR LLVM dialect enum attribute (Enum). +static void emitOneCEnumFromConversion(const llvm::Record *record, + raw_ostream &os) { + LLVMCEnumAttr enumAttr(record); + StringRef llvmClass = enumAttr.getLLVMClassName(); + StringRef cppClassName = enumAttr.getEnumClassName(); + StringRef cppNamespace = enumAttr.getCppNamespace(); + + // Emit the function converting the enum attribute from its LLVM counterpart. + os << formatv( + "inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM(int64_t " + "value) {{\n", + cppNamespace, cppClassName, llvmClass); + os << " switch (value) {\n"; + + for (const auto &enumerant : enumAttr.getAllCases()) { + StringRef llvmEnumerant = enumerant.getLLVMEnumerant(); + StringRef cppEnumerant = enumerant.getSymbol(); + os << formatv(" case static_cast({0}::{1}):\n", llvmClass, + llvmEnumerant); + os << formatv(" return {0}::{1}::{2};\n", cppNamespace, cppClassName, + cppEnumerant); + } + + os << " }\n"; + os << formatv(" llvm_unreachable(\"unknown {0} type\");", + enumAttr.getLLVMClassName()); + os << "}\n\n"; +} + // Emits conversion functions between MLIR enum attribute case and corresponding // LLVM API enumerants for all registered LLVM dialect enum attributes. template @@ -283,6 +367,13 @@ else emitOneEnumFromConversion(def, os); + for (const auto *def : + recordKeeper.getAllDerivedDefinitions("LLVM_CEnumAttr")) + if (ConvertTo) + emitOneCEnumToConversion(def, os); + else + emitOneCEnumFromConversion(def, os); + return false; }