diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -13,13 +13,9 @@ #ifndef FORTRAN_LOWER_OPENMP_H #define FORTRAN_LOWER_OPENMP_H +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include -namespace mlir { -class Value; -class Operation; -} // namespace mlir - namespace fir { class FirOpBuilder; class ConvertOp; @@ -29,6 +25,7 @@ namespace parser { struct OpenMPConstruct; struct OpenMPDeclarativeConstruct; +struct OpenMPDeclareTargetConstruct; struct OmpEndLoopDirective; struct OmpClauseList; } // namespace parser @@ -56,6 +53,17 @@ void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value, mlir::Value, fir::ConvertOp * = nullptr); void removeStoreOp(mlir::Operation *, mlir::Value); + +std::optional +getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &, + const Fortran::parser::OpenMPDeclareTargetConstruct &); +std::optional +extractOpenMPRequiresClauses(const Fortran::parser::OmpClauseList &, + mlir::omp::ClauseRequires &); + +bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &); + } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -52,6 +52,7 @@ #include "flang/Semantics/runtime-type-info.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/RegionUtils.h" @@ -266,7 +267,8 @@ public: explicit FirConverter(Fortran::lower::LoweringBridge &bridge) : Fortran::lower::AbstractConverter(bridge.getLoweringOptions()), - bridge{bridge}, foldingContext{bridge.createFoldingContext()} {} + bridge{bridge}, foldingContext{bridge.createFoldingContext()}, + ompRequiresFlags{mlir::omp::ClauseRequires::none} {} virtual ~FirConverter() = default; /// Convert the PFT to FIR. @@ -343,6 +345,16 @@ fir::runtime::genEnvironmentDefaults(*builder, toLocation(), bridge.getEnvironmentDefaults()); }); + + // Set the module attributes related to OpenMP requires directives + if (auto mod = llvm::dyn_cast( + getModuleOp().getOperation())) { + if (ompDeviceCodeFound) + mod.setRequires(ompRequiresFlags); + + if (ompAtomicDefaultMemOrder) + mod.setAtomicDefaultMemOrder(*ompAtomicDefaultMemOrder); + } } /// Declare a function. @@ -2053,10 +2065,57 @@ localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); + + // Register if a target region was found + ompDeviceCodeFound = + ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp); + } + + /// Extract information from OpenMP declarative constructs + void analyzeOpenMPDeclarative( + const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { + auto analyzeRequires = + [&](const Fortran::parser::OpenMPRequiresConstruct &ompReq) { + using mlir::omp::ClauseRequires; + + mlir::omp::ClauseRequires requiresFlags; + auto atomicDefaultMemOrder = + Fortran::lower::extractOpenMPRequiresClauses( + std::get(ompReq.t), + requiresFlags); + + if (requiresFlags != ClauseRequires::none) + ompRequiresFlags = ompRequiresFlags | requiresFlags; + + if (atomicDefaultMemOrder) + ompAtomicDefaultMemOrder = atomicDefaultMemOrder; + }; + + auto analyzeDeclareTarget = + [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) { + auto targetType = + Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + *this, getEval(), ompReq); + + ompDeviceCodeFound = + ompDeviceCodeFound || + (targetType && + *targetType != mlir::omp::DeclareTargetDeviceType::host); + }; + + std::visit( + Fortran::common::visitors{ + analyzeRequires, + analyzeDeclareTarget, + // Add other OpenMP declarative constructs currently skipped + [&](const auto &) {}, + }, + ompDecl.u); } void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); + analyzeOpenMPDeclarative(ompDecl); genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl); for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) genFIR(e); @@ -4098,6 +4157,17 @@ /// Tuple of host associated variables mlir::Value hostAssocTuple; + + /// OpenMP Requires flags + mlir::omp::ClauseRequires ompRequiresFlags; + + /// OpenMP Default memory order for atomic operations, as defined by a + /// 'requires' directive + std::optional ompAtomicDefaultMemOrder; + + /// Whether an OpenMP target region or declare target function/subroutine + /// intended for device offloading has been detected + bool ompDeviceCodeFound = false; }; } // namespace diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -24,7 +24,9 @@ #include "flang/Semantics/tools.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include using namespace mlir; @@ -2262,14 +2264,14 @@ converter.bindSymbol(sym, symThreadprivateExv); } -void handleDeclareTarget(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclareTargetConstruct - &declareTargetConstruct) { - llvm::SmallVector, - 0> - symbolAndClause; +/// Extract the list of function and variable symbols affected by the given +/// 'declare target' directive and return the intended device type for them. +static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + SmallVectorImpl> &symbolAndClause) { + // Gather the symbols and clauses auto findFuncAndVarSyms = [&](const Fortran::parser::OmpObjectList &objList, mlir::omp::DeclareTargetCaptureClause clause) { for (const auto &ompObject : objList.v) { @@ -2289,12 +2291,11 @@ } }; + // The default capture type + auto deviceType = Fortran::parser::OmpDeviceTypeClause::Type::Any; const auto &spec{std::get( declareTargetConstruct.t)}; - auto mod = converter.getFirOpBuilder().getModule(); - // The default capture type - auto deviceType = Fortran::parser::OmpDeviceTypeClause::Type::Any; if (const auto *objectList{ Fortran::parser::Unwrap(spec.u)}) { // Case: declare target(func, var1, var2) @@ -2330,22 +2331,31 @@ } } + switch (deviceType) { + case Fortran::parser::OmpDeviceTypeClause::Type::Any: + return mlir::omp::DeclareTargetDeviceType::any; + case Fortran::parser::OmpDeviceTypeClause::Type::Host: + return mlir::omp::DeclareTargetDeviceType::host; + case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: + return mlir::omp::DeclareTargetDeviceType::nohost; + } +} + +void handleDeclareTarget(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + llvm::SmallVector, + 0> + symbolAndClause; + auto deviceType = + getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause); + + auto mod = converter.getFirOpBuilder().getModule(); for (auto sym : symbolAndClause) { auto *op = mod.lookupSymbol(converter.mangleName(std::get<1>(sym))); - mlir::omp::DeclareTargetDeviceType newDeviceType; - switch (deviceType) { - case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: - newDeviceType = mlir::omp::DeclareTargetDeviceType::nohost; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Host: - newDeviceType = mlir::omp::DeclareTargetDeviceType::host; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Any: - newDeviceType = mlir::omp::DeclareTargetDeviceType::any; - break; - } - // The function or global already has a declare target applied to it, // very likely through implicit capture (usage in another declare // target function/subroutine). It should be marked as any if it has @@ -2353,13 +2363,13 @@ // change if (mlir::omp::OpenMPDialect::isDeclareTarget(op)) { if (mlir::omp::OpenMPDialect::getDeclareTargetDeviceType(op) != - newDeviceType) + deviceType) mlir::omp::OpenMPDialect::setDeclareTarget( op, mlir::omp::DeclareTargetDeviceType::any, std::get<0>(sym)); continue; } - mlir::omp::OpenMPDialect::setDeclareTarget(op, newDeviceType, + mlir::omp::OpenMPDialect::setDeclareTarget(op, deviceType, std::get<0>(sym)); } } @@ -2390,7 +2400,11 @@ }, [&](const Fortran::parser::OpenMPRequiresConstruct &requiresConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct"); + // Requires directives are analyzed before any statements are + // lowered. Then, the result of combining the set of clauses of all + // requires directives present in the compilation unit is used to + // emit code, so no code is emitted independently for each + // "requires" instance. }, [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { // The directive is lowered when instantiating the variable to @@ -2612,3 +2626,97 @@ } } } + +std::optional +Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + llvm::SmallVector, + 0> + symbolAndClause; + auto deviceType = + getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause); + + // Return the device type only if at least one of the targets for the + // directive is a function or subroutine + auto mod = converter.getFirOpBuilder().getModule(); + for (auto sym : symbolAndClause) { + auto *op = mod.lookupSymbol(converter.mangleName(std::get<1>(sym))); + + if (mlir::isa(op)) + return deviceType; + } + + return std::nullopt; +} + +std::optional +Fortran::lower::extractOpenMPRequiresClauses( + const Fortran::parser::OmpClauseList &clauseList, + omp::ClauseRequires &requiresFlags) { + std::optional atomicDefaultMemOrder; + requiresFlags = omp::ClauseRequires::none; + + for (const auto &clause : clauseList.v) { + if (const auto &atomicClause = + std::get_if( + &clause.u)) { + switch (atomicClause->v.v) { + case Fortran::parser::OmpAtomicDefaultMemOrderClause::Type::SeqCst: + atomicDefaultMemOrder = omp::ClauseMemoryOrderKind::Seq_cst; + break; + case Fortran::parser::OmpAtomicDefaultMemOrderClause::Type::AcqRel: + atomicDefaultMemOrder = omp::ClauseMemoryOrderKind::Acq_rel; + break; + case Fortran::parser::OmpAtomicDefaultMemOrderClause::Type::Relaxed: + atomicDefaultMemOrder = omp::ClauseMemoryOrderKind::Relaxed; + break; + } + } else if (std::get_if( + &clause.u)) { + requiresFlags = requiresFlags | omp::ClauseRequires::dynamic_allocators; + } else if (std::get_if( + &clause.u)) { + requiresFlags = requiresFlags | omp::ClauseRequires::reverse_offload; + } else if (std::get_if( + &clause.u)) { + requiresFlags = requiresFlags | omp::ClauseRequires::unified_address; + } else if (std::get_if( + &clause.u)) { + requiresFlags = + requiresFlags | omp::ClauseRequires::unified_shared_memory; + } + } + + return atomicDefaultMemOrder; +} + +bool Fortran::lower::isOpenMPTargetConstruct( + const Fortran::parser::OpenMPConstruct &omp) { + if (const auto *blockDir = + std::get_if(&omp.u)) { + const auto &beginBlockDir{ + std::get(blockDir->t)}; + const auto &beginDir{ + std::get(beginBlockDir.t)}; + + switch (beginDir.v) { + case llvm::omp::Directive::OMPD_target: + case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_target_parallel_do: + case llvm::omp::Directive::OMPD_target_parallel_do_simd: + case llvm::omp::Directive::OMPD_target_simd: + case llvm::omp::Directive::OMPD_target_teams: + case llvm::omp::Directive::OMPD_target_teams_distribute: + case llvm::omp::Directive::OMPD_target_teams_distribute_simd: + return true; + default: + break; + } + } + + return false; +} diff --git a/flang/test/Lower/OpenMP/omp-declare-target-data.f90 b/flang/test/Lower/OpenMP/omp-declare-target-data.f90 --- a/flang/test/Lower/OpenMP/omp-declare-target-data.f90 +++ b/flang/test/Lower/OpenMP/omp-declare-target-data.f90 @@ -1,7 +1,8 @@ !RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s !RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-device %s -o - | FileCheck %s -!RUN: %flang_fc1 -emit-llvm-bc -fopenmp -o %t.bc %s | llvm-dis %t.bc -o - | FileCheck %s --check-prefix=HOST -!RUN: %flang_fc1 -emit-llvm -fopenmp -fopenmp-is-device -fopenmp-host-ir-file-path %t.bc -o - %s 2>&1 | FileCheck %s --check-prefix=DEVICE +!COM: TODO Uncomment following commands, once all required declare target lowering work lands +!COM: %flang_fc1 -emit-llvm-bc -fopenmp -o %t.bc %s | llvm-dis %t.bc -o - | FileCheck %s --check-prefix=HOST +!COM: %flang_fc1 -emit-llvm -fopenmp -fopenmp-is-device -fopenmp-host-ir-file-path %t.bc -o - %s 2>&1 | FileCheck %s --check-prefix=DEVICE !HOST-DAG: %struct.__tgt_offload_entry = type { ptr, ptr, i64, i32, i32 } !HOST-DAG: !omp_offload.info = !{!{{.*}}} diff --git a/flang/test/Lower/OpenMP/requires-notarget.f90 b/flang/test/Lower/OpenMP/requires-notarget.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires-notarget.f90 @@ -0,0 +1,12 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks that requires lowering into MLIR skips creating the +! omp.requires attribute with target-related clauses if there are no device +! functions in the compilation unit + +!CHECK: module attributes { +!CHECK-SAME: omp.atomic_default_mem_order = #omp +!CHECK-NOT: omp.requires +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires diff --git a/flang/test/Lower/OpenMP/requires.f90 b/flang/test/Lower/OpenMP/requires.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires.f90 @@ -0,0 +1,14 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks the lowering of requires into MLIR + +!CHECK: module attributes { +!CHECK-SAME: omp.atomic_default_mem_order = #omp +!CHECK-SAME: omp.requires = #omp +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires + +subroutine f + !$omp declare target +end subroutine f