diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -38,6 +38,7 @@ MLIRDLTIDialect MLIRLLVMDialect MLIRLLVMIRTransforms + MLIROpenMPDialect MLIRTranslateLib ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" @@ -24,6 +25,8 @@ #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" +#include using namespace mlir; @@ -1031,11 +1034,24 @@ return success(); } +static std::optional +getAtomicDefaultMemOrder(Operation &opInst) { + // Try to get the omp.atomic_default_mem_order attribute, if present + if (auto offloadModule = + opInst.getParentOfType()) + return offloadModule.getAtomicDefaultMemOrder(); + + return std::nullopt; +} + /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering. -llvm::AtomicOrdering -convertAtomicOrdering(std::optional ao) { +static llvm::AtomicOrdering +convertAtomicOrdering(std::optional ao, + std::optional defaultAo) { + // If not specified, try using the default atomic ordering gathered from a + // requires atomic_mem_default_order clause, if present if (!ao) - return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering + ao = defaultAo.value_or(omp::ClauseMemoryOrderKind::Relaxed); switch (*ao) { case omp::ClauseMemoryOrderKind::Seq_cst: @@ -1056,13 +1072,14 @@ static LogicalResult convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { - auto readOp = cast(opInst); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); - llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrderVal()); + auto defaultAO = getAtomicDefaultMemOrder(opInst); + llvm::AtomicOrdering AO = + convertAtomicOrdering(readOp.getMemoryOrderVal(), defaultAO); llvm::Value *x = moduleTranslation.lookupValue(readOp.getX()); llvm::Value *v = moduleTranslation.lookupValue(readOp.getV()); @@ -1083,7 +1100,9 @@ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); - llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrderVal()); + auto defaultAO = getAtomicDefaultMemOrder(opInst); + llvm::AtomicOrdering ao = + convertAtomicOrdering(writeOp.getMemoryOrderVal(), defaultAO); llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getValue()); llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getAddress()); llvm::Type *ty = moduleTranslation.convertType(writeOp.getValue().getType()); @@ -1147,8 +1166,9 @@ /*isSigned=*/false, /*isVolatile=*/false}; + auto defaultAO = getAtomicDefaultMemOrder(*opInst.getOperation()); llvm::AtomicOrdering atomicOrdering = - convertAtomicOrdering(opInst.getMemoryOrderVal()); + convertAtomicOrdering(opInst.getMemoryOrderVal(), defaultAO); // Generate update code. LogicalResult updateGenStatus = success(); @@ -1236,8 +1256,9 @@ /*isSigned=*/false, /*isVolatile=*/false}; + auto defaultAO = getAtomicDefaultMemOrder(*atomicCaptureOp.getOperation()); llvm::AtomicOrdering atomicOrdering = - convertAtomicOrdering(atomicCaptureOp.getMemoryOrderVal()); + convertAtomicOrdering(atomicCaptureOp.getMemoryOrderVal(), defaultAO); LogicalResult updateGenStatus = success(); auto updateFn = [&](llvm::Value *atomicx, @@ -1574,6 +1595,27 @@ return success(); } +/// Converts the module-level set of OpenMP requires clauses into LLVM IR using +/// OpenMPIRBuilder. +static LogicalResult +convertRequiresAttr(Operation &op, omp::ClauseRequiresAttr requiresAttr, + LLVM::ModuleTranslation &moduleTranslation) { + auto *ompBuilder = moduleTranslation.getOpenMPBuilder(); + + // No need to read requiresAttr here, because it has already been done in + // translateModuleToLLVMIR(). There, flags are stored in the + // OpenMPIRBuilderConfig object, available to the OpenMPIRBuilder. + auto *regFn = + ompBuilder->createRegisterRequires(ompBuilder->createPlatformSpecificName( + {"omp_offloading", "requires_reg"})); + + // Add registration function as global constructor + if (regFn) + llvm::appendToGlobalCtors(ompBuilder->M, regFn, /* Priority = */ 0); + + return success(); +} + namespace { /// Implementation of the dialect interface that converts operations belonging @@ -1589,6 +1631,8 @@ convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const final; + /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime + /// calls, or operation amendments LogicalResult amendOperation(Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final; @@ -1596,8 +1640,6 @@ } // namespace -/// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime -/// calls, or operation amendments LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const { @@ -1606,9 +1648,14 @@ .Case([&](mlir::omp::FlagsAttr rtlAttr) { return convertFlagsAttr(op, rtlAttr, moduleTranslation); }) + .Case([&](omp::ClauseRequiresAttr requiresAttr) { + return convertRequiresAttr(*op, requiresAttr, moduleTranslation); + }) .Default([&](Attribute attr) { - // fall through for omp attributes that do not require lowering and/or - // have no concrete definition and thus no type to define a case on + // Fall through for omp attributes that do not require lowering and/or + // have no concrete definition and thus no type to define a case on. + // The omp.atomic_default_mem_order attribute is read directly during + // OpenMP atomic ops lowering return success(); }); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1322,6 +1322,25 @@ LLVM::ensureDistinctSuccessors(module); ModuleTranslation translator(module, std::move(llvmModule)); + + // Set OpenMP IR Builder configuration + if (auto offloadMod = dyn_cast(module)) { + llvm::OpenMPIRBuilderConfig config; + config.setIsEmbedded(offloadMod.getIsDevice()); + config.setIsTargetCodegen(false); + + const auto requiresFlags = offloadMod.getRequires(); + config.setHasRequiresReverseOffload(bitEnumContainsAll( + requiresFlags, omp::ClauseRequires::reverse_offload)); + config.setHasRequiresUnifiedAddress(bitEnumContainsAll( + requiresFlags, omp::ClauseRequires::unified_address)); + config.setHasRequiresUnifiedSharedMemory(bitEnumContainsAll( + requiresFlags, omp::ClauseRequires::unified_shared_memory)); + config.setHasRequiresDynamicAllocators(bitEnumContainsAll( + requiresFlags, omp::ClauseRequires::dynamic_allocators)); + translator.getOpenMPBuilder()->setConfig(config); + } + if (failed(translator.convertFunctionSignatures())) return nullptr; if (failed(translator.convertGlobals())) diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2537,3 +2537,51 @@ // CHECK: @__omp_rtl_assume_no_nested_parallelism = weak_odr hidden constant i32 0 module attributes {omp.flags = #omp.flags, omp.is_device = #omp.isdevice} {} + +// ----- + +// Check that the atomic default memory order is picked up by atomic operations. +module attributes { + omp.atomic_default_mem_order = #omp +} { + // CHECK-LABEL: @omp_atomic_default_mem_order + // CHECK-SAME: (ptr %[[ARG0:.*]], ptr %[[ARG1:.*]], i32 %[[EXPR:.*]]) + llvm.func @omp_atomic_default_mem_order(%arg0 : !llvm.ptr, + %arg1 : !llvm.ptr, + %expr : i32) -> () { + + // CHECK: %[[X1:.*]] = load atomic i32, ptr %[[ARG0]] seq_cst, align 4 + // CHECK: store i32 %[[X1]], ptr %[[ARG1]], align 4 + omp.atomic.read %arg1 = %arg0 : !llvm.ptr, i32 + + // CHECK: store atomic i32 %[[EXPR]], ptr %[[ARG1]] seq_cst, align 4 + // CHECK: call void @__kmpc_flush(ptr @{{.*}}) + omp.atomic.write %arg1 = %expr : !llvm.ptr, i32 + + // CHECK: atomicrmw add ptr %[[ARG1]], i32 %[[EXPR]] seq_cst + omp.atomic.update %arg1 : !llvm.ptr { + ^bb0(%xval: i32): + %newval = llvm.add %xval, %expr : i32 + omp.yield(%newval : i32) + } + + // CHECK: %[[xval:.*]] = atomicrmw xchg ptr %[[ARG0]], i32 %[[EXPR]] seq_cst + // CHECK: store i32 %[[xval]], ptr %[[ARG1]] + omp.atomic.capture { + omp.atomic.read %arg1 = %arg0 : !llvm.ptr, i32 + omp.atomic.write %arg0 = %expr : !llvm.ptr, i32 + } + + llvm.return + } +} + +// ----- + +// Check that OpenMP requires flags are registered by a global constructor. +// CHECK: @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] +// CHECK-SAME: [{ i32, ptr, ptr } { i32 0, ptr @[[REG_FN:.*]], ptr null }] +// CHECK: define {{.*}} @[[REG_FN]]({{.*}}) +// CHECK-NOT: } +// CHECK: call void @__tgt_register_requires(i64 10) +module attributes {omp.requires = #omp} {}