diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1153,6 +1153,66 @@ let verifier = "return ::verify(*this);"; } +def LLVM_GlobalCtorsOp : LLVM_Op<"mlir.global_ctors", [ + DeclareOpInterfaceMethods]> { + let arguments = (ins FlatSymbolRefArrayAttr + : $ctors, I32ArrayAttr + : $priorities); + let summary = "LLVM dialect global_ctors."; + let description = [{ + Specifies a list of constructor functions and priorities. The functions + referenced by this array will be called in ascending order of priority (i.e. + lowest first) when the module is loaded. The order of functions with the + same priority is not defined. This operation is translated to LLVM's + global_ctors global variable. The initializer functions are run at load + time. The `data` field present in LLVM's global_ctors variable is not + modeled here. + + Examples: + + ```mlir + llvm.mlir.global_ctors {@ctor} + + llvm.func @ctor() { + ... + llvm.return + } + ``` + + }]; + let verifier = [{ return ::verify(*this); }]; + let assemblyFormat = "attr-dict"; +} + +def LLVM_GlobalDtorsOp : LLVM_Op<"mlir.global_dtors", [ + DeclareOpInterfaceMethods]> { + let arguments = (ins + FlatSymbolRefArrayAttr:$dtors, + I32ArrayAttr:$priorities + ); + let summary = "LLVM dialect global_dtors."; + let description = [{ + Specifies a list of destructor functions and priorities. The functions + referenced by this array will be called in descending order of priority (i.e. + highest first) when the module is unloaded. The order of functions with the + same priority is not defined. This operation is translated to LLVM's + global_dtors global variable. The `data` field present in LLVM's + global_dtors variable is not modeled here. + + Examples: + + ```mlir + llvm.func @dtor() { + llvm.return + } + llvm.mlir.global_dtors {@dtor} + ``` + + }]; + let verifier = [{ return ::verify(*this); }]; + let assemblyFormat = "attr-dict"; +} + def LLVM_LLVMFuncOp : LLVM_Op<"func", [AutomaticAllocationScope, IsolatedFromAbove, FunctionLike, Symbol]> { let summary = "LLVM dialect function."; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1602,6 +1602,11 @@ let constBuilderCall = ?; } +def FlatSymbolRefArrayAttr : + TypedArrayAttrBase { + let constBuilderCall = ?; +} + //===----------------------------------------------------------------------===// // Derive attribute kinds diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1623,6 +1623,63 @@ return success(); } +//===----------------------------------------------------------------------===// +// LLVM::GlobalCtorsOp +//===----------------------------------------------------------------------===// + +/// Verifies `symbol`'s use in `op`. +static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, + Operation *op, + SymbolTableCollection &symbolTable) { + StringRef name = symbol.getValue(); + auto func = + symbolTable.lookupNearestSymbolFrom(op, symbol.getAttr()); + if (!func) + return op->emitOpError("'") + << name << "' does not reference a valid LLVM function"; + if (func.isExternal()) + return op->emitOpError("'") << name << "' does not have a definition"; + return success(); +} + +LogicalResult +GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + for (Attribute ctor : this->ctors()) { + if (failed(verifySymbolAttrUse(ctor.cast(), *this, + symbolTable))) + return failure(); + } + return success(); +} + +LogicalResult +GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + for (Attribute dtor : this->dtors()) { + if (failed(verifySymbolAttrUse(dtor.cast(), *this, + symbolTable))) + return failure(); + } + return success(); +} + +static LogicalResult verify(GlobalCtorsOp op) { + if (op.ctors().size() != op.priorities().size()) + return op.emitError( + "mismatch between the number of ctors and the number of priorities"); + return success(); +} + +//===----------------------------------------------------------------------===// +// LLVM::GlobalDtorsOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(GlobalDtorsOp op) { + if (op.dtors().size() != op.priorities().size()) + return op.emitError( + "mismatch between the number of dtors and the number of priorities"); + return success(); +} + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::ShuffleVectorOp. //===----------------------------------------------------------------------===// @@ -2351,7 +2408,7 @@ op->hasTrait(); } -static constexpr const FastmathFlags FastmathFlagsList[] = { +static constexpr const FastmathFlags fastmathFlagsList[] = { // clang-format off FastmathFlags::nnan, FastmathFlags::ninf, @@ -2366,7 +2423,7 @@ void FMFAttr::print(DialectAsmPrinter &printer) const { printer << "fastmath<"; - auto flags = llvm::make_filter_range(FastmathFlagsList, [&](auto flag) { + auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) { return bitEnumContains(this->getFlags(), flag); }); llvm::interleaveComma(flags, printer, diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -42,6 +42,7 @@ #include "llvm/IR/Verifier.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" using namespace mlir; using namespace mlir::LLVM; @@ -553,7 +554,7 @@ } /// Create named global variables that correspond to llvm.mlir.global -/// definitions. +/// definitions. Convert llvm.global_ctors and global_dtors ops. LogicalResult ModuleTranslation::convertGlobals() { for (auto op : getModuleBody(mlirModule).getOps()) { llvm::Type *type = convertType(op.getType()); @@ -622,6 +623,26 @@ } } + // Convert llvm.mlir.global_ctors and dtors. + for (Operation &op : getModuleBody(mlirModule)) { + auto ctorOp = dyn_cast(op); + auto dtorOp = dyn_cast(op); + if (!ctorOp && !dtorOp) + continue; + auto range = ctorOp ? llvm::zip(ctorOp.ctors(), ctorOp.priorities()) + : llvm::zip(dtorOp.dtors(), dtorOp.priorities()); + auto appendGlobalFn = + ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors; + for (auto symbolAndPriority : range) { + llvm::Function *f = lookupFunction( + std::get<0>(symbolAndPriority).cast().getValue()); + appendGlobalFn( + *llvmModule.get(), f, + std::get<1>(symbolAndPriority).cast().getInt(), + /*Data=*/nullptr); + } + } + return success(); } @@ -1025,7 +1046,8 @@ // Convert other top-level operations if possible. llvm::IRBuilder<> llvmBuilder(llvmContext); for (Operation &o : getModuleBody(module).getOperations()) { - if (!isa(&o) && + if (!isa(&o) && !o.hasTrait() && failed(translator.convertOperation(o, llvmBuilder))) { return nullptr; diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -209,3 +209,21 @@ // expected-error @+1 {{op the type must be a pointer to the type of the referenced global}} llvm.mlir.addressof @g : !llvm.ptr } + +// ----- + +llvm.func @ctor() { + llvm.return +} + +// CHECK: llvm.mlir.global_ctors {ctors = [@ctor], priorities = [0 : i32]} +llvm.mlir.global_ctors { ctors = [@ctor], priorities = [0 : i32]} + +// ----- + +llvm.func @dtor() { + llvm.return +} + +// CHECK: llvm.mlir.global_dtors {dtors = [@dtor], priorities = [0 : i32]} +llvm.mlir.global_dtors { dtors = [@dtor], priorities = [0 : i32]} diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -5,6 +5,36 @@ // ----- +llvm.func @ctor() { + llvm.return +} + +// expected-error@+1{{mismatch between the number of ctors and the number of priorities}} +llvm.mlir.global_ctors {ctors = [@ctor], priorities = []} + +// ----- + +llvm.func @dtor() { + llvm.return +} + +// expected-error@+1{{mismatch between the number of dtors and the number of priorities}} +llvm.mlir.global_dtors {dtors = [@dtor], priorities = [0 : i32, 32767 : i32]} + +// ----- + +// expected-error@+1{{'ctor' does not reference a valid LLVM function}} +llvm.mlir.global_ctors {ctors = [@ctor], priorities = [0 : i32]} + +// ----- + +llvm.func @dtor() + +// expected-error@+1{{'dtor' does not have a definition}} +llvm.mlir.global_dtors {dtors = [@dtor], priorities = [0 : i32]} + +// ----- + // expected-error@+1{{expected llvm.noalias argument attribute to be a unit attribute}} func @invalid_noalias(%arg0: i32 {llvm.noalias = 3}) { "llvm.return"() : () -> () diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1381,6 +1381,24 @@ // ----- +// CHECK: @llvm.global_ctors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 0, void ()* @foo, i8* null }] +llvm.mlir.global_ctors { ctors = [@foo], priorities = [0 : i32]} + +llvm.func @foo() { + llvm.return +} + +// ----- + +// CHECK: @llvm.global_dtors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 0, void ()* @foo, i8* null }] +llvm.mlir.global_dtors { dtors = [@foo], priorities = [0 : i32]} + +llvm.func @foo() { + llvm.return +} + +// ----- + // Check that branch weight attributes are exported properly as metadata. llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 { // CHECK: !prof ![[NODE:[0-9]+]]