diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -159,6 +159,7 @@ void setIsEmbedded(bool Value) { IsEmbedded = Value; } void setIsTargetCodegen(bool Value) { IsTargetCodegen = Value; } + void setOpenMPOffloadMandatory(bool Value) { OpenMPOffloadMandatory = Value; } void setFirstSeparator(StringRef FS) { FirstSeparator = FS; } void setSeparator(StringRef S) { Separator = S; } diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -38,6 +38,7 @@ MLIRDLTIDialect MLIRLLVMDialect MLIRLLVMIRTransforms + MLIROpenMPDialect MLIRTranslateLib MLIROpenMPDialect ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Transforms/RegionUtils.h" @@ -27,6 +28,8 @@ #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" +#include using namespace mlir; @@ -1035,7 +1038,7 @@ } /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering. -llvm::AtomicOrdering +static llvm::AtomicOrdering convertAtomicOrdering(std::optional ao) { if (!ao) return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering @@ -1059,7 +1062,6 @@ static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { - auto readOp = cast(opInst); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); @@ -1678,6 +1680,27 @@ return bodyGenStatus; } +/// Converts the module-level set of OpenMP requires clauses into LLVM IR using +/// OpenMPIRBuilder. +static LogicalResult +convertRequiresAttr(Operation &op, omp::ClauseRequiresAttr requiresAttr, + LLVM::ModuleTranslation &moduleTranslation) { + auto *ompBuilder = moduleTranslation.getOpenMPBuilder(); + + // No need to read requiresAttr here, because it has already been done in + // translateModuleToLLVMIR(). There, flags are stored in the + // OpenMPIRBuilderConfig object, available to the OpenMPIRBuilder. + auto *regFn = + ompBuilder->createRegisterRequires(ompBuilder->createPlatformSpecificName( + {"omp_offloading", "requires_reg"})); + + // Add registration function as global constructor + if (regFn) + llvm::appendToGlobalCtors(ompBuilder->M, regFn, /* Priority = */ 0); + + return success(); +} + namespace { /// Implementation of the dialect interface that converts operations belonging @@ -1693,6 +1716,8 @@ convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const final; + /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime + /// calls, or operation amendments LogicalResult amendOperation(Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final; @@ -1700,8 +1725,6 @@ } // namespace -/// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime -/// calls, or operation amendments LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const { @@ -1710,9 +1733,12 @@ .Case([&](mlir::omp::FlagsAttr rtlAttr) { return convertFlagsAttr(op, rtlAttr, moduleTranslation); }) + .Case([&](omp::ClauseRequiresAttr requiresAttr) { + return convertRequiresAttr(*op, requiresAttr, moduleTranslation); + }) .Default([&](Attribute attr) { - // fall through for omp attributes that do not require lowering and/or - // have no concrete definition and thus no type to define a case on + // Fall through for omp attributes that do not require lowering and/or + // have no concrete definition and thus no type to define a case on. return success(); }); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1261,20 +1261,24 @@ ompBuilder = std::make_unique(*llvmModule); ompBuilder->initialize(); - bool isDevice = false; - if (auto offloadMod = - dyn_cast(mlirModule)) - isDevice = offloadMod.getIsDevice(); - - // TODO: set the flags when available - llvm::OpenMPIRBuilderConfig Config( - isDevice, /* IsTargetCodegen */ false, - /* OpenMPOffloadMandatory */ false, - /* HasRequiresReverseOffload */ false, - /* HasRequiresUnifiedAddress */ false, - /* HasRequiresUnifiedSharedMemory */ false, - /* HasRequiresDynamicAllocators */ false); - ompBuilder->setConfig(Config); + // Set OpenMP IR Builder configuration + if (auto offloadMod = dyn_cast(mlirModule)) { + llvm::OpenMPIRBuilderConfig config; + config.setIsEmbedded(offloadMod.getIsDevice()); + config.setIsTargetCodegen(false); + config.setOpenMPOffloadMandatory(false); + + const auto requiresFlags = offloadMod.getRequires(); + config.setHasRequiresReverseOffload(bitEnumContainsAll( + requiresFlags, omp::ClauseRequires::reverse_offload)); + config.setHasRequiresUnifiedAddress(bitEnumContainsAll( + requiresFlags, omp::ClauseRequires::unified_address)); + config.setHasRequiresUnifiedSharedMemory(bitEnumContainsAll( + requiresFlags, omp::ClauseRequires::unified_shared_memory)); + config.setHasRequiresDynamicAllocators(bitEnumContainsAll( + requiresFlags, omp::ClauseRequires::dynamic_allocators)); + ompBuilder->setConfig(config); + } } return ompBuilder.get(); } diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2493,3 +2493,13 @@ // CHECK: @__omp_rtl_assume_no_thread_state = weak_odr hidden constant i32 1 // CHECK: @__omp_rtl_assume_no_nested_parallelism = weak_odr hidden constant i32 0 module attributes {omp.flags = #omp.flags} {} + +// ----- + +// Check that OpenMP requires flags are registered by a global constructor. +// CHECK: @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] +// CHECK-SAME: [{ i32, ptr, ptr } { i32 0, ptr @[[REG_FN:.*]], ptr null }] +// CHECK: define {{.*}} @[[REG_FN]]({{.*}}) +// CHECK-NOT: } +// CHECK: call void @__tgt_register_requires(i64 10) +module attributes {omp.requires = #omp} {}