diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -34,6 +34,10 @@ struct OmpClauseList; } // namespace parser +namespace semantics { +class Symbol; +} // namespace semantics + namespace lower { class AbstractConverter; @@ -62,6 +66,13 @@ void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value, mlir::Value, fir::ConvertOp * = nullptr); void removeStoreOp(mlir::Operation *, mlir::Value); + +bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &); +bool isOpenMPDeviceDeclareTarget(Fortran::lower::AbstractConverter &, + Fortran::lower::pft::Evaluation &, + const parser::OpenMPDeclarativeConstruct &); +void genOpenMPRequires(mlir::Operation *, const Fortran::semantics::Symbol *); + } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -50,6 +50,7 @@ #include "flang/Parser/parse-tree.h" #include "flang/Runtime/iostat.h" #include "flang/Semantics/runtime-type-info.h" +#include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/PatternMatch.h" @@ -294,12 +295,15 @@ // that they are available before lowering any function that may use // them. bool hasMainProgram = false; + const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr; for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) { std::visit(Fortran::common::visitors{ [&](Fortran::lower::pft::FunctionLikeUnit &f) { if (f.isMainProgram()) hasMainProgram = true; declareFunction(f); + if (!globalOmpRequiresSymbol) + globalOmpRequiresSymbol = f.getScope().symbol(); }, [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerModuleDeclScope(m); @@ -307,7 +311,10 @@ m.nestedFunctions) declareFunction(f); }, - [&](Fortran::lower::pft::BlockDataUnit &b) {}, + [&](Fortran::lower::pft::BlockDataUnit &b) { + if (!globalOmpRequiresSymbol) + globalOmpRequiresSymbol = b.symTab.symbol(); + }, [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {}, }, u); @@ -352,6 +359,7 @@ }); finalizeOpenACCLowering(); + finalizeOpenMPLowering(globalOmpRequiresSymbol); } /// Declare a function. @@ -2347,10 +2355,19 @@ localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); + + // Register if a target region was found + ompDeviceCodeFound = + ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp); } void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); + // Register if a declare target construct intended for a target device was + // found + ompDeviceCodeFound = + ompDeviceCodeFound || + Fortran::lower::isOpenMPDeviceDeclareTarget(*this, getEval(), ompDecl); genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl); for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) genFIR(e); @@ -4758,6 +4775,16 @@ accRoutineInfos); } + /// Performing OpenMP lowering actions that were deferred to the end of + /// lowering. + void finalizeOpenMPLowering( + const Fortran::semantics::Symbol *globalOmpRequiresSymbol) { + // Set the module attribute related to OpenMP requires directives + if (ompDeviceCodeFound) + Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), + globalOmpRequiresSymbol); + } + //===--------------------------------------------------------------------===// Fortran::lower::LoweringBridge &bridge; @@ -4804,6 +4831,10 @@ /// Deferred OpenACC routine attachment. Fortran::lower::AccRoutineInfoMappingList accRoutineInfos; + + /// Whether an OpenMP target region or declare target function/subroutine + /// intended for device offloading has been detected + bool ompDeviceCodeFound = false; }; } // namespace diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -78,9 +78,7 @@ static void gatherFuncAndVarSyms( const Fortran::parser::OmpObjectList &objList, mlir::omp::DeclareTargetCaptureClause clause, - llvm::SmallVectorImpl> - &symbolAndClause) { + llvm::SmallVectorImpl &symbolAndClause) { for (const Fortran::parser::OmpObject &ompObject : objList.v) { Fortran::common::visit( Fortran::common::visitors{ @@ -2453,6 +2451,71 @@ reductionDeclSymbols)); } +/// Extract the list of function and variable symbols affected by the given +/// 'declare target' directive and return the intended device type for them. +static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + llvm::SmallVectorImpl &symbolAndClause) { + + // The default capture type + mlir::omp::DeclareTargetDeviceType deviceType = + mlir::omp::DeclareTargetDeviceType::any; + const auto &spec = std::get( + declareTargetConstruct.t); + + if (const auto *objectList{ + Fortran::parser::Unwrap(spec.u)}) { + // Case: declare target(func, var1, var2) + gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to, + symbolAndClause); + } else if (const auto *clauseList{ + Fortran::parser::Unwrap( + spec.u)}) { + if (clauseList->v.empty()) { + // Case: declare target, implicit capture of function + symbolAndClause.emplace_back( + mlir::omp::DeclareTargetCaptureClause::to, + eval.getOwningProcedure()->getSubprogramSymbol()); + } + + ClauseProcessor cp(converter, *clauseList); + cp.processTo(symbolAndClause); + cp.processLink(symbolAndClause); + cp.processDeviceType(deviceType); + cp.processTODO( + converter.getCurrentLocation(), + llvm::omp::Directive::OMPD_declare_target); + } + + return deviceType; +} + +static std::optional +getDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + llvm::SmallVector symbolAndClause; + mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo( + converter, eval, declareTargetConstruct, symbolAndClause); + + // Return the device type only if at least one of the targets for the + // directive is a function or subroutine + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + for (const DeclareTargetCapturePair &symClause : symbolAndClause) { + mlir::Operation *op = mod.lookupSymbol( + converter.mangleName(std::get(symClause))); + + if (mlir::isa(op)) + return deviceType; + } + + return std::nullopt; +} + //===----------------------------------------------------------------------===// // genOMP() Code generation helper functions //===----------------------------------------------------------------------===// @@ -2973,35 +3036,8 @@ &declareTargetConstruct) { llvm::SmallVector symbolAndClause; mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - - // The default capture type - mlir::omp::DeclareTargetDeviceType deviceType = - mlir::omp::DeclareTargetDeviceType::any; - const auto &spec = std::get( - declareTargetConstruct.t); - if (const auto *objectList{ - Fortran::parser::Unwrap(spec.u)}) { - // Case: declare target(func, var1, var2) - gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to, - symbolAndClause); - } else if (const auto *clauseList{ - Fortran::parser::Unwrap( - spec.u)}) { - if (clauseList->v.empty()) { - // Case: declare target, implicit capture of function - symbolAndClause.emplace_back( - mlir::omp::DeclareTargetCaptureClause::to, - eval.getOwningProcedure()->getSubprogramSymbol()); - } - - ClauseProcessor cp(converter, *clauseList); - cp.processTo(symbolAndClause); - cp.processLink(symbolAndClause); - cp.processDeviceType(deviceType); - cp.processTODO( - converter.getCurrentLocation(), - llvm::omp::Directive::OMPD_declare_target); - } + mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo( + converter, eval, declareTargetConstruct, symbolAndClause); for (const DeclareTargetCapturePair &symClause : symbolAndClause) { mlir::Operation *op = mod.lookupSymbol( @@ -3130,7 +3166,10 @@ }, [&](const Fortran::parser::OpenMPRequiresConstruct &requiresConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct"); + // Requires directives are gathered and processed in semantics and + // then combined in the lowering bridge before triggering codegen + // just once. Hence, there is no need to lower each individual + // occurrence here. }, [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { // The directive is lowered when instantiating the variable to @@ -3444,3 +3483,72 @@ } } } + +bool Fortran::lower::isOpenMPTargetConstruct( + const Fortran::parser::OpenMPConstruct &omp) { + llvm::omp::Directive dir = llvm::omp::Directive::OMPD_unknown; + if (const auto *block = + std::get_if(&omp.u)) { + const auto &begin = + std::get(block->t); + dir = std::get(begin.t).v; + } else if (const auto *loop = + std::get_if(&omp.u)) { + const auto &begin = + std::get(loop->t); + dir = std::get(begin.t).v; + } + return llvm::omp::allTargetSet.test(dir); +} + +bool Fortran::lower::isOpenMPDeviceDeclareTarget( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { + return std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) { + mlir::omp::DeclareTargetDeviceType targetType = + getDeclareTargetFunctionDevice(converter, eval, ompReq) + .value_or(mlir::omp::DeclareTargetDeviceType::host); + return targetType != mlir::omp::DeclareTargetDeviceType::host; + }, + [&](const auto &) { return false; }, + }, + ompDecl.u); +} + +void Fortran::lower::genOpenMPRequires( + mlir::Operation *mod, const Fortran::semantics::Symbol *symbol) { + using MlirRequires = mlir::omp::ClauseRequires; + using SemaRequires = Fortran::semantics::WithOmpDeclarative::RequiresFlag; + + if (auto offloadMod = + llvm::dyn_cast(mod)) { + Fortran::semantics::WithOmpDeclarative::RequiresFlags semaFlags; + if (symbol) { + Fortran::common::visit( + [&](const auto &details) { + if constexpr (std::is_base_of_v< + Fortran::semantics::WithOmpDeclarative, + std::decay_t>) { + if (details.has_ompRequires()) + semaFlags = *details.ompRequires(); + } + }, + symbol->details()); + } + + MlirRequires mlirFlags = MlirRequires::none; + if (semaFlags.test(SemaRequires::ReverseOffload)) + mlirFlags = mlirFlags | MlirRequires::reverse_offload; + if (semaFlags.test(SemaRequires::UnifiedAddress)) + mlirFlags = mlirFlags | MlirRequires::unified_address; + if (semaFlags.test(SemaRequires::UnifiedSharedMemory)) + mlirFlags = mlirFlags | MlirRequires::unified_shared_memory; + if (semaFlags.test(SemaRequires::DynamicAllocators)) + mlirFlags = mlirFlags | MlirRequires::dynamic_allocators; + + offloadMod.setRequires(mlirFlags); + } +} diff --git a/flang/test/Lower/OpenMP/Todo/requires-unnamed-common.f90 b/flang/test/Lower/OpenMP/Todo/requires-unnamed-common.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/Todo/requires-unnamed-common.f90 @@ -0,0 +1,23 @@ +! This test checks the lowering of REQUIRES inside of an unnamed BLOCK DATA. +! The symbol of the `symTab` scope of the `BlockDataUnit` PFT node is null in +! this case, resulting in the inability to store the REQUIRES flags gathered in +! it. + +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s +! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s +! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s +! XFAIL: * + +!CHECK: module attributes { +!CHECK-SAME: omp.requires = #omp +block data + !$omp requires unified_shared_memory + integer :: x + common /block/ x + data x / 10 / +end + +subroutine f + !$omp declare target +end subroutine f diff --git a/flang/test/Lower/OpenMP/requires-common.f90 b/flang/test/Lower/OpenMP/requires-common.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires-common.f90 @@ -0,0 +1,19 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s +! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s +! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s + +! This test checks the lowering of requires into MLIR + +!CHECK: module attributes { +!CHECK-SAME: omp.requires = #omp +block data init + !$omp requires unified_shared_memory + integer :: x + common /block/ x + data x / 10 / +end + +subroutine f + !$omp declare target +end subroutine f diff --git a/flang/test/Lower/OpenMP/requires-notarget.f90 b/flang/test/Lower/OpenMP/requires-notarget.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires-notarget.f90 @@ -0,0 +1,14 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s +! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s +! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s + +! This test checks that requires lowering into MLIR skips creating the +! omp.requires attribute with target-related clauses if there are no device +! functions in the compilation unit + +!CHECK: module attributes { +!CHECK-NOT: omp.requires +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires diff --git a/flang/test/Lower/OpenMP/requires.f90 b/flang/test/Lower/OpenMP/requires.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires.f90 @@ -0,0 +1,14 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s +! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s +! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s + +! This test checks the lowering of requires into MLIR + +!CHECK: module attributes { +!CHECK-SAME: omp.requires = #omp +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) + !$omp target + !$omp end target +end program requires