diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -13,6 +13,7 @@ #ifndef FORTRAN_LOWER_OPENMP_H #define FORTRAN_LOWER_OPENMP_H +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include namespace mlir { @@ -30,10 +31,15 @@ namespace parser { struct OpenMPConstruct; struct OpenMPDeclarativeConstruct; +struct OpenMPDeclareTargetConstruct; struct OmpEndLoopDirective; struct OmpClauseList; } // namespace parser +namespace semantics { +class Symbol; +} // namespace semantics + namespace lower { class AbstractConverter; @@ -49,6 +55,9 @@ void genOpenMPConstruct(AbstractConverter &, pft::Evaluation &, const parser::OpenMPConstruct &); +void analyzeOpenMPDeclarativeConstruct( + Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &, + const parser::OpenMPDeclarativeConstruct &, bool &); void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &, const parser::OpenMPDeclarativeConstruct &); int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList); @@ -62,6 +71,15 @@ void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value, mlir::Value, fir::ConvertOp * = nullptr); void removeStoreOp(mlir::Operation *, mlir::Value); + +std::optional +getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &, + const Fortran::parser::OpenMPDeclareTargetConstruct &); +bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &); + +void genOpenMPRequires(mlir::Operation *, const Fortran::semantics::Symbol *); + } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -50,6 +50,7 @@ #include "flang/Parser/parse-tree.h" #include "flang/Runtime/iostat.h" #include "flang/Semantics/runtime-type-info.h" +#include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/PatternMatch.h" @@ -62,6 +63,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" +#include #include #define DEBUG_TYPE "flang-lower-bridge" @@ -294,12 +296,15 @@ // that they are available before lowering any function that may use // them. bool hasMainProgram = false; + const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr; for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) { std::visit(Fortran::common::visitors{ [&](Fortran::lower::pft::FunctionLikeUnit &f) { if (f.isMainProgram()) hasMainProgram = true; declareFunction(f); + if (!globalOmpRequiresSymbol) + globalOmpRequiresSymbol = f.getScope().symbol(); }, [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerModuleDeclScope(m); @@ -307,7 +312,10 @@ m.nestedFunctions) declareFunction(f); }, - [&](Fortran::lower::pft::BlockDataUnit &b) {}, + [&](Fortran::lower::pft::BlockDataUnit &b) { + if (!globalOmpRequiresSymbol) + globalOmpRequiresSymbol = b.symTab.symbol(); + }, [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {}, }, u); @@ -350,6 +358,11 @@ fir::runtime::genEnvironmentDefaults(*builder, toLocation(), bridge.getEnvironmentDefaults()); }); + + // Set the module attribute related to OpenMP requires directives + if (ompDeviceCodeFound) + Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), + globalOmpRequiresSymbol); } /// Declare a function. @@ -2299,10 +2312,16 @@ localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); + + // Register if a target region was found + ompDeviceCodeFound = + ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp); } void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); + analyzeOpenMPDeclarativeConstruct(*this, getEval(), ompDecl, + ompDeviceCodeFound); genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl); for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) genFIR(e); @@ -4726,6 +4745,10 @@ /// A counter for uniquing names in `literalNamesMap`. std::uint64_t uniqueLitId = 0; + + /// Whether an OpenMP target region or declare target function/subroutine + /// intended for device offloading has been detected + bool ompDeviceCodeFound = false; }; } // namespace diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -2407,6 +2407,44 @@ reductionDeclSymbols)); } +/// Extract the list of function and variable symbols affected by the given +/// 'declare target' directive and return the intended device type for them. +static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + llvm::SmallVectorImpl &symbolAndClause) { + + // The default capture type + mlir::omp::DeclareTargetDeviceType deviceType = + mlir::omp::DeclareTargetDeviceType::any; + const auto &spec = std::get( + declareTargetConstruct.t); + + if (const auto *objectList{ + Fortran::parser::Unwrap(spec.u)}) { + // Case: declare target(func, var1, var2) + gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to, + symbolAndClause); + } else if (const auto *clauseList{ + Fortran::parser::Unwrap( + spec.u)}) { + if (clauseList->v.empty()) { + // Case: declare target, implicit capture of function + symbolAndClause.emplace_back( + mlir::omp::DeclareTargetCaptureClause::to, + eval.getOwningProcedure()->getSubprogramSymbol()); + } + + ClauseProcessor cp(converter, *clauseList); + cp.processTo(symbolAndClause); + cp.processLink(symbolAndClause); + cp.processDeviceType(deviceType); + } + + return deviceType; +} + //===----------------------------------------------------------------------===// // genOMP() Code generation helper functions //===----------------------------------------------------------------------===// @@ -3306,32 +3344,8 @@ &declareTargetConstruct) { llvm::SmallVector symbolAndClause; mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - - // The default capture type - mlir::omp::DeclareTargetDeviceType deviceType = - mlir::omp::DeclareTargetDeviceType::any; - const auto &spec = std::get( - declareTargetConstruct.t); - if (const auto *objectList{ - Fortran::parser::Unwrap(spec.u)}) { - // Case: declare target(func, var1, var2) - gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to, - symbolAndClause); - } else if (const auto *clauseList{ - Fortran::parser::Unwrap( - spec.u)}) { - if (clauseList->v.empty()) { - // Case: declare target, implicit capture of function - symbolAndClause.emplace_back( - mlir::omp::DeclareTargetCaptureClause::to, - eval.getOwningProcedure()->getSubprogramSymbol()); - } - - ClauseProcessor cp(converter, *clauseList); - cp.processTo(symbolAndClause); - cp.processLink(symbolAndClause); - cp.processDeviceType(deviceType); - } + mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo( + converter, eval, declareTargetConstruct, symbolAndClause); for (const DeclareTargetCapturePair &symClause : symbolAndClause) { mlir::Operation *op = mod.lookupSymbol( @@ -3435,6 +3449,28 @@ ompConstruct.u); } +void Fortran::lower::analyzeOpenMPDeclarativeConstruct( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl, + bool &ompDeviceCodeFound) { + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) { + mlir::omp::DeclareTargetDeviceType targetType = + Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + converter, eval, ompReq) + .value_or(mlir::omp::DeclareTargetDeviceType::host); + + ompDeviceCodeFound = + ompDeviceCodeFound || + targetType != mlir::omp::DeclareTargetDeviceType::host; + }, + [&](const auto &) {}, + }, + ompDecl.u); +} + void Fortran::lower::genOpenMPDeclarativeConstruct( Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -3460,7 +3496,10 @@ }, [&](const Fortran::parser::OpenMPRequiresConstruct &requiresConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct"); + // Requires directives are gathered and processed in semantics and + // then combined in the lowering bridge before triggering codegen + // just once. Hence, there is no need to lower each individual + // occurrence here. }, [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { // The directive is lowered when instantiating the variable to @@ -3751,3 +3790,81 @@ } } } + +std::optional +Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + llvm::SmallVector symbolAndClause; + mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo( + converter, eval, declareTargetConstruct, symbolAndClause); + + // Return the device type only if at least one of the targets for the + // directive is a function or subroutine + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + for (std::pair + sym : symbolAndClause) { + mlir::Operation *op = + mod.lookupSymbol(converter.mangleName(std::get<1>(sym))); + + if (mlir::isa(op)) + return deviceType; + } + + return std::nullopt; +} + +bool Fortran::lower::isOpenMPTargetConstruct( + const Fortran::parser::OpenMPConstruct &omp) { + llvm::omp::Directive dir = llvm::omp::Directive::OMPD_unknown; + if (const auto *block = + std::get_if(&omp.u)) { + const auto &begin = + std::get(block->t); + dir = std::get(begin.t).v; + } else if (const auto *loop = + std::get_if(&omp.u)) { + const auto &begin = + std::get(loop->t); + dir = std::get(begin.t).v; + } + return llvm::omp::allTargetSet.test(dir); +} + +void Fortran::lower::genOpenMPRequires( + mlir::Operation *mod, const Fortran::semantics::Symbol *symbol) { + using MlirRequires = mlir::omp::ClauseRequires; + using SemaRequires = Fortran::semantics::OmpRequiresFlags; + + if (auto offloadMod = + llvm::dyn_cast(mod)) { + SemaRequires semaFlags = SemaRequires::None; + if (symbol) { + Fortran::common::visit( + [&](const auto &details) { + if constexpr (std::is_base_of_v< + Fortran::semantics::WithOmpDeclarative, + std::decay_t>) { + if (details.has_ompRequires()) + semaFlags = *details.ompRequires(); + } + }, + symbol->details()); + } + + MlirRequires mlirFlags = MlirRequires::none; + if (semaFlags & SemaRequires::ReverseOffload) + mlirFlags = mlirFlags | MlirRequires::reverse_offload; + if (semaFlags & SemaRequires::UnifiedAddress) + mlirFlags = mlirFlags | MlirRequires::unified_address; + if (semaFlags & SemaRequires::UnifiedSharedMemory) + mlirFlags = mlirFlags | MlirRequires::unified_shared_memory; + if (semaFlags & SemaRequires::DynamicAllocators) + mlirFlags = mlirFlags | MlirRequires::dynamic_allocators; + + offloadMod.setRequires(mlirFlags); + } +} diff --git a/flang/test/Lower/OpenMP/requires-notarget.f90 b/flang/test/Lower/OpenMP/requires-notarget.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires-notarget.f90 @@ -0,0 +1,11 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks that requires lowering into MLIR skips creating the +! omp.requires attribute with target-related clauses if there are no device +! functions in the compilation unit + +!CHECK: module attributes { +!CHECK-NOT: omp.requires +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires diff --git a/flang/test/Lower/OpenMP/requires.f90 b/flang/test/Lower/OpenMP/requires.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires.f90 @@ -0,0 +1,13 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks the lowering of requires into MLIR + +!CHECK: module attributes { +!CHECK-SAME: omp.requires = #omp +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires + +subroutine f + !$omp declare target +end subroutine f