diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -13,13 +13,9 @@ #ifndef FORTRAN_LOWER_OPENMP_H #define FORTRAN_LOWER_OPENMP_H +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include -namespace mlir { -class Value; -class Operation; -} // namespace mlir - namespace fir { class FirOpBuilder; class ConvertOp; @@ -29,6 +25,7 @@ namespace parser { struct OpenMPConstruct; struct OpenMPDeclarativeConstruct; +struct OpenMPDeclareTargetConstruct; struct OmpEndLoopDirective; struct OmpClauseList; } // namespace parser @@ -56,6 +53,17 @@ void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value, mlir::Value, fir::ConvertOp * = nullptr); void removeStoreOp(mlir::Operation *, mlir::Value); + +std::optional +getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &, + const Fortran::parser::OpenMPDeclareTargetConstruct &); +std::optional +extractOpenMPRequiresClauses(const Fortran::parser::OmpClauseList &, + mlir::omp::ClauseRequires &); + +bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &); + } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -52,6 +52,7 @@ #include "flang/Semantics/runtime-type-info.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/RegionUtils.h" @@ -266,7 +267,8 @@ public: explicit FirConverter(Fortran::lower::LoweringBridge &bridge) : Fortran::lower::AbstractConverter(bridge.getLoweringOptions()), - bridge{bridge}, foldingContext{bridge.createFoldingContext()} {} + bridge{bridge}, foldingContext{bridge.createFoldingContext()}, + ompRequiresFlags{mlir::omp::ClauseRequires::none} {} virtual ~FirConverter() = default; /// Convert the PFT to FIR. @@ -343,6 +345,16 @@ fir::runtime::genEnvironmentDefaults(*builder, toLocation(), bridge.getEnvironmentDefaults()); }); + + // Set the module attributes related to OpenMP requires directives + if (auto mod = llvm::dyn_cast( + getModuleOp().getOperation())) { + if (ompDeviceCodeFound) + mod.setRequires(ompRequiresFlags); + + if (ompAtomicDefaultMemOrder) + mod.setAtomicDefaultMemOrder(*ompAtomicDefaultMemOrder); + } } /// Declare a function. @@ -2053,10 +2065,66 @@ localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); + + // Register if a target region was found + ompDeviceCodeFound = + ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp); + } + + /// Extract information from OpenMP declarative constructs + void analyzeOpenMPDeclarative( + const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { + auto analyzeRequires = + [&](const Fortran::parser::OpenMPRequiresConstruct &ompReq) { + using mlir::omp::ClauseRequires; + + mlir::omp::ClauseRequires requiresFlags; + auto atomicDefaultMemOrder = + Fortran::lower::extractOpenMPRequiresClauses( + std::get(ompReq.t), + requiresFlags); + + if (requiresFlags != ClauseRequires::none) + ompRequiresFlags = ompRequiresFlags | requiresFlags; + + if (atomicDefaultMemOrder) { + if (ompAtomicDefaultMemOrder && + ompAtomicDefaultMemOrder != atomicDefaultMemOrder) + fir::emitFatalError( + toLocation(), + "conflicting atomic_default_mem_order clause found: " + + stringifyEnum(*atomicDefaultMemOrder) + + " != " + stringifyEnum(*ompAtomicDefaultMemOrder), + /*genCrashDiag=*/false); + ompAtomicDefaultMemOrder = atomicDefaultMemOrder; + } + }; + + auto analyzeDeclareTarget = + [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) { + auto targetType = + Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + *this, getEval(), ompReq); + + ompDeviceCodeFound = + ompDeviceCodeFound || + (targetType && + *targetType != mlir::omp::DeclareTargetDeviceType::host); + }; + + std::visit( + Fortran::common::visitors{ + analyzeRequires, + analyzeDeclareTarget, + // Add other OpenMP declarative constructs currently skipped + [&](const auto &) {}, + }, + ompDecl.u); } void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); + analyzeOpenMPDeclarative(ompDecl); genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl); for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) genFIR(e); @@ -4098,6 +4166,17 @@ /// Tuple of host associated variables mlir::Value hostAssocTuple; + + /// OpenMP Requires flags + mlir::omp::ClauseRequires ompRequiresFlags; + + /// OpenMP Default memory order for atomic operations, as defined by a + /// 'requires' directive + std::optional ompAtomicDefaultMemOrder; + + /// Whether an OpenMP target region or declare target function/subroutine + /// intended for device offloading has been detected + bool ompDeviceCodeFound = false; }; } // namespace diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -24,7 +24,9 @@ #include "flang/Semantics/tools.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include using namespace mlir; @@ -2239,11 +2241,14 @@ converter.bindSymbol(sym, symThreadprivateExv); } -void handleDeclareTarget(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclareTargetConstruct - &declareTargetConstruct) { - std::vector symbols; +/// Extract the list of function and variable MLIR operations affected by the +/// given 'declare target' directive and return the intended device type for +/// them. +static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + SmallVectorImpl &symbols) { + // Gather the symbols auto findFuncAndVarSyms = [&](const Fortran::parser::OmpObjectList &objList) { for (const auto &ompObject : objList.v) { Fortran::common::visit( @@ -2261,12 +2266,11 @@ } }; + // The default capture type + auto deviceType = Fortran::parser::OmpDeviceTypeClause::Type::Any; const auto &spec{std::get( declareTargetConstruct.t)}; - auto mod = converter.getFirOpBuilder().getModule(); - // The default capture type - auto deviceType = Fortran::parser::OmpDeviceTypeClause::Type::Any; if (const auto *objectList{ Fortran::parser::Unwrap(spec.u)}) { // Case: declare target(func, var1, var2) @@ -2288,8 +2292,7 @@ std::get_if( &clause.u)}) { // Case: declare target link(var1, var2)... - TODO(converter.getCurrentLocation(), - "the link clause is currently unsupported"); + TODO_NOLOC("the link clause is currently unsupported"); } else if (const auto *deviceClause{ std::get_if( &clause.u)}) { @@ -2299,6 +2302,24 @@ } } + switch (deviceType) { + case Fortran::parser::OmpDeviceTypeClause::Type::Any: + return mlir::omp::DeclareTargetDeviceType::any; + case Fortran::parser::OmpDeviceTypeClause::Type::Host: + return mlir::omp::DeclareTargetDeviceType::host; + case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: + return mlir::omp::DeclareTargetDeviceType::nohost; + } +} + +void handleDeclareTarget(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + SmallVector symbols; + auto deviceType = getDeclareTargetInfo(eval, declareTargetConstruct, symbols); + + auto mod = converter.getFirOpBuilder().getModule(); for (auto sym : symbols) { auto *op = mod.lookupSymbol(converter.mangleName(sym)); @@ -2315,29 +2336,17 @@ // assigned both host and nohost, else we skip, as there is no // change if (mlir::omp::OpenMPDialect::isDeclareTarget(fOp)) { - auto currentDeclTar = + auto currentDeclareTarget = mlir::omp::OpenMPDialect::getDeclareTargetDeviceType(fOp); - if ((currentDeclTar == mlir::omp::DeclareTargetDeviceType::nohost && - deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Host) || - (currentDeclTar == mlir::omp::DeclareTargetDeviceType::host && - deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Nohost)) { + + if (currentDeclareTarget != deviceType) mlir::omp::OpenMPDialect::setDeclareTarget( op, mlir::omp::DeclareTargetDeviceType::any); - } continue; } - if (deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Nohost) { - mlir::omp::OpenMPDialect::setDeclareTarget( - op, mlir::omp::DeclareTargetDeviceType::nohost); - } else if (deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Host) { - mlir::omp::OpenMPDialect::setDeclareTarget( - op, mlir::omp::DeclareTargetDeviceType::host); - } else if (deviceType == Fortran::parser::OmpDeviceTypeClause::Type::Any) { - mlir::omp::OpenMPDialect::setDeclareTarget( - op, mlir::omp::DeclareTargetDeviceType::any); - } + mlir::omp::OpenMPDialect::setDeclareTarget(op, deviceType); } } @@ -2367,7 +2376,11 @@ }, [&](const Fortran::parser::OpenMPRequiresConstruct &requiresConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct"); + // Requires directives are analyzed before any statements are + // lowered. Then, the result of combining the set of clauses of all + // requires directives present in the compilation unit is used to + // emit code, so no code is emitted independently for each + // "requires" instance. }, [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { // The directive is lowered when instantiating the variable to @@ -2577,3 +2590,93 @@ } } } + +std::optional +Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + SmallVector symbols; + auto deviceType = getDeclareTargetInfo(eval, declareTargetConstruct, symbols); + + // Return the device type only if at least one of the targets for the + // directive is a function or subroutine + auto mod = converter.getFirOpBuilder().getModule(); + for (auto sym : symbols) { + auto *op = mod.lookupSymbol(converter.mangleName(sym)); + + if (mlir::isa(op)) + return deviceType; + } + + return std::nullopt; +} + +std::optional +Fortran::lower::extractOpenMPRequiresClauses( + const Fortran::parser::OmpClauseList &clauseList, + omp::ClauseRequires &requiresFlags) { + std::optional atomicDefaultMemOrder; + requiresFlags = omp::ClauseRequires::none; + + for (const auto &clause : clauseList.v) { + if (const auto &atomicClause = + std::get_if( + &clause.u)) { + switch (atomicClause->v.v) { + case Fortran::parser::OmpAtomicDefaultMemOrderClause::Type::SeqCst: + atomicDefaultMemOrder = omp::ClauseMemoryOrderKind::Seq_cst; + break; + case Fortran::parser::OmpAtomicDefaultMemOrderClause::Type::AcqRel: + atomicDefaultMemOrder = omp::ClauseMemoryOrderKind::Acq_rel; + break; + case Fortran::parser::OmpAtomicDefaultMemOrderClause::Type::Relaxed: + atomicDefaultMemOrder = omp::ClauseMemoryOrderKind::Relaxed; + break; + } + } else if (std::get_if( + &clause.u)) { + requiresFlags = requiresFlags | omp::ClauseRequires::dynamic_allocators; + } else if (std::get_if( + &clause.u)) { + requiresFlags = requiresFlags | omp::ClauseRequires::reverse_offload; + } else if (std::get_if( + &clause.u)) { + requiresFlags = requiresFlags | omp::ClauseRequires::unified_address; + } else if (std::get_if( + &clause.u)) { + requiresFlags = + requiresFlags | omp::ClauseRequires::unified_shared_memory; + } + } + + return atomicDefaultMemOrder; +} + +bool Fortran::lower::isOpenMPTargetConstruct( + const Fortran::parser::OpenMPConstruct &omp) { + if (const auto *blockDir = + std::get_if(&omp.u)) { + const auto &beginBlockDir{ + std::get(blockDir->t)}; + const auto &beginDir{ + std::get(beginBlockDir.t)}; + + switch (beginDir.v) { + case llvm::omp::Directive::OMPD_target: + case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_target_parallel_do: + case llvm::omp::Directive::OMPD_target_parallel_do_simd: + case llvm::omp::Directive::OMPD_target_simd: + case llvm::omp::Directive::OMPD_target_teams: + case llvm::omp::Directive::OMPD_target_teams_distribute: + case llvm::omp::Directive::OMPD_target_teams_distribute_simd: + return true; + default: + break; + } + } + + return false; +} diff --git a/flang/test/Lower/OpenMP/requires-error.f90 b/flang/test/Lower/OpenMP/requires-error.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires-error.f90 @@ -0,0 +1,15 @@ +! RUN: not %flang_fc1 -emit-fir -fopenmp %s -o - 2>&1 | FileCheck %s + +! This test checks that requires lowering into MLIR skips creating the +! omp.requires attribute with target-related clauses if there are no device +! functions in the compilation unit + +!CHECK: error: {{.*}} conflicting atomic_default_mem_order clause found: +!CHECK-SAME: acq_rel != seq_cst +program requires + !$omp requires atomic_default_mem_order(seq_cst) +end program requires + +subroutine f() + !$omp requires atomic_default_mem_order(acq_rel) +end subroutine f diff --git a/flang/test/Lower/OpenMP/requires-notarget.f90 b/flang/test/Lower/OpenMP/requires-notarget.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires-notarget.f90 @@ -0,0 +1,12 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks that requires lowering into MLIR skips creating the +! omp.requires attribute with target-related clauses if there are no device +! functions in the compilation unit + +!CHECK: module attributes { +!CHECK-SAME: omp.atomic_default_mem_order = #omp +!CHECK-NOT: omp.requires +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires diff --git a/flang/test/Lower/OpenMP/requires.f90 b/flang/test/Lower/OpenMP/requires.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires.f90 @@ -0,0 +1,14 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks the lowering of requires into MLIR + +!CHECK: module attributes { +!CHECK-SAME: omp.atomic_default_mem_order = #omp +!CHECK-SAME: omp.requires = #omp +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires + +subroutine f + !$omp declare target +end subroutine f