diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -1040,9 +1040,10 @@ /*IsTargetCodegen*/ false, CGM.getLangOpts().OpenMPOffloadMandatory, /*HasRequiresReverseOffload*/ false, /*HasRequiresUnifiedAddress*/ false, hasRequiresUnifiedSharedMemory(), /*HasRequiresDynamicAllocators*/ false); - OMPBuilder.initialize(CGM.getLangOpts().OpenMPIsDevice - ? CGM.getLangOpts().OMPHostIRFile - : StringRef{}); + OMPBuilder.initialize(); + OMPBuilder.loadOffloadInfoMetadata(CGM.getLangOpts().OpenMPIsDevice + ? CGM.getLangOpts().OMPHostIRFile + : StringRef{}); OMPBuilder.setConfig(Config); } diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -159,6 +159,7 @@ void setIsEmbedded(bool Value) { IsEmbedded = Value; } void setIsTargetCodegen(bool Value) { IsTargetCodegen = Value; } + void setOpenMPOffloadMandatory(bool Value) { OpenMPOffloadMandatory = Value; } void setFirstSeparator(StringRef FS) { FirstSeparator = FS; } void setSeparator(StringRef S) { Separator = S; } @@ -425,14 +426,9 @@ /// Initialize the internal state, this will put structures types and /// potentially other helpers into the underlying module. Must be called - /// before any other method and only once! This internal state includes - /// Types used in the OpenMPIRBuilder generated from OMPKinds.def as well - /// as loading offload metadata for device from the OpenMP host IR file - /// passed in as the HostFilePath argument. - /// \param HostFilePath The path to the host IR file, used to load in - /// offload metadata for the device, allowing host and device to - /// maintain the same metadata mapping. - void initialize(StringRef HostFilePath = {}); + /// before any other method and only once! This internal state includes types + /// used in the OpenMPIRBuilder generated from OMPKinds.def. + void initialize(); void setConfig(OpenMPIRBuilderConfig C) { Config = C; } @@ -2243,6 +2239,15 @@ /// loaded from bitcode file, i.e, different from OpenMPIRBuilder::M module. void loadOffloadInfoMetadata(Module &M); + /// Loads all the offload entries information from the host IR + /// metadata read from the file passed in as the HostFilePath argument. This + /// function is only meant to be used with device code generation. + /// + /// \param HostFilePath The path to the host IR file, + /// used to load in offload metadata for the device, allowing host and device + /// to maintain the same metadata mapping. + void loadOffloadInfoMetadata(StringRef HostFilePath); + /// Gets (if variable with the given name already exist) or creates /// internal global variable with the specified Name. The created variable has /// linkage CommonLinkage by default and is initialized by null value. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -546,31 +546,7 @@ return Fn; } -void OpenMPIRBuilder::initialize(StringRef HostFilePath) { - initializeTypes(M); - - if (HostFilePath.empty()) - return; - - auto Buf = MemoryBuffer::getFile(HostFilePath); - if (std::error_code Err = Buf.getError()) { - report_fatal_error(("error opening host file from host file path inside of " - "OpenMPIRBuilder: " + - Err.message()) - .c_str()); - } - - LLVMContext Ctx; - auto M = expectedToErrorOrAndEmitErrors( - Ctx, parseBitcodeFile(Buf.get()->getMemBufferRef(), Ctx)); - if (std::error_code Err = M.getError()) { - report_fatal_error( - ("error parsing host file inside of OpenMPIRBuilder: " + Err.message()) - .c_str()); - } - - loadOffloadInfoMetadata(*M.get()); -} +void OpenMPIRBuilder::initialize() { initializeTypes(M); } void OpenMPIRBuilder::finalize(Function *Fn) { SmallPtrSet ParallelRegionBlockSet; @@ -5337,6 +5313,30 @@ } } +void OpenMPIRBuilder::loadOffloadInfoMetadata(StringRef HostFilePath) { + if (HostFilePath.empty()) + return; + + auto Buf = MemoryBuffer::getFile(HostFilePath); + if (std::error_code Err = Buf.getError()) { + report_fatal_error(("error opening host file from host file path inside of " + "OpenMPIRBuilder: " + + Err.message()) + .c_str()); + } + + LLVMContext Ctx; + auto M = expectedToErrorOrAndEmitErrors( + Ctx, parseBitcodeFile(Buf.get()->getMemBufferRef(), Ctx)); + if (std::error_code Err = M.getError()) { + report_fatal_error( + ("error parsing host file inside of OpenMPIRBuilder: " + Err.message()) + .c_str()); + } + + loadOffloadInfoMetadata(*M.get()); +} + Function *OpenMPIRBuilder::createRegisterRequires(StringRef Name) { // Skip the creation of the registration function if this is device codegen if (Config.isEmbedded()) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Transforms/RegionUtils.h" @@ -27,6 +28,8 @@ #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" +#include using namespace mlir; @@ -1040,7 +1043,7 @@ } /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering. -llvm::AtomicOrdering +static llvm::AtomicOrdering convertAtomicOrdering(std::optional ao) { if (!ao) return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering @@ -1675,6 +1678,27 @@ return bodyGenStatus; } +/// Converts the module-level set of OpenMP requires clauses into LLVM IR using +/// OpenMPIRBuilder. +static LogicalResult +convertRequiresAttr(Operation &op, omp::ClauseRequiresAttr requiresAttr, + LLVM::ModuleTranslation &moduleTranslation) { + auto *ompBuilder = moduleTranslation.getOpenMPBuilder(); + + // No need to read requiresAttr here, because it has already been done in + // translateModuleToLLVMIR(). There, flags are stored in the + // OpenMPIRBuilderConfig object, available to the OpenMPIRBuilder. + auto *regFn = + ompBuilder->createRegisterRequires(ompBuilder->createPlatformSpecificName( + {"omp_offloading", "requires_reg"})); + + // Add registration function as global constructor + if (regFn) + llvm::appendToGlobalCtors(ompBuilder->M, regFn, /* Priority = */ 0); + + return success(); +} + namespace { /// Implementation of the dialect interface that converts operations belonging @@ -1690,6 +1714,8 @@ convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const final; + /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime + /// calls, or operation amendments LogicalResult amendOperation(Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final; @@ -1697,28 +1723,72 @@ } // namespace -/// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime -/// calls, or operation amendments LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const { - - return llvm::TypeSwitch(attribute.getValue()) - .Case([&](mlir::omp::FlagsAttr rtlAttr) { - return convertFlagsAttr(op, rtlAttr, moduleTranslation); - }) - .Case([&](mlir::omp::VersionAttr versionAttr) { - llvm::OpenMPIRBuilder *ompBuilder = - moduleTranslation.getOpenMPBuilder(); - ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp", - versionAttr.getVersion()); - return success(); - }) - .Default([&](Attribute attr) { - // fall through for omp attributes that do not require lowering and/or - // have no concrete definition and thus no type to define a case on + return llvm::StringSwitch>( + attribute.getName()) + .Case("omp.is_device", + [&](Attribute attr) { + if (auto deviceAttr = attr.dyn_cast()) { + llvm::OpenMPIRBuilderConfig &config = + moduleTranslation.getOpenMPBuilder()->Config; + config.setIsEmbedded(deviceAttr.getValue()); + return success(); + } + return failure(); + }) + .Case("omp.host_ir_filepath", + [&](Attribute attr) { + if (auto filepathAttr = attr.dyn_cast()) { + llvm::OpenMPIRBuilder *ompBuilder = + moduleTranslation.getOpenMPBuilder(); + ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue()); + return success(); + } + return failure(); + }) + .Case("omp.flags", + [&](Attribute attr) { + if (auto rtlAttr = attr.dyn_cast()) + return convertFlagsAttr(op, rtlAttr, moduleTranslation); + return failure(); + }) + .Case("omp.version", + [&](Attribute attr) { + if (auto versionAttr = attr.dyn_cast()) { + llvm::OpenMPIRBuilder *ompBuilder = + moduleTranslation.getOpenMPBuilder(); + ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp", + versionAttr.getVersion()); + return success(); + } + return failure(); + }) + .Case( + "omp.requires", + [&](Attribute attr) { + if (auto requiresAttr = attr.dyn_cast()) { + using Requires = omp::ClauseRequires; + Requires flags = requiresAttr.getValue(); + llvm::OpenMPIRBuilderConfig &config = + moduleTranslation.getOpenMPBuilder()->Config; + config.setHasRequiresReverseOffload( + bitEnumContainsAll(flags, Requires::reverse_offload)); + config.setHasRequiresUnifiedAddress( + bitEnumContainsAll(flags, Requires::unified_address)); + config.setHasRequiresUnifiedSharedMemory( + bitEnumContainsAll(flags, Requires::unified_shared_memory)); + config.setHasRequiresDynamicAllocators( + bitEnumContainsAll(flags, Requires::dynamic_allocators)); + return convertRequiresAttr(*op, requiresAttr, moduleTranslation); + } + return failure(); + }) + .Default([](Attribute) { + // Fall through for omp attributes that do not require lowering. return success(); - }); + })(attribute.getValue()); return failure(); } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -20,8 +20,6 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" -#include "mlir/Dialect/OpenMP/OpenMPDialect.h" -#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -1273,30 +1271,18 @@ llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() { if (!ompBuilder) { ompBuilder = std::make_unique(*llvmModule); + ompBuilder->initialize(); - bool isDevice = false; - llvm::StringRef hostIRFilePath = ""; - - if (Attribute deviceAttr = mlirModule->getAttr("omp.is_device")) - if (::llvm::isa(deviceAttr)) - isDevice = ::llvm::dyn_cast(deviceAttr).getValue(); - - if (Attribute filepath = mlirModule->getAttr("omp.host_ir_filepath")) - if (::llvm::isa(filepath)) - hostIRFilePath = - ::llvm::dyn_cast(filepath).getValue(); - - ompBuilder->initialize(hostIRFilePath); - - // TODO: set the flags when available - llvm::OpenMPIRBuilderConfig config( - isDevice, /* IsTargetCodegen */ false, + // Flags represented as top-level OpenMP dialect attributes are set in + // OpenMPDialectLLVMIRTranslationInterface::amendOperation(). Here we set + // the default configuration. + ompBuilder->setConfig(llvm::OpenMPIRBuilderConfig( + /* IsEmbedded */ false, /* IsTargetCodegen */ false, /* OpenMPOffloadMandatory */ false, /* HasRequiresReverseOffload */ false, /* HasRequiresUnifiedAddress */ false, /* HasRequiresUnifiedSharedMemory */ false, - /* OpenMPOffloadMandatory */ false); - ompBuilder->setConfig(config); + /* OpenMPOffloadMandatory */ false)); } return ompBuilder.get(); } @@ -1383,11 +1369,17 @@ return nullptr; if (failed(translator.createTBAAMetadata())) return nullptr; + + // Convert module itself before any functions and operations inside, so that + // the OpenMPIRBuilder is configured with the OpenMP dialect attributes + // attached to the module by the amendOperation() flow before then. + llvm::IRBuilder<> llvmBuilder(llvmContext); + if (failed(translator.convertOperation(*module, llvmBuilder))) + return nullptr; if (failed(translator.convertFunctions())) return nullptr; // Convert other top-level operations if possible. - llvm::IRBuilder<> llvmBuilder(llvmContext); for (Operation &o : getModuleBody(module).getOperations()) { if (!isa(&o) && @@ -1397,10 +1389,6 @@ } } - // Convert module itself. - if (failed(translator.convertOperation(*module, llvmBuilder))) - return nullptr; - if (llvm::verifyModule(*translator.llvmModule, &llvm::errs())) return nullptr; diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2543,3 +2543,13 @@ // CHECK: @__omp_rtl_assume_no_thread_state = weak_odr hidden constant i32 1 // CHECK: @__omp_rtl_assume_no_nested_parallelism = weak_odr hidden constant i32 0 module attributes {omp.flags = #omp.flags} {} + +// ----- + +// Check that OpenMP requires flags are registered by a global constructor. +// CHECK: @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] +// CHECK-SAME: [{ i32, ptr, ptr } { i32 0, ptr @[[REG_FN:.*]], ptr null }] +// CHECK: define {{.*}} @[[REG_FN]]({{.*}}) +// CHECK-NOT: } +// CHECK: call void @__tgt_register_requires(i64 10) +module attributes {omp.requires = #omp} {}