diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -29,6 +29,15 @@ let assemblyFormat = "`<` $CallingConv `>`"; } +//===----------------------------------------------------------------------===// +// ComdatAttr +//===----------------------------------------------------------------------===// + +def ComdatAttr : LLVM_Attr<"Comdat", "comdat"> { + let parameters = (ins "comdat::Comdat":$comdat); + let assemblyFormat = "$comdat"; +} + //===----------------------------------------------------------------------===// // LinkageAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h @@ -66,6 +66,7 @@ // TODO: this shouldn't be needed after we unify the attribute generation, i.e. // --gen-attr-* and --gen-attrdef-*. using cconv::CConv; +using comdat::Comdat; using linkage::Linkage; } // namespace LLVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td @@ -589,6 +589,42 @@ "::mlir::LLVM::LinkageAttr::get($_builder.getContext(), $0)"; } + +//===----------------------------------------------------------------------===// +// Comdat +//===----------------------------------------------------------------------===// + +def ComdatAny + : LLVM_EnumAttrCase<"Any", "any", "Any", 0>; +def ComdatExactMatch + : LLVM_EnumAttrCase<"ExactMatch", "exactmatch", "ExactMatch", 1>; +def ComdatLargest + : LLVM_EnumAttrCase<"Largest", "largest", "Largest", 2>; +def ComdatNoDeduplicate + : LLVM_EnumAttrCase<"NoDeduplicate", "nodeduplicate", "NoDeduplicate", 3>; +def ComdatSameSize + : LLVM_EnumAttrCase<"SameSize", "samesize", "SameSize", 4>; + +def ComdatEnum : LLVM_EnumAttr< + "Comdat", + "::llvm::Comdat::SelectionKind", + "LLVM Comdat Types", + [ComdatAny, ComdatExactMatch, ComdatLargest, + ComdatNoDeduplicate, ComdatSameSize]> { + let cppNamespace = "::mlir::LLVM::comdat"; +} + +def Comdat : DialectAttr< + LLVM_Dialect, + CPred<"$_self.isa<::mlir::LLVM::ComdatAttr>()">, + "LLVM Comdat selection kind"> { + let storageType = "::mlir::LLVM::ComdatAttr"; + let returnType = "::mlir::LLVM::Comdat"; + let convertFromStorage = "$_self.getComdat()"; + let constBuilderCall = + "::mlir::LLVM::ComdatAttr::get($_builder.getContext(), $0)"; +} + //===----------------------------------------------------------------------===// // UnnamedAddr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1316,6 +1316,7 @@ DefaultValuedAttr, "0">:$addr_space, OptionalAttr:$unnamed_addr, OptionalAttr:$section, + OptionalAttr:$comdat, DefaultValuedAttr:$visibility_ ); let summary = "LLVM dialect global."; @@ -1424,6 +1425,7 @@ CArg<"unsigned", "0">:$addrSpace, CArg<"bool", "false">:$dsoLocal, CArg<"bool", "false">:$thread_local_, + CArg<"std::optional", "std::nullopt">:$comdat, CArg<"ArrayRef", "{}">:$attrs)> ]; @@ -1515,6 +1517,47 @@ let hasVerifier = 1; } +def LLVM_ComdatSelectorOp : LLVM_Op<"comdat_selector", [Symbol]> { + let arguments = (ins + SymbolNameAttr:$sym_name, + Comdat:$comdat + ); + + let summary = "LLVM dialect comdat declaration"; + + let description = [{ + Provides access to object file COMDAT section/group functionality. + + Examples: + ```mlir + llvm.mlir.comdat @any_comdat any + llvm.mlir.global internal constant @has_any_comdat(1 : i64) comdat(@any_comdat) : i64 + ``` + }]; + let assemblyFormat = "$sym_name $comdat attr-dict"; +} + +def LLVM_ComdatOp : LLVM_Op<"comdat", [NoTerminator, NoRegionArguments, SymbolTable, Symbol]> { + let arguments = (ins + SymbolNameAttr:$sym_name + ); + let summary = "LLVM dialect comdat declaration"; + + let description = [{ + Provides access to object file COMDAT section/group functionality. + + Examples: + ```mlir + llvm.mlir.comdat @any_comdat any + llvm.mlir.global internal constant @has_any_comdat(1 : i64) comdat(@any_comdat) : i64 + ``` + }]; + let regions = (region SizedRegion<1>:$body); + + let assemblyFormat = "$sym_name $body attr-dict"; + let hasRegionVerifier = 1; +} + def LLVM_LLVMFuncOp : LLVM_Op<"func", [ AutomaticAllocationScope, IsolatedFromAbove, FunctionOpInterface, CallableOpInterface diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -274,6 +274,7 @@ LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder); LogicalResult convertFunctionSignatures(); LogicalResult convertFunctions(); + LogicalResult convertComdats(); LogicalResult convertGlobals(); LogicalResult convertOneFunction(LLVMFuncOp func); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -38,6 +38,7 @@ #include #include +#include using namespace mlir; using namespace mlir::LLVM; @@ -1593,6 +1594,21 @@ return success(); } +//===----------------------------------------------------------------------===// +// Verifier for LLVM::ComdatOp. +//===----------------------------------------------------------------------===// + +LogicalResult ComdatOp::verifyRegions() { + Region &body = getBody(); + for (Operation &op : body.getOps()) { + if (!isa(op)) + return emitOpError( + "only comdat selector symbols can appear in a comdat region"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // Builder, printer and verifier for LLVM::GlobalOp. //===----------------------------------------------------------------------===// @@ -1601,6 +1617,7 @@ bool isConstant, Linkage linkage, StringRef name, Attribute value, uint64_t alignment, unsigned addrSpace, bool dsoLocal, bool threadLocal, + std::optional comdat, ArrayRef attrs) { result.addAttribute(getSymNameAttrName(result.name), builder.getStringAttr(name)); @@ -1616,6 +1633,8 @@ if (threadLocal) result.addAttribute(getThreadLocal_AttrName(result.name), builder.getUnitAttr()); + if (comdat) + result.addAttribute(getComdatAttrName(result.name), *comdat); // Only add an alignment attribute if the "alignment" input // is different from 0. The value must also be a power of two, but @@ -1652,6 +1671,9 @@ if (auto value = getValueOrNull()) p.printAttribute(value); p << ')'; + if (auto cd = getComdat()) + p << " comdat(" << *cd << ')'; + // Note that the alignment attribute is printed using the // default syntax here, even though it is an inherent attribute // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) @@ -1660,7 +1682,7 @@ getGlobalTypeAttrName(), getConstantAttrName(), getValueAttrName(), getLinkageAttrName(), getUnnamedAddrAttrName(), getThreadLocal_AttrName(), - getVisibility_AttrName()}); + getVisibility_AttrName(), getComdatAttrName()}); // Print the trailing type unless it's a string global. if (llvm::dyn_cast_or_null(getValueOrNull())) @@ -1768,6 +1790,15 @@ return failure(); } + if (succeeded(parser.parseOptionalKeyword("comdat"))) { + SymbolRefAttr comdat; + if (parser.parseLParen() || parser.parseAttribute(comdat) || + parser.parseRParen()) + return failure(); + + result.addAttribute(getComdatAttrName(result.name), comdat); + } + SmallVector types; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseOptionalColonTypeList(types)) @@ -1850,6 +1881,12 @@ } } + if (getComdat()) { + auto *op = SymbolTable::lookupNearestSymbolFrom(*this, getComdatAttr()); + if (!llvm::isa_and_nonnull(op)) + return emitOpError() << "expected comdat symbol"; + } + std::optional alignAttr = getAlignment(); if (alignAttr.has_value()) { uint64_t value = alignAttr.value(); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h" @@ -714,6 +715,16 @@ : llvm::GlobalValue::NotThreadLocal, addrSpace); + if (auto comdat = op.getComdat()) { + auto cname = comdat->getRootReference().getValue().str(); + auto sname = comdat->getLeafReference().getValue().str(); + auto name = cname + "_" + sname; + if (!llvmModule->getComdatSymbolTable().contains(name)) + return emitError(op.getLoc(), "global references non-existant comdat"); + auto *llcomdat = llvmModule->getOrInsertComdat(cname + "_" + sname); + var->setComdat(llcomdat); + } + if (op.getUnnamedAddr().has_value()) var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr())); @@ -1038,6 +1049,19 @@ return success(); } +LogicalResult ModuleTranslation::convertComdats() { + for (auto comdat : getModuleBody(mlirModule).getOps()) { + std::string cname = comdat.getName().str(); + for (auto selector : comdat.getOps()) { + std::string sname = selector.getName().str(); + llvm::Module *module = getLLVMModule(); + llvm::Comdat *comdat = module->getOrInsertComdat(cname + "_" + sname); + comdat->setSelectionKind(convertComdatToLLVM(selector.getComdat())); + } + } + return success(); +} + LogicalResult ModuleTranslation::createAccessGroupMetadata() { return loopAnnotationTranslation->createAccessGroupMetadata(); } @@ -1369,6 +1393,8 @@ ModuleTranslation translator(module, std::move(llvmModule)); if (failed(translator.convertFunctionSignatures())) return nullptr; + if (failed(translator.convertComdats())) + return nullptr; if (failed(translator.convertGlobals())) return nullptr; if (failed(translator.createAccessGroupMetadata())) @@ -1384,7 +1410,8 @@ llvm::IRBuilder<> llvmBuilder(llvmContext); for (Operation &o : getModuleBody(module).getOperations()) { if (!isa(&o) && + LLVM::GlobalDtorsOp, LLVM::MetadataOp, LLVM::ComdatOp, + LLVM::ComdatSelectorOp>(&o) && !o.hasTrait() && failed(translator.convertOperation(o, llvmBuilder))) { return nullptr; diff --git a/mlir/test/Dialect/LLVMIR/comdat.mlir b/mlir/test/Dialect/LLVMIR/comdat.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/comdat.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK: llvm.comdat @__llvm_comdat +llvm.comdat @__llvm_comdat { + // CHECK: llvm.comdat_selector @any_comdat any + llvm.comdat_selector @any_comdat any + // CHECK: llvm.comdat_selector @exactmatch_comdat exactmatch + llvm.comdat_selector @exactmatch_comdat exactmatch + // CHECK: llvm.comdat_selector @largest_comdat largest + llvm.comdat_selector @largest_comdat largest + // CHECK: llvm.comdat_selector @nodeduplicate_comdat nodeduplicate + llvm.comdat_selector @nodeduplicate_comdat nodeduplicate + // CHECK: llvm.comdat_selector @samesize_comdat samesize + llvm.comdat_selector @samesize_comdat samesize +} \ No newline at end of file diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -64,6 +64,14 @@ // CHECK: llvm.mlir.global external @has_addr_space(32 : i64) {addr_space = 3 : i32} : i64 llvm.mlir.global external @has_addr_space(32 : i64) {addr_space = 3: i32} : i64 +// CHECK: llvm.comdat @__llvm_comdat +llvm.comdat @__llvm_comdat { + // CHECK: llvm.comdat_selector @any any + llvm.comdat_selector @any any +} +// CHECK: llvm.mlir.global external @any() comdat(@__llvm_comdat::@any) {addr_space = 0 : i32} : i64 +llvm.mlir.global @any() comdat(@__llvm_comdat::@any) : i64 + // CHECK-LABEL: references func.func @references() { // CHECK: llvm.mlir.addressof @".string" : !llvm.ptr diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1,5 +1,21 @@ // RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +// Comdat sections +llvm.comdat @__llvm_comdat { + // CHECK: $__llvm_comdat_any = comdat any + llvm.comdat_selector @any any + // CHECK: $__llvm_comdat_exactmatch = comdat exactmatch + llvm.comdat_selector @exactmatch exactmatch + // CHECK: $__llvm_comdat_largest = comdat largest + llvm.comdat_selector @largest largest + // CHECK: $__llvm_comdat_nodeduplicate = comdat nodeduplicate + llvm.comdat_selector @nodeduplicate nodeduplicate + // CHECK: $__llvm_comdat_samesize = comdat samesize + llvm.comdat_selector @samesize samesize +} + + // CHECK: @global_aligned32 = private global i64 42, align 32 "llvm.mlir.global"() ({}) {sym_name = "global_aligned32", global_type = i64, value = 42 : i64, linkage = #llvm.linkage, alignment = 32} : () -> () @@ -162,6 +178,20 @@ // CHECK: @sectionvar = internal constant [10 x i8] c"teststring", section ".mysection" llvm.mlir.global internal constant @sectionvar("teststring") {section = ".mysection"}: !llvm.array<10 x i8> +// +// Comdat attribute. +// +// CHECK: @has_any_comdat = internal constant i64 1, comdat($__llvm_comdat_any) +llvm.mlir.global internal constant @has_any_comdat(1 : i64) comdat(@__llvm_comdat::@any) : i64 +// CHECK: @has_exactmatch_comdat = internal constant i64 1, comdat($__llvm_comdat_exactmatch) +llvm.mlir.global internal constant @has_exactmatch_comdat(1 : i64) comdat(@__llvm_comdat::@exactmatch) : i64 +// CHECK: @has_largest_comdat = internal constant i64 1, comdat($__llvm_comdat_largest) +llvm.mlir.global internal constant @has_largest_comdat(1 : i64) comdat(@__llvm_comdat::@largest) : i64 +// CHECK: @has_nodeduplicate_comdat = internal constant i64 1, comdat($__llvm_comdat_nodeduplicate) +llvm.mlir.global internal constant @has_nodeduplicate_comdat(1 : i64) comdat(@__llvm_comdat::@nodeduplicate) : i64 +// CHECK: @has_samesize_comdat = internal constant i64 1, comdat($__llvm_comdat_samesize) +llvm.mlir.global internal constant @has_samesize_comdat(1 : i64) comdat(@__llvm_comdat::@samesize) : i64 + // // Declarations of the allocation functions to be linked against. These are // inserted before other functions in the module.