diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h --- a/flang/include/flang/Semantics/symbol.h +++ b/flang/include/flang/Semantics/symbol.h @@ -14,6 +14,8 @@ #include "flang/Common/enum-set.h" #include "flang/Common/reference.h" #include "flang/Common/visit.h" +#include "flang/Parser/parse-tree.h" +#include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/DenseMapInfo.h" #include @@ -45,8 +47,47 @@ using MutableSymbolRef = common::Reference; using MutableSymbolVector = std::vector; +namespace { +LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE(); +// Flags representing OpenMP requires clauses. +enum OmpRequiresFlags { + None = 0x00, + ReverseOffload = 0x01, + UnifiedAddress = 0x02, + UnifiedSharedMemory = 0x04, + DynamicAllocators = 0x08, + LLVM_MARK_AS_BITMASK_ENUM(DynamicAllocators) +}; +} // namespace + +// Mixin for details with OpenMP declarative constructs. +class WithOmpDeclarative { + using OmpAtomicOrderType = parser::OmpAtomicDefaultMemOrderClause::Type; + +public: + bool has_ompRequires() const { return ompRequires_.has_value(); } + const OmpRequiresFlags *ompRequires() const { + return ompRequires_ ? &*ompRequires_ : nullptr; + } + void set_ompRequires(OmpRequiresFlags flags) { ompRequires_ = flags; } + + bool has_ompAtomicDefaultMemOrder() const { + return ompAtomicDefaultMemOrder_.has_value(); + } + const OmpAtomicOrderType *ompAtomicDefaultMemOrder() const { + return ompAtomicDefaultMemOrder_ ? &*ompAtomicDefaultMemOrder_ : nullptr; + } + void set_ompAtomicDefaultMemOrder(OmpAtomicOrderType flags) { + ompAtomicDefaultMemOrder_ = flags; + } + +private: + std::optional ompRequires_; + std::optional ompAtomicDefaultMemOrder_; +}; + // A module or submodule. -class ModuleDetails { +class ModuleDetails : public WithOmpDeclarative { public: ModuleDetails(bool isSubmodule = false) : isSubmodule_{isSubmodule} {} bool isSubmodule() const { return isSubmodule_; } @@ -63,7 +104,7 @@ const Scope *scope_{nullptr}; }; -class MainProgramDetails { +class MainProgramDetails : public WithOmpDeclarative { public: private: }; @@ -85,7 +126,7 @@ // A subroutine or function definition, or a subprogram interface defined // in an INTERFACE block as part of the definition of a dummy procedure // or a procedure pointer (with just POINTER). -class SubprogramDetails : public WithBindName { +class SubprogramDetails : public WithBindName, public WithOmpDeclarative { public: bool isFunction() const { return result_ != nullptr; } bool isInterface() const { return isInterface_; } diff --git a/flang/lib/Semantics/resolve-directives.h b/flang/lib/Semantics/resolve-directives.h --- a/flang/lib/Semantics/resolve-directives.h +++ b/flang/lib/Semantics/resolve-directives.h @@ -11,6 +11,7 @@ namespace Fortran::parser { struct Name; +struct Program; struct ProgramUnit; } // namespace Fortran::parser @@ -21,6 +22,7 @@ // Name resolution for OpenACC and OpenMP directives void ResolveAccParts(SemanticsContext &, const parser::ProgramUnit &); void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &); +void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &); } // namespace Fortran::semantics #endif diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -21,6 +21,13 @@ #include #include +template +static Fortran::semantics::Scope *GetScope( + Fortran::semantics::SemanticsContext &context, const T &x) { + std::optional source{GetSource(x)}; + return source ? &context.FindScope(*source) : nullptr; +} + namespace Fortran::semantics { template class DirectiveAttributeVisitor { @@ -320,11 +327,6 @@ return true; } - bool Pre(const parser::SpecificationPart &x) { - Walk(std::get>(x.t)); - return true; - } - bool Pre(const parser::StmtFunctionStmt &x) { const auto &parsedExpr{std::get>(x.t)}; if (const auto *expr{GetExpr(context_, parsedExpr)}) { @@ -372,6 +374,35 @@ bool Pre(const parser::OpenMPRequiresConstruct &x) { PushContext(x.source, llvm::omp::Directive::OMPD_requires); + + // Gather information from the clauses. + OmpRequiresFlags flags{OmpRequiresFlags::None}; + std::optional memOrder; + for (auto &clause : std::get(x.t).v) { + flags |= common::visit( + common::visitors{ + [&memOrder]( + const parser::OmpClause::AtomicDefaultMemOrder &atomic) { + memOrder = atomic.v.v; + return OmpRequiresFlags::None; + }, + [](const parser::OmpClause::ReverseOffload &) { + return OmpRequiresFlags::ReverseOffload; + }, + [](const parser::OmpClause::UnifiedAddress &) { + return OmpRequiresFlags::UnifiedAddress; + }, + [](const parser::OmpClause::UnifiedSharedMemory &) { + return OmpRequiresFlags::UnifiedSharedMemory; + }, + [](const parser::OmpClause::DynamicAllocators &) { + return OmpRequiresFlags::DynamicAllocators; + }, + [](const auto &) { return OmpRequiresFlags::None; }}, + clause.u); + } + // Merge clauses into parents' symbols details. + AddOmpRequiresToScope(&currScope(), flags, memOrder); return true; } void Post(const parser::OpenMPRequiresConstruct &) { PopContext(); } @@ -639,6 +670,9 @@ bool HasSymbolInEnclosingScope(const Symbol &, Scope &); std::int64_t ordCollapseLevel{0}; + + void AddOmpRequiresToScope(Scope *, OmpRequiresFlags, + std::optional); }; template @@ -2012,6 +2046,72 @@ } } +void ResolveOmpTopLevelParts( + SemanticsContext &context, const parser::Program &program) { + if (!context.IsEnabled(common::LanguageFeature::OpenMP)) { + return; + } + + // Gather REQUIRES clauses from all non-module top-level program unit symbols, + // combine them together ensuring compatibility and apply them to all these + // program units. Modules are skipped because their REQUIRES clauses should be + // propagated via USE statements instead. + OmpRequiresFlags combinedFlags{OmpRequiresFlags::None}; + std::optional combinedMemOrder; + + // Function to go through non-module top level program units and extract + // REQUIRES information to be processed by a function-like argument. + auto processProgramUnits{[&](auto processFn) { + for (const parser::ProgramUnit &unit : program.v) { + if (!std::get_if>(&unit.u) && + !std::get_if>(&unit.u)) { + Symbol *symbol{common::visit( + [&context]( + auto &x) { return GetScope(context, x.value())->symbol(); }, + unit.u)}; + + common::visit( + [&](auto &details) { + if constexpr (std::is_convertible_v) { + processFn(symbol, details); + } + }, + symbol->details()); + } + } + }}; + + // Combine global REQUIRES information from all program units except modules + // and submodules. + processProgramUnits([&](Symbol *symbol, WithOmpDeclarative &details) { + if (const OmpRequiresFlags * flags{details.ompRequires()}) { + combinedFlags |= *flags; + } + if (const parser::OmpAtomicDefaultMemOrderClause::Type * + memOrder{details.ompAtomicDefaultMemOrder()}) { + if (combinedMemOrder && *combinedMemOrder != *memOrder) { + context.Say(symbol->scope()->sourceRange(), + "Conflicting '%s' REQUIRES clauses found in compilation " + "unit"_err_en_US, + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::Clause::OMPC_atomic_default_mem_order) + .str())); + } + combinedMemOrder = *memOrder; + } + }); + + // Update all program units except modules and submodules with the combined + // global REQUIRES information. + processProgramUnits([&](Symbol *, WithOmpDeclarative &details) { + details.set_ompRequires(combinedFlags); + if (combinedMemOrder) { + details.set_ompAtomicDefaultMemOrder(*combinedMemOrder); + } + }); +} + void OmpAttributeVisitor::CheckDataCopyingClause( const parser::Name &name, const Symbol &symbol, Symbol::Flag ompFlag) { const auto *checkSymbol{&symbol}; @@ -2159,4 +2259,43 @@ parser::ToUpperCaseLetters( llvm::omp::getOpenMPDirectiveName(GetContext().directive).str())); } + +void OmpAttributeVisitor::AddOmpRequiresToScope(Scope *scope, + OmpRequiresFlags flags, + std::optional memOrder) { + do { + if (Symbol * symbol{scope->symbol()}) { + common::visit( + [&](auto &details) { + // Store clauses information into the symbol for the parent and + // enclosing modules, programs, functions and subroutines. + if constexpr (std::is_convertible_v) { + if (flags != OmpRequiresFlags::None) { + if (const OmpRequiresFlags * + otherFlags{details.ompRequires()}) { + flags |= *otherFlags; + } + details.set_ompRequires(flags); + } + if (memOrder) { + if (details.has_ompAtomicDefaultMemOrder() && + *details.ompAtomicDefaultMemOrder() != *memOrder) { + context_.Say(scope->sourceRange(), + "Conflicting '%s' REQUIRES clauses found in compilation " + "unit"_err_en_US, + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::Clause::OMPC_atomic_default_mem_order) + .str())); + } + details.set_ompAtomicDefaultMemOrder(*memOrder); + } + } + }, + symbol->details()); + } + scope = &scope->parent(); + } while (!scope->IsGlobal()); +} + } // namespace Fortran::semantics diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -8675,11 +8675,14 @@ } } -void ResolveNamesVisitor::Post(const parser::Program &) { +void ResolveNamesVisitor::Post(const parser::Program &x) { // ensure that all temps were deallocated CHECK(!attrs_); CHECK(!cudaDataAttr_); CHECK(!GetDeclTypeSpec()); + // Top-level resolution to propagate information across program units after + // each of them has been resolved separately. + ResolveOmpTopLevelParts(context(), x); } // A singleton instance of the scope -> IMPLICIT rules mapping is diff --git a/flang/test/Semantics/OpenMP/requires09.f90 b/flang/test/Semantics/OpenMP/requires09.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/OpenMP/requires09.f90 @@ -0,0 +1,14 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp +! OpenMP Version 5.0 +! 2.4 Requires directive +! All atomic_default_mem_order clauses in 'requires' directives found within a +! compilation unit must specify the same ordering. + +subroutine f + !$omp requires atomic_default_mem_order(seq_cst) +end subroutine f + +!ERROR: Conflicting 'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clauses found in compilation unit +subroutine g + !$omp requires atomic_default_mem_order(relaxed) +end subroutine g