diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -13,6 +13,7 @@ #ifndef FORTRAN_LOWER_OPENMP_H #define FORTRAN_LOWER_OPENMP_H +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include namespace mlir { @@ -30,6 +31,7 @@ namespace parser { struct OpenMPConstruct; struct OpenMPDeclarativeConstruct; +struct OpenMPDeclareTargetConstruct; struct OmpEndLoopDirective; struct OmpClauseList; } // namespace parser @@ -49,6 +51,9 @@ void genOpenMPConstruct(AbstractConverter &, pft::Evaluation &, const parser::OpenMPConstruct &); +void analyzeOpenMPDeclarativeConstruct( + Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &, + const parser::OpenMPDeclarativeConstruct &, bool &); void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &, const parser::OpenMPDeclarativeConstruct &); int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList); @@ -61,6 +66,17 @@ void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value, mlir::Value, fir::ConvertOp * = nullptr); void removeStoreOp(mlir::Operation *, mlir::Value); + +std::optional +getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &, + const Fortran::parser::OpenMPDeclareTargetConstruct &); +bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &); + +mlir::omp::ClauseRequires +extractOpenMPRequiresClauses(const Fortran::parser::OmpClauseList &); +void genOpenMPRequires(mlir::Operation *, mlir::omp::ClauseRequires); + } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -50,6 +50,7 @@ #include "flang/Parser/parse-tree.h" #include "flang/Runtime/iostat.h" #include "flang/Semantics/runtime-type-info.h" +#include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/PatternMatch.h" @@ -62,6 +63,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" +#include #include #define DEBUG_TYPE "flang-lower-bridge" @@ -288,20 +290,34 @@ // that they are available before lowering any function that may use // them. bool hasMainProgram = false; + Fortran::semantics::OmpRequiresFlags ompRequiresFlags = + Fortran::semantics::OmpRequiresFlags::None; + std::optional + ompAtomicDefaultMemOrder; for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) { std::visit(Fortran::common::visitors{ [&](Fortran::lower::pft::FunctionLikeUnit &f) { if (f.isMainProgram()) hasMainProgram = true; declareFunction(f); + ompProcessTopLevelSymbol(f.getScope().symbol(), + ompRequiresFlags, + ompAtomicDefaultMemOrder); }, [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerModuleDeclScope(m); for (Fortran::lower::pft::FunctionLikeUnit &f : m.nestedFunctions) declareFunction(f); + ompProcessTopLevelSymbol(m.getScope().symbol(), + ompRequiresFlags, + ompAtomicDefaultMemOrder); + }, + [&](Fortran::lower::pft::BlockDataUnit &b) { + ompProcessTopLevelSymbol(b.symTab.symbol(), + ompRequiresFlags, + ompAtomicDefaultMemOrder); }, - [&](Fortran::lower::pft::BlockDataUnit &b) {}, [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {}, }, u); @@ -344,6 +360,24 @@ fir::runtime::genEnvironmentDefaults(*builder, toLocation(), bridge.getEnvironmentDefaults()); }); + + // Set the module attribute related to OpenMP requires directives + if (ompDeviceCodeFound) { + using MlirRequires = mlir::omp::ClauseRequires; + using SemaRequires = Fortran::semantics::OmpRequiresFlags; + MlirRequires flags = MlirRequires::none; + + if (ompRequiresFlags & SemaRequires::ReverseOffload) + flags = flags | MlirRequires::reverse_offload; + if (ompRequiresFlags & SemaRequires::UnifiedAddress) + flags = flags | MlirRequires::unified_address; + if (ompRequiresFlags & SemaRequires::UnifiedSharedMemory) + flags = flags | MlirRequires::unified_shared_memory; + if (ompRequiresFlags & SemaRequires::DynamicAllocators) + flags = flags | MlirRequires::dynamic_allocators; + + Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), flags); + } } /// Declare a function. @@ -1191,6 +1225,47 @@ activeConstructStack.pop_back(); } + void ompProcessTopLevelSymbol( + const Fortran::semantics::Symbol *symbol, + Fortran::semantics::OmpRequiresFlags &ompRequiresFlags, + std::optional + &ompAtomicDefaultMemOrder) { + if (!symbol) + return; + + Fortran::common::visit( + [&](const auto &details) { + if constexpr (std::is_base_of_v< + Fortran::semantics::WithOmpDeclarative, + std::decay_t>) { + // Collect OpenMP 'requires' clauses. + if (details.has_ompRequires()) + ompRequiresFlags |= *details.ompRequires(); + + // Make sure any atomic_default_mem_order OpenMP 'requires' clauses + // obtained for different top-level symbols match. + if (details.has_ompAtomicDefaultMemOrder()) { + Fortran::parser::OmpAtomicDefaultMemOrderClause::Type memOrder{ + *details.ompAtomicDefaultMemOrder()}; + if (ompAtomicDefaultMemOrder && + memOrder != *ompAtomicDefaultMemOrder) + fir::emitFatalError( + getCurrentLocation(), + llvm::StringRef{ + "incompatible OpenMP requires atomic_default_mem_order " + "clauses found: '"} + + Fortran::parser::OmpAtomicDefaultMemOrderClause:: + EnumToString(memOrder) + + llvm::StringRef{"' and '"} + + Fortran::parser::OmpAtomicDefaultMemOrderClause:: + EnumToString(*ompAtomicDefaultMemOrder)); + ompAtomicDefaultMemOrder = memOrder; + } + } + }, + symbol->details()); + } + //===--------------------------------------------------------------------===// // Termination of symbolically referenced execution units //===--------------------------------------------------------------------===// @@ -2201,10 +2276,16 @@ localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); + + // Register if a target region was found + ompDeviceCodeFound = + ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp); } void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); + analyzeOpenMPDeclarativeConstruct(*this, getEval(), ompDecl, + ompDeviceCodeFound); genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl); for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) genFIR(e); @@ -4530,6 +4611,10 @@ /// A counter for uniquing names in `literalNamesMap`. std::uint64_t uniqueLitId = 0; + + /// Whether an OpenMP target region or declare target function/subroutine + /// intended for device offloading has been detected + bool ompDeviceCodeFound = false; }; } // namespace diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -2603,16 +2603,14 @@ converter.bindSymbol(sym, symThreadprivateExv); } -void handleDeclareTarget(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclareTargetConstruct - &declareTargetConstruct) { - llvm::SmallVector, - 0> - symbolAndClause; - mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - +/// Extract the list of function and variable symbols affected by the given +/// 'declare target' directive and return the intended device type for them. +static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + SmallVectorImpl> &symbolAndClause) { + // Gather the symbols and clauses auto findFuncAndVarSyms = [&](const Fortran::parser::OmpObjectList &objList, mlir::omp::DeclareTargetCaptureClause clause) { for (const Fortran::parser::OmpObject &ompObject : objList.v) { @@ -2637,6 +2635,7 @@ Fortran::parser::OmpDeviceTypeClause::Type::Any; const auto &spec = std::get( declareTargetConstruct.t); + if (const auto *objectList{ Fortran::parser::Unwrap(spec.u)}) { // Case: declare target(func, var1, var2) @@ -2671,6 +2670,28 @@ } } + switch (deviceType) { + case Fortran::parser::OmpDeviceTypeClause::Type::Any: + return mlir::omp::DeclareTargetDeviceType::any; + case Fortran::parser::OmpDeviceTypeClause::Type::Host: + return mlir::omp::DeclareTargetDeviceType::host; + case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: + return mlir::omp::DeclareTargetDeviceType::nohost; + } +} + +void genDeclareTarget(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + llvm::SmallVector, + 0> + symbolAndClause; + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + mlir::omp::DeclareTargetDeviceType deviceType = + getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause); + for (std::pair symClause : symbolAndClause) { @@ -2697,35 +2718,44 @@ converter.getCurrentLocation(), "Attempt to apply declare target on unsupported operation"); - mlir::omp::DeclareTargetDeviceType newDeviceType; - switch (deviceType) { - case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: - newDeviceType = mlir::omp::DeclareTargetDeviceType::nohost; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Host: - newDeviceType = mlir::omp::DeclareTargetDeviceType::host; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Any: - newDeviceType = mlir::omp::DeclareTargetDeviceType::any; - break; - } - // The function or global already has a declare target applied to it, // very likely through implicit capture (usage in another declare // target function/subroutine). It should be marked as any if it has // been assigned both host and nohost, else we skip, as there is no // change if (declareTargetOp.isDeclareTarget()) { - if (declareTargetOp.getDeclareTargetDeviceType() != newDeviceType) + if (declareTargetOp.getDeclareTargetDeviceType() != deviceType) declareTargetOp.setDeclareTarget( mlir::omp::DeclareTargetDeviceType::any, std::get<0>(symClause)); continue; } - declareTargetOp.setDeclareTarget(newDeviceType, std::get<0>(symClause)); + declareTargetOp.setDeclareTarget(deviceType, std::get<0>(symClause)); } } +void Fortran::lower::analyzeOpenMPDeclarativeConstruct( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl, + bool &ompDeviceCodeFound) { + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) { + mlir::omp::DeclareTargetDeviceType targetType = + Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + converter, eval, ompReq) + .value_or(mlir::omp::DeclareTargetDeviceType::host); + + ompDeviceCodeFound = + ompDeviceCodeFound || + targetType != mlir::omp::DeclareTargetDeviceType::host; + }, + [&](const auto &) {}, + }, + ompDecl.u); +} + void Fortran::lower::genOpenMPDeclarativeConstruct( Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -2748,11 +2778,14 @@ }, [&](const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { - handleDeclareTarget(converter, eval, declareTargetConstruct); + genDeclareTarget(converter, eval, declareTargetConstruct); }, [&](const Fortran::parser::OpenMPRequiresConstruct &requiresConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct"); + // Requires directives are gathered and processed in semantics in + // order to support modules, and then combined in the lowering + // bridge before triggering codegen just once. Hence, there is no + // need for codegen for each individual occurrence here. }, [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { // The directive is lowered when instantiating the variable to @@ -2974,3 +3007,84 @@ } } } + +std::optional +Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + llvm::SmallVector, + 0> + symbolAndClause; + mlir::omp::DeclareTargetDeviceType deviceType = + getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause); + + // Return the device type only if at least one of the targets for the + // directive is a function or subroutine + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + for (std::pair + sym : symbolAndClause) { + mlir::Operation *op = + mod.lookupSymbol(converter.mangleName(std::get<1>(sym))); + + if (mlir::isa(op)) + return deviceType; + } + + return std::nullopt; +} + +bool Fortran::lower::isOpenMPTargetConstruct( + const Fortran::parser::OpenMPConstruct &omp) { + if (const auto *blockDir = + std::get_if(&omp.u)) { + const auto &beginBlockDir{ + std::get(blockDir->t)}; + const auto &beginDir{ + std::get(beginBlockDir.t)}; + + switch (beginDir.v) { + case llvm::omp::Directive::OMPD_target: + case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_target_parallel_do: + case llvm::omp::Directive::OMPD_target_parallel_do_simd: + case llvm::omp::Directive::OMPD_target_simd: + case llvm::omp::Directive::OMPD_target_teams: + case llvm::omp::Directive::OMPD_target_teams_distribute: + case llvm::omp::Directive::OMPD_target_teams_distribute_simd: + return true; + default: + break; + } + } + + return false; +} + +omp::ClauseRequires Fortran::lower::extractOpenMPRequiresClauses( + const Fortran::parser::OmpClauseList &clauseList) { + using omp::ClauseRequires, Fortran::parser::OmpClause; + auto requiresFlags = ClauseRequires::none; + + for (const OmpClause &clause : clauseList.v) { + if (std::get_if(&clause.u)) + requiresFlags = requiresFlags | ClauseRequires::dynamic_allocators; + else if (std::get_if(&clause.u)) + requiresFlags = requiresFlags | ClauseRequires::reverse_offload; + else if (std::get_if(&clause.u)) + requiresFlags = requiresFlags | ClauseRequires::unified_address; + else if (std::get_if(&clause.u)) + requiresFlags = requiresFlags | ClauseRequires::unified_shared_memory; + } + + return requiresFlags; +} + +void Fortran::lower::genOpenMPRequires(Operation *mod, + omp::ClauseRequires flags) { + if (auto offloadMod = llvm::dyn_cast(mod)) + offloadMod.setRequires(flags); +} diff --git a/flang/test/Lower/OpenMP/requires-notarget.f90 b/flang/test/Lower/OpenMP/requires-notarget.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires-notarget.f90 @@ -0,0 +1,11 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks that requires lowering into MLIR skips creating the +! omp.requires attribute with target-related clauses if there are no device +! functions in the compilation unit + +!CHECK: module attributes { +!CHECK-NOT: omp.requires +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires diff --git a/flang/test/Lower/OpenMP/requires.f90 b/flang/test/Lower/OpenMP/requires.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires.f90 @@ -0,0 +1,13 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks the lowering of requires into MLIR + +!CHECK: module attributes { +!CHECK-SAME: omp.requires = #omp +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires + +subroutine f + !$omp declare target +end subroutine f