diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -13,6 +13,7 @@ #ifndef FORTRAN_LOWER_OPENMP_H #define FORTRAN_LOWER_OPENMP_H +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include namespace mlir { @@ -30,10 +31,15 @@ namespace parser { struct OpenMPConstruct; struct OpenMPDeclarativeConstruct; +struct OpenMPDeclareTargetConstruct; struct OmpEndLoopDirective; struct OmpClauseList; } // namespace parser +namespace semantics { +class Symbol; +} // namespace semantics + namespace lower { class AbstractConverter; @@ -49,6 +55,9 @@ void genOpenMPConstruct(AbstractConverter &, pft::Evaluation &, const parser::OpenMPConstruct &); +void analyzeOpenMPDeclarativeConstruct( + Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &, + const parser::OpenMPDeclarativeConstruct &, bool &); void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &, const parser::OpenMPDeclarativeConstruct &); int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList); @@ -61,6 +70,15 @@ void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value, mlir::Value, fir::ConvertOp * = nullptr); void removeStoreOp(mlir::Operation *, mlir::Value); + +std::optional +getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &, + const Fortran::parser::OpenMPDeclareTargetConstruct &); +bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &); + +void genOpenMPRequires(mlir::Operation *, const Fortran::semantics::Symbol *); + } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -50,6 +50,7 @@ #include "flang/Parser/parse-tree.h" #include "flang/Runtime/iostat.h" #include "flang/Semantics/runtime-type-info.h" +#include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/PatternMatch.h" @@ -62,6 +63,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" +#include #include #define DEBUG_TYPE "flang-lower-bridge" @@ -288,12 +290,15 @@ // that they are available before lowering any function that may use // them. bool hasMainProgram = false; + const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr; for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) { std::visit(Fortran::common::visitors{ [&](Fortran::lower::pft::FunctionLikeUnit &f) { if (f.isMainProgram()) hasMainProgram = true; declareFunction(f); + if (!globalOmpRequiresSymbol) + globalOmpRequiresSymbol = f.getScope().symbol(); }, [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerModuleDeclScope(m); @@ -301,7 +306,10 @@ m.nestedFunctions) declareFunction(f); }, - [&](Fortran::lower::pft::BlockDataUnit &b) {}, + [&](Fortran::lower::pft::BlockDataUnit &b) { + if (!globalOmpRequiresSymbol) + globalOmpRequiresSymbol = b.symTab.symbol(); + }, [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {}, }, u); @@ -344,6 +352,11 @@ fir::runtime::genEnvironmentDefaults(*builder, toLocation(), bridge.getEnvironmentDefaults()); }); + + // Set the module attribute related to OpenMP requires directives + if (ompDeviceCodeFound) + Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), + globalOmpRequiresSymbol); } /// Declare a function. @@ -2201,10 +2214,16 @@ localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); + + // Register if a target region was found + ompDeviceCodeFound = + ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp); } void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); + analyzeOpenMPDeclarativeConstruct(*this, getEval(), ompDecl, + ompDeviceCodeFound); genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl); for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) genFIR(e); @@ -4560,6 +4579,10 @@ /// A counter for uniquing names in `literalNamesMap`. std::uint64_t uniqueLitId = 0; + + /// Whether an OpenMP target region or declare target function/subroutine + /// intended for device offloading has been detected + bool ompDeviceCodeFound = false; }; } // namespace diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -2676,16 +2676,14 @@ converter.bindSymbol(sym, symThreadprivateExv); } -void handleDeclareTarget(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclareTargetConstruct - &declareTargetConstruct) { - llvm::SmallVector, - 0> - symbolAndClause; - mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - +/// Extract the list of function and variable symbols affected by the given +/// 'declare target' directive and return the intended device type for them. +static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + SmallVectorImpl> &symbolAndClause) { + // Gather the symbols and clauses auto findFuncAndVarSyms = [&](const Fortran::parser::OmpObjectList &objList, mlir::omp::DeclareTargetCaptureClause clause) { for (const Fortran::parser::OmpObject &ompObject : objList.v) { @@ -2710,6 +2708,7 @@ Fortran::parser::OmpDeviceTypeClause::Type::Any; const auto &spec = std::get( declareTargetConstruct.t); + if (const auto *objectList{ Fortran::parser::Unwrap(spec.u)}) { // Case: declare target(func, var1, var2) @@ -2744,6 +2743,28 @@ } } + switch (deviceType) { + case Fortran::parser::OmpDeviceTypeClause::Type::Any: + return mlir::omp::DeclareTargetDeviceType::any; + case Fortran::parser::OmpDeviceTypeClause::Type::Host: + return mlir::omp::DeclareTargetDeviceType::host; + case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: + return mlir::omp::DeclareTargetDeviceType::nohost; + } +} + +void genDeclareTarget(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + llvm::SmallVector, + 0> + symbolAndClause; + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + mlir::omp::DeclareTargetDeviceType deviceType = + getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause); + for (std::pair symClause : symbolAndClause) { @@ -2770,35 +2791,44 @@ converter.getCurrentLocation(), "Attempt to apply declare target on unsupported operation"); - mlir::omp::DeclareTargetDeviceType newDeviceType; - switch (deviceType) { - case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: - newDeviceType = mlir::omp::DeclareTargetDeviceType::nohost; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Host: - newDeviceType = mlir::omp::DeclareTargetDeviceType::host; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Any: - newDeviceType = mlir::omp::DeclareTargetDeviceType::any; - break; - } - // The function or global already has a declare target applied to it, // very likely through implicit capture (usage in another declare // target function/subroutine). It should be marked as any if it has // been assigned both host and nohost, else we skip, as there is no // change if (declareTargetOp.isDeclareTarget()) { - if (declareTargetOp.getDeclareTargetDeviceType() != newDeviceType) + if (declareTargetOp.getDeclareTargetDeviceType() != deviceType) declareTargetOp.setDeclareTarget( mlir::omp::DeclareTargetDeviceType::any, std::get<0>(symClause)); continue; } - declareTargetOp.setDeclareTarget(newDeviceType, std::get<0>(symClause)); + declareTargetOp.setDeclareTarget(deviceType, std::get<0>(symClause)); } } +void Fortran::lower::analyzeOpenMPDeclarativeConstruct( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl, + bool &ompDeviceCodeFound) { + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) { + mlir::omp::DeclareTargetDeviceType targetType = + Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + converter, eval, ompReq) + .value_or(mlir::omp::DeclareTargetDeviceType::host); + + ompDeviceCodeFound = + ompDeviceCodeFound || + targetType != mlir::omp::DeclareTargetDeviceType::host; + }, + [&](const auto &) {}, + }, + ompDecl.u); +} + void Fortran::lower::genOpenMPDeclarativeConstruct( Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -2821,11 +2851,14 @@ }, [&](const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { - handleDeclareTarget(converter, eval, declareTargetConstruct); + genDeclareTarget(converter, eval, declareTargetConstruct); }, [&](const Fortran::parser::OpenMPRequiresConstruct &requiresConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct"); + // Requires directives are gathered and processed in semantics in + // order to support modules, and then combined in the lowering + // bridge before triggering codegen just once. Hence, there is no + // need for codegen for each individual occurrence here. }, [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { // The directive is lowered when instantiating the variable to @@ -3047,3 +3080,94 @@ } } } + +std::optional +Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + llvm::SmallVector, + 0> + symbolAndClause; + mlir::omp::DeclareTargetDeviceType deviceType = + getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause); + + // Return the device type only if at least one of the targets for the + // directive is a function or subroutine + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + for (std::pair + sym : symbolAndClause) { + mlir::Operation *op = + mod.lookupSymbol(converter.mangleName(std::get<1>(sym))); + + if (mlir::isa(op)) + return deviceType; + } + + return std::nullopt; +} + +bool Fortran::lower::isOpenMPTargetConstruct( + const Fortran::parser::OpenMPConstruct &omp) { + if (const auto *blockDir = + std::get_if(&omp.u)) { + const auto &beginBlockDir{ + std::get(blockDir->t)}; + const auto &beginDir{ + std::get(beginBlockDir.t)}; + + switch (beginDir.v) { + case llvm::omp::Directive::OMPD_target: + case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_target_parallel_do: + case llvm::omp::Directive::OMPD_target_parallel_do_simd: + case llvm::omp::Directive::OMPD_target_simd: + case llvm::omp::Directive::OMPD_target_teams: + case llvm::omp::Directive::OMPD_target_teams_distribute: + case llvm::omp::Directive::OMPD_target_teams_distribute_simd: + return true; + default: + break; + } + } + + return false; +} + +void Fortran::lower::genOpenMPRequires( + Operation *mod, const Fortran::semantics::Symbol *symbol) { + using MlirRequires = mlir::omp::ClauseRequires; + using SemaRequires = Fortran::semantics::OmpRequiresFlags; + + if (auto offloadMod = + llvm::dyn_cast(mod)) { + SemaRequires semaFlags = SemaRequires::None; + if (symbol) { + Fortran::common::visit( + [&](const auto &details) { + if constexpr (std::is_base_of_v< + Fortran::semantics::WithOmpDeclarative, + std::decay_t>) { + if (details.has_ompRequires()) + semaFlags = *details.ompRequires(); + } + }, + symbol->details()); + } + + MlirRequires mlirFlags = MlirRequires::none; + if (semaFlags & SemaRequires::ReverseOffload) + mlirFlags = mlirFlags | MlirRequires::reverse_offload; + if (semaFlags & SemaRequires::UnifiedAddress) + mlirFlags = mlirFlags | MlirRequires::unified_address; + if (semaFlags & SemaRequires::UnifiedSharedMemory) + mlirFlags = mlirFlags | MlirRequires::unified_shared_memory; + if (semaFlags & SemaRequires::DynamicAllocators) + mlirFlags = mlirFlags | MlirRequires::dynamic_allocators; + + offloadMod.setRequires(mlirFlags); + } +} diff --git a/flang/test/Lower/OpenMP/requires-notarget.f90 b/flang/test/Lower/OpenMP/requires-notarget.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires-notarget.f90 @@ -0,0 +1,11 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks that requires lowering into MLIR skips creating the +! omp.requires attribute with target-related clauses if there are no device +! functions in the compilation unit + +!CHECK: module attributes { +!CHECK-NOT: omp.requires +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires diff --git a/flang/test/Lower/OpenMP/requires.f90 b/flang/test/Lower/OpenMP/requires.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires.f90 @@ -0,0 +1,13 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks the lowering of requires into MLIR + +!CHECK: module attributes { +!CHECK-SAME: omp.requires = #omp +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires + +subroutine f + !$omp declare target +end subroutine f