diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h --- a/flang/include/flang/Semantics/symbol.h +++ b/flang/include/flang/Semantics/symbol.h @@ -14,6 +14,7 @@ #include "flang/Common/enum-set.h" #include "flang/Common/reference.h" #include "flang/Common/visit.h" +#include "flang/Parser/parse-tree.h" #include "llvm/ADT/DenseMapInfo.h" #include @@ -45,8 +46,38 @@ using MutableSymbolRef = common::Reference; using MutableSymbolVector = std::vector; +// Mixin for details with OpenMP declarative constructs. +class WithOmpDeclarative { + using OmpAtomicOrderType = parser::OmpAtomicDefaultMemOrderClause::Type; + +public: + ENUM_CLASS(RequiresFlag, ReverseOffload, UnifiedAddress, UnifiedSharedMemory, + DynamicAllocators); + using RequiresFlags = common::EnumSet; + + bool has_ompRequires() const { return ompRequires_.has_value(); } + const RequiresFlags *ompRequires() const { + return ompRequires_ ? &*ompRequires_ : nullptr; + } + void set_ompRequires(RequiresFlags flags) { ompRequires_ = flags; } + + bool has_ompAtomicDefaultMemOrder() const { + return ompAtomicDefaultMemOrder_.has_value(); + } + const OmpAtomicOrderType *ompAtomicDefaultMemOrder() const { + return ompAtomicDefaultMemOrder_ ? &*ompAtomicDefaultMemOrder_ : nullptr; + } + void set_ompAtomicDefaultMemOrder(OmpAtomicOrderType flags) { + ompAtomicDefaultMemOrder_ = flags; + } + +private: + std::optional ompRequires_; + std::optional ompAtomicDefaultMemOrder_; +}; + // A module or submodule. -class ModuleDetails { +class ModuleDetails : public WithOmpDeclarative { public: ModuleDetails(bool isSubmodule = false) : isSubmodule_{isSubmodule} {} bool isSubmodule() const { return isSubmodule_; } @@ -63,7 +94,7 @@ const Scope *scope_{nullptr}; }; -class MainProgramDetails { +class MainProgramDetails : public WithOmpDeclarative { public: private: }; @@ -114,7 +145,7 @@ // A subroutine or function definition, or a subprogram interface defined // in an INTERFACE block as part of the definition of a dummy procedure // or a procedure pointer (with just POINTER). -class SubprogramDetails : public WithBindName { +class SubprogramDetails : public WithBindName, public WithOmpDeclarative { public: bool isFunction() const { return result_ != nullptr; } bool isInterface() const { return isInterface_; } diff --git a/flang/lib/Semantics/resolve-directives.h b/flang/lib/Semantics/resolve-directives.h --- a/flang/lib/Semantics/resolve-directives.h +++ b/flang/lib/Semantics/resolve-directives.h @@ -11,6 +11,7 @@ namespace Fortran::parser { struct Name; +struct Program; struct ProgramUnit; } // namespace Fortran::parser @@ -21,6 +22,7 @@ // Name resolution for OpenACC and OpenMP directives void ResolveAccParts(SemanticsContext &, const parser::ProgramUnit &); void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &); +void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &); } // namespace Fortran::semantics #endif diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -22,6 +22,13 @@ #include #include +template +static Fortran::semantics::Scope *GetScope( + Fortran::semantics::SemanticsContext &context, const T &x) { + std::optional source{GetSource(x)}; + return source ? &context.FindScope(*source) : nullptr; +} + namespace Fortran::semantics { template class DirectiveAttributeVisitor { @@ -324,11 +331,6 @@ return true; } - bool Pre(const parser::SpecificationPart &x) { - Walk(std::get>(x.t)); - return true; - } - bool Pre(const parser::StmtFunctionStmt &x) { const auto &parsedExpr{std::get>(x.t)}; if (const auto *expr{GetExpr(context_, parsedExpr)}) { @@ -375,7 +377,38 @@ void Post(const parser::OpenMPDeclareSimdConstruct &) { PopContext(); } bool Pre(const parser::OpenMPRequiresConstruct &x) { + using Flags = WithOmpDeclarative::RequiresFlags; + using Requires = WithOmpDeclarative::RequiresFlag; PushContext(x.source, llvm::omp::Directive::OMPD_requires); + + // Gather information from the clauses. + Flags flags; + std::optional memOrder; + for (const auto &clause : std::get(x.t).v) { + flags |= common::visit( + common::visitors{ + [&memOrder]( + const parser::OmpClause::AtomicDefaultMemOrder &atomic) { + memOrder = atomic.v.v; + return Flags{}; + }, + [](const parser::OmpClause::ReverseOffload &) { + return Flags{Requires::ReverseOffload}; + }, + [](const parser::OmpClause::UnifiedAddress &) { + return Flags{Requires::UnifiedAddress}; + }, + [](const parser::OmpClause::UnifiedSharedMemory &) { + return Flags{Requires::UnifiedSharedMemory}; + }, + [](const parser::OmpClause::DynamicAllocators &) { + return Flags{Requires::DynamicAllocators}; + }, + [](const auto &) { return Flags{}; }}, + clause.u); + } + // Merge clauses into parents' symbols details. + AddOmpRequiresToScope(currScope(), flags, memOrder); return true; } void Post(const parser::OpenMPRequiresConstruct &) { PopContext(); } @@ -643,6 +676,9 @@ bool HasSymbolInEnclosingScope(const Symbol &, Scope &); std::int64_t ordCollapseLevel{0}; + + void AddOmpRequiresToScope(Scope &, WithOmpDeclarative::RequiresFlags, + std::optional); }; template @@ -2101,6 +2137,77 @@ } } +void ResolveOmpTopLevelParts( + SemanticsContext &context, const parser::Program &program) { + if (!context.IsEnabled(common::LanguageFeature::OpenMP)) { + return; + } + + // Gather REQUIRES clauses from all non-module top-level program unit symbols, + // combine them together ensuring compatibility and apply them to all these + // program units. Modules are skipped because their REQUIRES clauses should be + // propagated via USE statements instead. + WithOmpDeclarative::RequiresFlags combinedFlags; + std::optional combinedMemOrder; + + // Function to go through non-module top level program units and extract + // REQUIRES information to be processed by a function-like argument. + auto processProgramUnits{[&](auto processFn) { + for (const parser::ProgramUnit &unit : program.v) { + if (!std::holds_alternative>( + unit.u) && + !std::holds_alternative>( + unit.u)) { + Symbol *symbol{common::visit( + [&context]( + auto &x) { return GetScope(context, x.value())->symbol(); }, + unit.u)}; + + common::visit( + [&](auto &details) { + if constexpr (std::is_convertible_v) { + processFn(*symbol, details); + } + }, + symbol->details()); + } + } + }}; + + // Combine global REQUIRES information from all program units except modules + // and submodules. + processProgramUnits([&](Symbol &symbol, WithOmpDeclarative &details) { + if (const WithOmpDeclarative::RequiresFlags * + flags{details.ompRequires()}) { + combinedFlags |= *flags; + } + if (const parser::OmpAtomicDefaultMemOrderClause::Type * + memOrder{details.ompAtomicDefaultMemOrder()}) { + if (combinedMemOrder && *combinedMemOrder != *memOrder) { + context.Say(symbol.scope()->sourceRange(), + "Conflicting '%s' REQUIRES clauses found in compilation " + "unit"_err_en_US, + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::Clause::OMPC_atomic_default_mem_order) + .str())); + } + combinedMemOrder = *memOrder; + } + }); + + // Update all program units except modules and submodules with the combined + // global REQUIRES information. + processProgramUnits([&](Symbol &, WithOmpDeclarative &details) { + if (combinedFlags.any()) { + details.set_ompRequires(combinedFlags); + } + if (combinedMemOrder) { + details.set_ompAtomicDefaultMemOrder(*combinedMemOrder); + } + }); +} + void OmpAttributeVisitor::CheckDataCopyingClause( const parser::Name &name, const Symbol &symbol, Symbol::Flag ompFlag) { const auto *checkSymbol{&symbol}; @@ -2248,4 +2355,44 @@ parser::ToUpperCaseLetters( llvm::omp::getOpenMPDirectiveName(GetContext().directive).str())); } + +void OmpAttributeVisitor::AddOmpRequiresToScope(Scope &scope, + WithOmpDeclarative::RequiresFlags flags, + std::optional memOrder) { + Scope *scopeIter = &scope; + do { + if (Symbol * symbol{scopeIter->symbol()}) { + common::visit( + [&](auto &details) { + // Store clauses information into the symbol for the parent and + // enclosing modules, programs, functions and subroutines. + if constexpr (std::is_convertible_v) { + if (flags.any()) { + if (const WithOmpDeclarative::RequiresFlags * + otherFlags{details.ompRequires()}) { + flags |= *otherFlags; + } + details.set_ompRequires(flags); + } + if (memOrder) { + if (details.has_ompAtomicDefaultMemOrder() && + *details.ompAtomicDefaultMemOrder() != *memOrder) { + context_.Say(scopeIter->sourceRange(), + "Conflicting '%s' REQUIRES clauses found in compilation " + "unit"_err_en_US, + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::Clause::OMPC_atomic_default_mem_order) + .str())); + } + details.set_ompAtomicDefaultMemOrder(*memOrder); + } + } + }, + symbol->details()); + } + scopeIter = &scopeIter->parent(); + } while (!scopeIter->IsGlobal()); +} + } // namespace Fortran::semantics diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -8699,11 +8699,14 @@ } } -void ResolveNamesVisitor::Post(const parser::Program &) { +void ResolveNamesVisitor::Post(const parser::Program &x) { // ensure that all temps were deallocated CHECK(!attrs_); CHECK(!cudaDataAttr_); CHECK(!GetDeclTypeSpec()); + // Top-level resolution to propagate information across program units after + // each of them has been resolved separately. + ResolveOmpTopLevelParts(context(), x); } // A singleton instance of the scope -> IMPLICIT rules mapping is diff --git a/flang/test/Semantics/OpenMP/requires09.f90 b/flang/test/Semantics/OpenMP/requires09.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/OpenMP/requires09.f90 @@ -0,0 +1,14 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp +! OpenMP Version 5.0 +! 2.4 Requires directive +! All atomic_default_mem_order clauses in 'requires' directives found within a +! compilation unit must specify the same ordering. + +subroutine f + !$omp requires atomic_default_mem_order(seq_cst) +end subroutine f + +!ERROR: Conflicting 'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clauses found in compilation unit +subroutine g + !$omp requires atomic_default_mem_order(relaxed) +end subroutine g