diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -13,13 +13,9 @@ #ifndef FORTRAN_LOWER_OPENMP_H #define FORTRAN_LOWER_OPENMP_H +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include -namespace mlir { -class Value; -class Operation; -} // namespace mlir - namespace fir { class FirOpBuilder; class ConvertOp; @@ -56,6 +52,15 @@ void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value, mlir::Value, fir::ConvertOp * = nullptr); void removeStoreOp(mlir::Operation *, mlir::Value); + +std::optional +processOpenMPRequiresClauses(const Fortran::parser::OmpClauseList &, + mlir::omp::ClauseRequires &); +void genOpenMPRequires(AbstractConverter &, mlir::omp::ClauseRequires, + std::optional); + +bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &); + } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -51,6 +51,7 @@ #include "flang/Semantics/runtime-type-info.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/RegionUtils.h" @@ -265,7 +266,8 @@ public: explicit FirConverter(Fortran::lower::LoweringBridge &bridge) : Fortran::lower::AbstractConverter(bridge.getLoweringOptions()), - bridge{bridge}, foldingContext{bridge.createFoldingContext()} {} + bridge{bridge}, foldingContext{bridge.createFoldingContext()}, + ompRequiresFlags{mlir::omp::ClauseRequires::none} {} virtual ~FirConverter() = default; /// Convert the PFT to FIR. @@ -292,12 +294,16 @@ if (f.isMainProgram()) hasMainProgram = true; declareFunction(f); + analyzeOpenMPDeclarative(f.evaluationList); }, [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerModuleDeclScope(m); + analyzeOpenMPDeclarative(m.evaluationList); for (Fortran::lower::pft::FunctionLikeUnit &f : - m.nestedFunctions) + m.nestedFunctions) { declareFunction(f); + analyzeOpenMPDeclarative(f.evaluationList); + } }, [&](Fortran::lower::pft::BlockDataUnit &b) {}, [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {}, @@ -347,6 +353,16 @@ fir::runtime::genEnvironmentDefaults(*builder, toLocation(), bridge.getEnvironmentDefaults()); }); + + // Create the module attributes related to OpenMP requires directives + // TODO Only generate if there is something to be offloaded in this + // compilation unit (i.e. ompTargetRegionFound=true or + // ompDeclareTargetRegionFound=true...) + if (auto mod = llvm::dyn_cast( + getModuleOp().getOperation())) { + if (!mod.getIsDevice()) + genOpenMPRequires(*this, ompRequiresFlags, ompAtomicDefaultMemOrder); + } } /// Declare a function. @@ -1141,6 +1157,68 @@ activeConstructStack.pop_back(); } + /// Perform a pass to gather compilation unit-level data from OpenMP + /// declarative constructs. This must be done prior to lowering, to ensure + /// data is available to the lowering pass. + void analyzeOpenMPDeclarative( + const Fortran::lower::pft::EvaluationList &evaluationList) { + // Populate ompTargetRegionFound and ompDeclareTargetRegionFound during + // analysis, so that semantically necessary ordering information between + // the requires directive and other OpenMP directives is present. + auto analyzeRequires = + [&](const Fortran::parser::OpenMPRequiresConstruct &ompReq) { + using mlir::omp::ClauseRequires; + + mlir::omp::ClauseRequires requiresFlags; + auto atomicDefaultMemOrder = + Fortran::lower::processOpenMPRequiresClauses( + std::get(ompReq.t), + requiresFlags); + + if (bitEnumContainsAny(requiresFlags, + ClauseRequires::reverse_offload | + ClauseRequires::unified_address | + ClauseRequires::unified_shared_memory) && + (ompTargetRegionFound || ompDeclareTargetRegionFound)) + mlir::emitError(toLocation(), + "requires directive specifies a reverse_offload, " + "unified_address or unified_shared_memory " + "requirement lexically after a device construct"); + + if (requiresFlags != ClauseRequires::none) + ompRequiresFlags = ompRequiresFlags | requiresFlags; + + if (atomicDefaultMemOrder) + ompAtomicDefaultMemOrder = atomicDefaultMemOrder; + }; + + auto analyzeDeclareTarget = + [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) { + // Only register that a "declare target" region is found here + ompDeclareTargetRegionFound = true; + }; + + for (const Fortran::lower::pft::Evaluation &eval : evaluationList) { + if (const auto *ompDecl = + eval.getIf()) { + std::visit( + Fortran::common::visitors{ + analyzeRequires, + analyzeDeclareTarget, + // Add other OpenMP declarative constructs currently skipped + [&](const auto &) {}, + }, + ompDecl->u); + } else if (const auto *ompDecl = + eval.getIf()) { + // Register if a target region is found + ompTargetRegionFound = + ompTargetRegionFound || + Fortran::lower::isOpenMPTargetConstruct(*ompDecl); + } + } + } + //===--------------------------------------------------------------------===// // Termination of symbolically referenced execution units //===--------------------------------------------------------------------===// @@ -4064,6 +4142,13 @@ /// Tuple of host associated variables mlir::Value hostAssocTuple; + + /// OpenMP Requires flags + mlir::omp::ClauseRequires ompRequiresFlags; + std::optional ompAtomicDefaultMemOrder; + + bool ompTargetRegionFound = false; + bool ompDeclareTargetRegionFound = false; }; } // namespace diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -2211,7 +2211,11 @@ }, [&](const Fortran::parser::OpenMPRequiresConstruct &requiresConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct"); + // Requires directives are analyzed before any statements are + // lowered. Then, the result of combining the set of clauses of all + // requires directives present in the compilation unit is used to + // emit code, so no code is emitted independently for each + // "requires" instance. }, [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { // The directive is lowered when instantiating the variable to @@ -2418,3 +2422,83 @@ } } } + +std::optional +Fortran::lower::processOpenMPRequiresClauses( + const Fortran::parser::OmpClauseList &clauseList, + omp::ClauseRequires &requiresFlags) { + std::optional atomicDefaultMemOrder; + requiresFlags = omp::ClauseRequires::none; + + for (const auto &clause : clauseList.v) { + if (const auto &atomicClause = + std::get_if( + &clause.u)) { + switch (atomicClause->v.v) { + case Fortran::parser::OmpAtomicDefaultMemOrderClause::Type::SeqCst: + atomicDefaultMemOrder = omp::ClauseMemoryOrderKind::Seq_cst; + break; + case Fortran::parser::OmpAtomicDefaultMemOrderClause::Type::AcqRel: + atomicDefaultMemOrder = omp::ClauseMemoryOrderKind::Acq_rel; + break; + case Fortran::parser::OmpAtomicDefaultMemOrderClause::Type::Relaxed: + atomicDefaultMemOrder = omp::ClauseMemoryOrderKind::Relaxed; + break; + } + } else if (std::get_if( + &clause.u)) { + requiresFlags = requiresFlags | omp::ClauseRequires::dynamic_allocators; + } else if (std::get_if( + &clause.u)) { + requiresFlags = requiresFlags | omp::ClauseRequires::reverse_offload; + } else if (std::get_if( + &clause.u)) { + requiresFlags = requiresFlags | omp::ClauseRequires::unified_address; + } else if (std::get_if( + &clause.u)) { + requiresFlags = + requiresFlags | omp::ClauseRequires::unified_shared_memory; + } + } + + return atomicDefaultMemOrder; +} + +void Fortran::lower::genOpenMPRequires( + Fortran::lower::AbstractConverter &converter, + omp::ClauseRequires requiresFlags, + std::optional atomicDefaultMemOrder) { + auto mod = + cast(converter.getModuleOp().getOperation()); + mod.setRequires(requiresFlags); + + if (atomicDefaultMemOrder) + mod.setAtomicDefaultMemOrder(*atomicDefaultMemOrder); +} + +bool Fortran::lower::isOpenMPTargetConstruct( + const Fortran::parser::OpenMPConstruct &omp) { + if (const auto *blockDir = + std::get_if(&omp.u)) { + const auto &beginBlockDir{ + std::get(blockDir->t)}; + const auto &beginDir{ + std::get(beginBlockDir.t)}; + + switch (beginDir.v) { + case llvm::omp::Directive::OMPD_target: + case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_target_parallel_do: + case llvm::omp::Directive::OMPD_target_parallel_do_simd: + case llvm::omp::Directive::OMPD_target_simd: + case llvm::omp::Directive::OMPD_target_teams: + case llvm::omp::Directive::OMPD_target_teams_distribute: + case llvm::omp::Directive::OMPD_target_teams_distribute_simd: + return true; + default: + break; + } + } + + return false; +} diff --git a/flang/test/Lower/OpenMP/requires.f90 b/flang/test/Lower/OpenMP/requires.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires.f90 @@ -0,0 +1,16 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck --check-prefix=MLIR %s + +! This test checks the lowering of requires into MLIR + +! MLIR attributes +!MLIR: module attributes { +!MLIR-SAME: omp.atomic_default_mem_order = #omp +!MLIR-SAME: omp.requires = #omp + +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) + integer :: x, y + + !$omp atomic read + x = y +end program requires