diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -13,13 +13,9 @@ #ifndef FORTRAN_LOWER_OPENMP_H #define FORTRAN_LOWER_OPENMP_H +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include -namespace mlir { -class Value; -class Operation; -} // namespace mlir - namespace fir { class FirOpBuilder; class ConvertOp; @@ -29,6 +25,7 @@ namespace parser { struct OpenMPConstruct; struct OpenMPDeclarativeConstruct; +struct OpenMPDeclareTargetConstruct; struct OmpEndLoopDirective; struct OmpClauseList; } // namespace parser @@ -56,6 +53,17 @@ void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value, mlir::Value, fir::ConvertOp * = nullptr); void removeStoreOp(mlir::Operation *, mlir::Value); + +std::optional +getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &, + const Fortran::parser::OpenMPDeclareTargetConstruct &); +bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &); + +mlir::omp::ClauseRequires +extractOpenMPRequiresClauses(const Fortran::parser::OmpClauseList &); +void genOpenMPRequires(mlir::Operation *, mlir::omp::ClauseRequires); + } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -266,7 +266,8 @@ public: explicit FirConverter(Fortran::lower::LoweringBridge &bridge) : Fortran::lower::AbstractConverter(bridge.getLoweringOptions()), - bridge{bridge}, foldingContext{bridge.createFoldingContext()} {} + bridge{bridge}, foldingContext{bridge.createFoldingContext()}, + ompRequiresFlags{mlir::omp::ClauseRequires::none} {} virtual ~FirConverter() = default; /// Convert the PFT to FIR. @@ -343,6 +344,11 @@ fir::runtime::genEnvironmentDefaults(*builder, toLocation(), bridge.getEnvironmentDefaults()); }); + + // Set the module attribute related to OpenMP requires directives + if (ompDeviceCodeFound) + Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), + ompRequiresFlags); } /// Declare a function. @@ -2053,10 +2059,43 @@ localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); + + // Register if a target region was found + ompDeviceCodeFound = + ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp); + } + + /// Extract information from OpenMP declarative constructs + void analyzeOpenMPDeclarative( + const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::OpenMPRequiresConstruct &ompReq) { + auto requiresFlags = Fortran::lower::extractOpenMPRequiresClauses( + std::get(ompReq.t)); + + if (requiresFlags != mlir::omp::ClauseRequires::none) + ompRequiresFlags = ompRequiresFlags | requiresFlags; + }, + [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) { + auto targetType = + Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + *this, getEval(), ompReq); + + ompDeviceCodeFound = + ompDeviceCodeFound || + (targetType && + *targetType != mlir::omp::DeclareTargetDeviceType::host); + }, + // Add other OpenMP declarative constructs currently skipped + [&](const auto &) {}, + }, + ompDecl.u); } void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); + analyzeOpenMPDeclarative(ompDecl); genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl); for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) genFIR(e); @@ -4098,6 +4137,13 @@ /// Tuple of host associated variables mlir::Value hostAssocTuple; + + /// OpenMP Requires flags + mlir::omp::ClauseRequires ompRequiresFlags; + + /// Whether an OpenMP target region or declare target function/subroutine + /// intended for device offloading has been detected + bool ompDeviceCodeFound = false; }; } // namespace diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -24,7 +24,9 @@ #include "flang/Semantics/tools.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include using namespace mlir; @@ -2262,14 +2264,14 @@ converter.bindSymbol(sym, symThreadprivateExv); } -void handleDeclareTarget(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclareTargetConstruct - &declareTargetConstruct) { - llvm::SmallVector, - 0> - symbolAndClause; +/// Extract the list of function and variable symbols affected by the given +/// 'declare target' directive and return the intended device type for them. +static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + SmallVectorImpl> &symbolAndClause) { + // Gather the symbols and clauses auto findFuncAndVarSyms = [&](const Fortran::parser::OmpObjectList &objList, mlir::omp::DeclareTargetCaptureClause clause) { for (const auto &ompObject : objList.v) { @@ -2289,12 +2291,11 @@ } }; + // The default capture type + auto deviceType = Fortran::parser::OmpDeviceTypeClause::Type::Any; const auto &spec{std::get( declareTargetConstruct.t)}; - auto mod = converter.getFirOpBuilder().getModule(); - // The default capture type - auto deviceType = Fortran::parser::OmpDeviceTypeClause::Type::Any; if (const auto *objectList{ Fortran::parser::Unwrap(spec.u)}) { // Case: declare target(func, var1, var2) @@ -2330,6 +2331,28 @@ } } + switch (deviceType) { + case Fortran::parser::OmpDeviceTypeClause::Type::Any: + return mlir::omp::DeclareTargetDeviceType::any; + case Fortran::parser::OmpDeviceTypeClause::Type::Host: + return mlir::omp::DeclareTargetDeviceType::host; + case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: + return mlir::omp::DeclareTargetDeviceType::nohost; + } +} + +void handleDeclareTarget(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + llvm::SmallVector, + 0> + symbolAndClause; + auto deviceType = + getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause); + + auto mod = converter.getFirOpBuilder().getModule(); for (auto sym : symbolAndClause) { auto *op = mod.lookupSymbol(converter.mangleName(std::get<1>(sym))); @@ -2339,32 +2362,19 @@ converter.getCurrentLocation(), "Attempt to apply declare target on unsupproted operation"); - mlir::omp::DeclareTargetDeviceType newDeviceType; - switch (deviceType) { - case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: - newDeviceType = mlir::omp::DeclareTargetDeviceType::nohost; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Host: - newDeviceType = mlir::omp::DeclareTargetDeviceType::host; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Any: - newDeviceType = mlir::omp::DeclareTargetDeviceType::any; - break; - } - // The function or global already has a declare target applied to it, // very likely through implicit capture (usage in another declare // target function/subroutine). It should be marked as any if it has // been assigned both host and nohost, else we skip, as there is no // change if (declareTargetOp.isDeclareTarget()) { - if (declareTargetOp.getDeclareTargetDeviceType() != newDeviceType) + if (declareTargetOp.getDeclareTargetDeviceType() != deviceType) declareTargetOp.setDeclareTarget( mlir::omp::DeclareTargetDeviceType::any, std::get<0>(sym)); continue; } - declareTargetOp.setDeclareTarget(newDeviceType, std::get<0>(sym)); + declareTargetOp.setDeclareTarget(deviceType, std::get<0>(sym)); } } @@ -2394,7 +2404,11 @@ }, [&](const Fortran::parser::OpenMPRequiresConstruct &requiresConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct"); + // Requires directives are analyzed before any statements are + // lowered. Then, the result of combining the set of clauses of all + // requires directives present in the compilation unit is used to + // emit code, so no code is emitted independently for each + // "requires" instance. }, [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { // The directive is lowered when instantiating the variable to @@ -2616,3 +2630,81 @@ } } } + +std::optional +Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + llvm::SmallVector, + 0> + symbolAndClause; + auto deviceType = + getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause); + + // Return the device type only if at least one of the targets for the + // directive is a function or subroutine + auto mod = converter.getFirOpBuilder().getModule(); + for (auto sym : symbolAndClause) { + auto *op = mod.lookupSymbol(converter.mangleName(std::get<1>(sym))); + + if (mlir::isa(op)) + return deviceType; + } + + return std::nullopt; +} + +bool Fortran::lower::isOpenMPTargetConstruct( + const Fortran::parser::OpenMPConstruct &omp) { + if (const auto *blockDir = + std::get_if(&omp.u)) { + const auto &beginBlockDir{ + std::get(blockDir->t)}; + const auto &beginDir{ + std::get(beginBlockDir.t)}; + + switch (beginDir.v) { + case llvm::omp::Directive::OMPD_target: + case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_target_parallel_do: + case llvm::omp::Directive::OMPD_target_parallel_do_simd: + case llvm::omp::Directive::OMPD_target_simd: + case llvm::omp::Directive::OMPD_target_teams: + case llvm::omp::Directive::OMPD_target_teams_distribute: + case llvm::omp::Directive::OMPD_target_teams_distribute_simd: + return true; + default: + break; + } + } + + return false; +} + +omp::ClauseRequires Fortran::lower::extractOpenMPRequiresClauses( + const Fortran::parser::OmpClauseList &clauseList) { + using omp::ClauseRequires, Fortran::parser::OmpClause; + auto requiresFlags = ClauseRequires::none; + + for (const auto &clause : clauseList.v) { + if (std::get_if(&clause.u)) + requiresFlags = requiresFlags | ClauseRequires::dynamic_allocators; + else if (std::get_if(&clause.u)) + requiresFlags = requiresFlags | ClauseRequires::reverse_offload; + else if (std::get_if(&clause.u)) + requiresFlags = requiresFlags | ClauseRequires::unified_address; + else if (std::get_if(&clause.u)) + requiresFlags = requiresFlags | ClauseRequires::unified_shared_memory; + } + + return requiresFlags; +} + +void Fortran::lower::genOpenMPRequires(Operation *mod, + omp::ClauseRequires flags) { + if (auto offloadMod = llvm::dyn_cast(mod)) + offloadMod.setRequires(flags); +} diff --git a/flang/test/Lower/OpenMP/requires-notarget.f90 b/flang/test/Lower/OpenMP/requires-notarget.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires-notarget.f90 @@ -0,0 +1,11 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks that requires lowering into MLIR skips creating the +! omp.requires attribute with target-related clauses if there are no device +! functions in the compilation unit + +!CHECK: module attributes { +!CHECK-NOT: omp.requires +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires diff --git a/flang/test/Lower/OpenMP/requires.f90 b/flang/test/Lower/OpenMP/requires.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires.f90 @@ -0,0 +1,13 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks the lowering of requires into MLIR + +!CHECK: module attributes { +!CHECK-SAME: omp.requires = #omp +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires + +subroutine f + !$omp declare target +end subroutine f diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt --- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt @@ -12,4 +12,5 @@ LINK_LIBS PUBLIC MLIRIR MLIRLLVMDialect + MLIRFuncDialect )