diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -29,6 +29,15 @@ let assemblyFormat = "`<` $CallingConv `>`"; } +//===----------------------------------------------------------------------===// +// ComdatAttr +//===----------------------------------------------------------------------===// + +def ComdatAttr : LLVM_Attr<"Comdat", "comdat"> { + let parameters = (ins "comdat::Comdat":$comdat); + let assemblyFormat = "$comdat"; +} + //===----------------------------------------------------------------------===// // LinkageAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h @@ -66,6 +66,7 @@ // TODO: this shouldn't be needed after we unify the attribute generation, i.e. // --gen-attr-* and --gen-attrdef-*. using cconv::CConv; +using comdat::Comdat; using linkage::Linkage; } // namespace LLVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td @@ -589,6 +589,42 @@ "::mlir::LLVM::LinkageAttr::get($_builder.getContext(), $0)"; } + +//===----------------------------------------------------------------------===// +// Comdat +//===----------------------------------------------------------------------===// + +def ComdatAny + : LLVM_EnumAttrCase<"Any", "any", "Any", 0>; +def ComdatExactMatch + : LLVM_EnumAttrCase<"ExactMatch", "exactmatch", "ExactMatch", 1>; +def ComdatLargest + : LLVM_EnumAttrCase<"Largest", "largest", "Largest", 2>; +def ComdatNoDeduplicate + : LLVM_EnumAttrCase<"NoDeduplicate", "nodeduplicate", "NoDeduplicate", 3>; +def ComdatSameSize + : LLVM_EnumAttrCase<"SameSize", "samesize", "SameSize", 4>; + +def ComdatEnum : LLVM_EnumAttr< + "Comdat", + "::llvm::Comdat::SelectionKind", + "LLVM Comdat Types", + [ComdatAny, ComdatExactMatch, ComdatLargest, + ComdatNoDeduplicate, ComdatSameSize]> { + let cppNamespace = "::mlir::LLVM::comdat"; +} + +def Comdat : DialectAttr< + LLVM_Dialect, + CPred<"$_self.isa<::mlir::LLVM::ComdatAttr>()">, + "LLVM Comdat selection kind"> { + let storageType = "::mlir::LLVM::ComdatAttr"; + let returnType = "::mlir::LLVM::Comdat"; + let convertFromStorage = "$_self.getComdat()"; + let constBuilderCall = + "::mlir::LLVM::ComdatAttr::get($_builder.getContext(), $0)"; +} + //===----------------------------------------------------------------------===// // UnnamedAddr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1316,6 +1316,7 @@ DefaultValuedAttr, "0">:$addr_space, OptionalAttr:$unnamed_addr, OptionalAttr:$section, + OptionalAttr:$comdat, DefaultValuedAttr:$visibility_ ); let summary = "LLVM dialect global."; @@ -1424,6 +1425,7 @@ CArg<"unsigned", "0">:$addrSpace, CArg<"bool", "false">:$dsoLocal, CArg<"bool", "false">:$thread_local_, + CArg<"std::optional", "std::nullopt">:$comdat, CArg<"ArrayRef", "{}">:$attrs)> ]; @@ -1515,6 +1517,55 @@ let hasVerifier = 1; } +def LLVM_ComdatSelectorOp : LLVM_Op<"comdat_selector", [Symbol]> { + let arguments = (ins + SymbolNameAttr:$sym_name, + Comdat:$comdat + ); + + let summary = "LLVM dialect comdat selector declaration"; + + let description = [{ + Provides access to object file COMDAT section/group functionality. + + Examples: + ```mlir + llvm.comdat @__llvm_comdat { + llvm.comdat_selector @any any + } + llvm.mlir.global internal constant @has_any_comdat(1 : i64) comdat(@__llvm_comdat::@any) : i64 + ``` + }]; + let assemblyFormat = "$sym_name $comdat attr-dict"; +} + +def LLVM_ComdatOp : LLVM_Op<"comdat", [NoTerminator, NoRegionArguments, SymbolTable, Symbol]> { + let arguments = (ins + SymbolNameAttr:$sym_name + ); + let summary = "LLVM dialect comdat region"; + + let description = [{ + Provides access to object file COMDAT section/group functionality. + + Examples: + ```mlir + llvm.comdat @__llvm_comdat { + llvm.comdat_selector @any any + } + llvm.mlir.global internal constant @has_any_comdat(1 : i64) comdat(@__llvm_comdat::@any) : i64 + ``` + }]; + let regions = (region SizedRegion<1>:$body); + + + let skipDefaultBuilders = 1; + let builders = [OpBuilder<(ins "StringRef":$symName)>]; + + let assemblyFormat = "$sym_name $body attr-dict"; + let hasRegionVerifier = 1; +} + def LLVM_LLVMFuncOp : LLVM_Op<"func", [ AutomaticAllocationScope, IsolatedFromAbove, FunctionOpInterface, CallableOpInterface diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -56,6 +56,9 @@ /// Converts all functions of the LLVM module to MLIR functions. LogicalResult convertFunctions(); + /// Converts all comdat selectors of the LLVM module to MLIR comdat operations. + LogicalResult convertComdats(); + /// Converts all global variables of the LLVM module to MLIR global variables. LogicalResult convertGlobals(); @@ -284,6 +287,10 @@ /// metadata that converts to MLIR operations. Creates the global metadata /// operation on the first invocation. MetadataOp getGlobalMetadataOp(); + /// Returns a global comdat operation that serves as a container for LLVM + /// comdat selectors. Creates the global comdat operation on the first + /// invocation. + ComdatOp getGlobalComdatOp(); /// Performs conversion of LLVM TBAA metadata starting from /// `node`. On exit from this function all nodes reachable /// from `node` are converted, and tbaaMapping map is updated @@ -312,6 +319,8 @@ Operation *globalInsertionOp = nullptr; /// Operation to insert metadata operations into. MetadataOp globalMetadataOp = nullptr; + /// Operation to insert comdat selector operations into. + ComdatOp globalComdatOp = nullptr; /// The current context. MLIRContext *context; /// The MLIR module being created. diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -274,6 +274,7 @@ LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder); LogicalResult convertFunctionSignatures(); LogicalResult convertFunctions(); + LogicalResult convertComdats(); LogicalResult convertGlobals(); LogicalResult convertOneFunction(LLVMFuncOp func); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1593,6 +1593,25 @@ return success(); } +//===----------------------------------------------------------------------===// +// Verifier for LLVM::ComdatOp. +//===----------------------------------------------------------------------===// + +void ComdatOp::build(OpBuilder &builder, OperationState &result, StringRef symName) { + result.addAttribute(getSymNameAttrName(result.name), builder.getStringAttr(symName)); + Region *body = result.addRegion(); + body->emplaceBlock(); +} + +LogicalResult ComdatOp::verifyRegions() { + Region &body = getBody(); + for (Operation &op : body.getOps()) + if (!isa(op)) + return op.emitError("only comdat selector symbols can appear in a comdat region"); + + return success(); +} + //===----------------------------------------------------------------------===// // Builder, printer and verifier for LLVM::GlobalOp. //===----------------------------------------------------------------------===// @@ -1601,6 +1620,7 @@ bool isConstant, Linkage linkage, StringRef name, Attribute value, uint64_t alignment, unsigned addrSpace, bool dsoLocal, bool threadLocal, + std::optional comdat, ArrayRef attrs) { result.addAttribute(getSymNameAttrName(result.name), builder.getStringAttr(name)); @@ -1616,6 +1636,8 @@ if (threadLocal) result.addAttribute(getThreadLocal_AttrName(result.name), builder.getUnitAttr()); + if (comdat) + result.addAttribute(getComdatAttrName(result.name), *comdat); // Only add an alignment attribute if the "alignment" input // is different from 0. The value must also be a power of two, but @@ -1652,6 +1674,9 @@ if (auto value = getValueOrNull()) p.printAttribute(value); p << ')'; + if (auto cd = getComdat()) + p << " comdat(" << *cd << ')'; + // Note that the alignment attribute is printed using the // default syntax here, even though it is an inherent attribute // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) @@ -1660,7 +1685,7 @@ getGlobalTypeAttrName(), getConstantAttrName(), getValueAttrName(), getLinkageAttrName(), getUnnamedAddrAttrName(), getThreadLocal_AttrName(), - getVisibility_AttrName()}); + getVisibility_AttrName(), getComdatAttrName()}); // Print the trailing type unless it's a string global. if (llvm::dyn_cast_or_null(getValueOrNull())) @@ -1768,6 +1793,15 @@ return failure(); } + if (succeeded(parser.parseOptionalKeyword("comdat"))) { + SymbolRefAttr comdat; + if (parser.parseLParen() || parser.parseAttribute(comdat) || + parser.parseRParen()) + return failure(); + + result.addAttribute(getComdatAttrName(result.name), comdat); + } + SmallVector types; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseOptionalColonTypeList(types)) @@ -1850,6 +1884,12 @@ } } + if (getComdat()) { + auto *op = SymbolTable::lookupNearestSymbolFrom(*this, getComdatAttr()); + if (!isa_and_nonnull(op)) + return emitOpError() << "expected comdat symbol"; + } + std::optional alignAttr = getAlignment(); if (alignAttr.has_value()) { uint64_t value = alignAttr.value(); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -29,6 +29,7 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringSet.h" +#include "llvm/IR/Comdat.h" #include "llvm/IR/Constants.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstIterator.h" @@ -81,6 +82,10 @@ return "__llvm_global_metadata"; } +static constexpr StringRef getGlobalComdatOpName() { + return "__llvm_global_comdat"; +} + /// Converts the sync scope identifier of `inst` to the string representation /// necessary to build an atomic LLVM dialect operation. Returns the empty /// string if the operation has either no sync scope or the default system-level @@ -167,6 +172,16 @@ mlirModule.getLoc(), getGlobalMetadataOpName()); } +ComdatOp ModuleImport::getGlobalComdatOp() { + if (globalComdatOp) + return globalComdatOp; + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(mlirModule.getBody()); + return globalComdatOp = builder.create( + mlirModule.getLoc(), getGlobalComdatOpName()); +} + LogicalResult ModuleImport::processTBAAMetadata(const llvm::MDNode *node) { Location loc = mlirModule.getLoc(); SmallVector workList; @@ -540,6 +555,19 @@ return success(); } +LogicalResult ModuleImport::convertComdats() { + ComdatOp cdregion = getGlobalComdatOp(); + builder.setInsertionPointToEnd(&cdregion.getBody().back()); + for (auto &kv : llvmModule->getComdatSymbolTable()) { + StringRef name = kv.getKey(); + llvm::Comdat::SelectionKind selector = kv.getValue().getSelectionKind(); + builder.create(mlirModule.getLoc(), name, convertComdatFromLLVM(selector)); + } + builder.setInsertionPointAfter(cdregion); + + return success(); +} + LogicalResult ModuleImport::convertGlobals() { for (llvm::GlobalVariable &globalVar : llvmModule->globals()) { if (globalVar.getName() == getGlobalCtorsVarName() || @@ -857,6 +885,18 @@ globalOp.setVisibility_( convertVisibilityFromLLVM(globalVar->getVisibility())); + if (globalVar->hasComdat()) { + llvm::Comdat *cd = globalVar->getComdat(); + ComdatOp cdregion = getGlobalComdatOp(); + if (ComdatSelectorOp selector = dyn_cast(cdregion.lookupSymbol(cd->getName()))) { + auto symbolRef = + SymbolRefAttr::get(builder.getContext(), getGlobalComdatOpName(), + FlatSymbolRefAttr::get(selector.getSymNameAttr())); + globalOp.setComdatAttr(symbolRef); + } else + return failure(); + } + return success(); } @@ -1752,6 +1792,8 @@ return {}; if (failed(moduleImport.convertMetadata())) return {}; + if (failed(moduleImport.convertComdats())) + return {}; if (failed(moduleImport.convertGlobals())) return {}; if (failed(moduleImport.convertFunctions())) diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h" @@ -714,6 +715,14 @@ : llvm::GlobalValue::NotThreadLocal, addrSpace); + if (auto comdat = op.getComdat()) { + auto name = comdat->getLeafReference().getValue(); + if (!llvmModule->getComdatSymbolTable().contains(name)) + return emitError(op.getLoc(), "global references non-existant comdat"); + auto *llcomdat = llvmModule->getOrInsertComdat(name); + var->setComdat(llcomdat); + } + if (op.getUnnamedAddr().has_value()) var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr())); @@ -1038,6 +1047,22 @@ return success(); } +LogicalResult ModuleTranslation::convertComdats() { + for (auto comdat : getModuleBody(mlirModule).getOps()) { + std::string cname = comdat.getName().str(); + for (auto selector : comdat.getOps()) { + StringRef name = selector.getName(); + llvm::Module *module = getLLVMModule(); + if (module->getComdatSymbolTable().contains(name)) { + return emitError(selector.getLoc()) << "comdat selection symbols must be unique even in different comdat regions"; + } + llvm::Comdat *comdat = module->getOrInsertComdat(name); + comdat->setSelectionKind(convertComdatToLLVM(selector.getComdat())); + } + } + return success(); +} + LogicalResult ModuleTranslation::createAccessGroupMetadata() { return loopAnnotationTranslation->createAccessGroupMetadata(); } @@ -1369,6 +1394,8 @@ ModuleTranslation translator(module, std::move(llvmModule)); if (failed(translator.convertFunctionSignatures())) return nullptr; + if (failed(translator.convertComdats())) + return nullptr; if (failed(translator.convertGlobals())) return nullptr; if (failed(translator.createAccessGroupMetadata())) @@ -1384,7 +1411,7 @@ llvm::IRBuilder<> llvmBuilder(llvmContext); for (Operation &o : getModuleBody(module).getOperations()) { if (!isa(&o) && + LLVM::GlobalDtorsOp, LLVM::MetadataOp, LLVM::ComdatOp>(&o) && !o.hasTrait() && failed(translator.convertOperation(o, llvmBuilder))) { return nullptr; diff --git a/mlir/test/Dialect/LLVMIR/comdat.mlir b/mlir/test/Dialect/LLVMIR/comdat.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/comdat.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK: llvm.comdat @__llvm_comdat +llvm.comdat @__llvm_comdat { + // CHECK: llvm.comdat_selector @any_comdat any + llvm.comdat_selector @any_comdat any + // CHECK: llvm.comdat_selector @exactmatch_comdat exactmatch + llvm.comdat_selector @exactmatch_comdat exactmatch + // CHECK: llvm.comdat_selector @largest_comdat largest + llvm.comdat_selector @largest_comdat largest + // CHECK: llvm.comdat_selector @nodeduplicate_comdat nodeduplicate + llvm.comdat_selector @nodeduplicate_comdat nodeduplicate + // CHECK: llvm.comdat_selector @samesize_comdat samesize + llvm.comdat_selector @samesize_comdat samesize +} + diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -64,6 +64,14 @@ // CHECK: llvm.mlir.global external @has_addr_space(32 : i64) {addr_space = 3 : i32} : i64 llvm.mlir.global external @has_addr_space(32 : i64) {addr_space = 3: i32} : i64 +// CHECK: llvm.comdat @__llvm_comdat +llvm.comdat @__llvm_comdat { + // CHECK: llvm.comdat_selector @any any + llvm.comdat_selector @any any +} +// CHECK: llvm.mlir.global external @any() comdat(@__llvm_comdat::@any) {addr_space = 0 : i32} : i64 +llvm.mlir.global @any() comdat(@__llvm_comdat::@any) : i64 + // CHECK-LABEL: references func.func @references() { // CHECK: llvm.mlir.addressof @".string" : !llvm.ptr diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1387,3 +1387,10 @@ // expected-error@+1 {{cannot cast pointers of different address spaces, use 'llvm.addrspacecast' instead}} %0 = llvm.bitcast %arg : !llvm.vec<4 x ptr<1>> to !llvm.vec<4 x ptr> } + +// ----- + +llvm.comdat @__llvm_comdat { + // expected-error@+1 {{only comdat selector symbols can appear in a comdat region}} + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -258,3 +258,14 @@ %0 = llvm.intr.experimental.stepvector : vector<7xf32> llvm.return %0 : vector<7xf32> } + +// ----- + +llvm.comdat @__llvm_comdat { + llvm.comdat_selector @foo any +} + +llvm.comdat @__llvm_comdat_1 { + // expected-error @below{{comdat selection symbols must be unique even in different comdat regions}} + llvm.comdat_selector @foo any +} \ No newline at end of file diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1,5 +1,21 @@ // RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +// Comdat sections +llvm.comdat @__llvm_comdat { + // CHECK: $any = comdat any + llvm.comdat_selector @any any + // CHECK: $exactmatch = comdat exactmatch + llvm.comdat_selector @exactmatch exactmatch + // CHECK: $largest = comdat largest + llvm.comdat_selector @largest largest + // CHECK: $nodeduplicate = comdat nodeduplicate + llvm.comdat_selector @nodeduplicate nodeduplicate + // CHECK: $samesize = comdat samesize + llvm.comdat_selector @samesize samesize +} + + // CHECK: @global_aligned32 = private global i64 42, align 32 "llvm.mlir.global"() ({}) {sym_name = "global_aligned32", global_type = i64, value = 42 : i64, linkage = #llvm.linkage, alignment = 32} : () -> () @@ -162,6 +178,20 @@ // CHECK: @sectionvar = internal constant [10 x i8] c"teststring", section ".mysection" llvm.mlir.global internal constant @sectionvar("teststring") {section = ".mysection"}: !llvm.array<10 x i8> +// +// Comdat attribute. +// +// CHECK: @has_any_comdat = internal constant i64 1, comdat($any) +llvm.mlir.global internal constant @has_any_comdat(1 : i64) comdat(@__llvm_comdat::@any) : i64 +// CHECK: @has_exactmatch_comdat = internal constant i64 1, comdat($exactmatch) +llvm.mlir.global internal constant @has_exactmatch_comdat(1 : i64) comdat(@__llvm_comdat::@exactmatch) : i64 +// CHECK: @has_largest_comdat = internal constant i64 1, comdat($largest) +llvm.mlir.global internal constant @has_largest_comdat(1 : i64) comdat(@__llvm_comdat::@largest) : i64 +// CHECK: @has_nodeduplicate_comdat = internal constant i64 1, comdat($nodeduplicate) +llvm.mlir.global internal constant @has_nodeduplicate_comdat(1 : i64) comdat(@__llvm_comdat::@nodeduplicate) : i64 +// CHECK: @has_samesize_comdat = internal constant i64 1, comdat($samesize) +llvm.mlir.global internal constant @has_samesize_comdat(1 : i64) comdat(@__llvm_comdat::@samesize) : i64 + // // Declarations of the allocation functions to be linked against. These are // inserted before other functions in the module.