diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h --- a/flang/include/flang/Semantics/symbol.h +++ b/flang/include/flang/Semantics/symbol.h @@ -14,6 +14,8 @@ #include "flang/Common/enum-set.h" #include "flang/Common/reference.h" #include "flang/Common/visit.h" +#include "flang/Parser/parse-tree.h" +#include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/DenseMapInfo.h" #include @@ -45,8 +47,47 @@ using MutableSymbolRef = common::Reference; using MutableSymbolVector = std::vector; +namespace { +LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE(); +// Flags representing OpenMP requires clauses. +enum OmpRequiresFlags { + None = 0x00, + ReverseOffload = 0x01, + UnifiedAddress = 0x02, + UnifiedSharedMemory = 0x04, + DynamicAllocators = 0x08, + LLVM_MARK_AS_BITMASK_ENUM(DynamicAllocators) +}; +} // namespace + +// Mixin for details with OpenMP declarative constructs. +class WithOmpDeclarative { + using OmpAtomicOrderType = parser::OmpAtomicDefaultMemOrderClause::Type; + +public: + bool has_ompRequires() const { return ompRequires_.has_value(); } + const OmpRequiresFlags *ompRequires() const { + return ompRequires_ ? &*ompRequires_ : nullptr; + } + void set_ompRequires(OmpRequiresFlags flags) { ompRequires_ = flags; } + + bool has_ompAtomicDefaultMemOrder() const { + return ompAtomicDefaultMemOrder_.has_value(); + } + const OmpAtomicOrderType *ompAtomicDefaultMemOrder() const { + return ompAtomicDefaultMemOrder_ ? &*ompAtomicDefaultMemOrder_ : nullptr; + } + void set_ompAtomicDefaultMemOrder(OmpAtomicOrderType flags) { + ompAtomicDefaultMemOrder_ = flags; + } + +private: + std::optional ompRequires_; + std::optional ompAtomicDefaultMemOrder_; +}; + // A module or submodule. -class ModuleDetails { +class ModuleDetails : public WithOmpDeclarative { public: ModuleDetails(bool isSubmodule = false) : isSubmodule_{isSubmodule} {} bool isSubmodule() const { return isSubmodule_; } @@ -63,7 +104,7 @@ const Scope *scope_{nullptr}; }; -class MainProgramDetails { +class MainProgramDetails : public WithOmpDeclarative { public: private: }; @@ -85,7 +126,7 @@ // A subroutine or function definition, or a subprogram interface defined // in an INTERFACE block as part of the definition of a dummy procedure // or a procedure pointer (with just POINTER). -class SubprogramDetails : public WithBindName { +class SubprogramDetails : public WithBindName, public WithOmpDeclarative { public: bool isFunction() const { return result_ != nullptr; } bool isInterface() const { return isInterface_; } diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -608,7 +608,7 @@ // 2.4 Requires construct TYPE_PARSER(sourced(construct( - verbatim("REQUIRES"_tok), some(Parser{} / maybe(","_tok))))) + verbatim("REQUIRES"_tok), Parser{}))) // 2.15.2 Threadprivate directive TYPE_PARSER(sourced(construct( diff --git a/flang/lib/Semantics/CMakeLists.txt b/flang/lib/Semantics/CMakeLists.txt --- a/flang/lib/Semantics/CMakeLists.txt +++ b/flang/lib/Semantics/CMakeLists.txt @@ -36,6 +36,7 @@ resolve-directives.cpp resolve-names-utils.cpp resolve-names.cpp + rewrite-directives.cpp rewrite-parse-tree.cpp runtime-type-info.cpp scope.cpp diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h --- a/flang/lib/Semantics/check-omp-structure.h +++ b/flang/lib/Semantics/check-omp-structure.h @@ -182,6 +182,8 @@ void Enter(const parser::OmpAtomicCapture &); void Leave(const parser::OmpAtomic &); + void Enter(const parser::UseStmt &); + #define GEN_FLANG_CLAUSE_CHECK_ENTER #include "llvm/Frontend/OpenMP/OMP.inc" @@ -268,6 +270,7 @@ const parser::OmpObjectList &ompObjectList); void CheckPredefinedAllocatorRestriction( const parser::CharBlock &source, const parser::Name &name); + void CheckAllowedRequiresClause(llvmOmpClause clause); bool isPredefinedAllocator{false}; void EnterDirectiveNest(const int index) { directiveNest_[index]++; } void ExitDirectiveNest(const int index) { directiveNest_[index]--; } @@ -281,6 +284,7 @@ LastType }; int directiveNest_[LastType + 1] = {0}; + bool deviceConstructFound_{false}; }; } // namespace Fortran::semantics #endif // FORTRAN_SEMANTICS_CHECK_OMP_STRUCTURE_H_ diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -789,6 +789,7 @@ switch (beginDir.v) { case llvm::omp::Directive::OMPD_target: + deviceConstructFound_ = true; if (CheckTargetBlockOnlyTeams(block)) { EnterDirectiveNest(TargetBlockOnlyTeams); } @@ -1155,21 +1156,40 @@ const auto &dir{std::get(x.t)}; const auto &spec{std::get(x.t)}; if (const auto *objectList{parser::Unwrap(spec.u)}) { + deviceConstructFound_ = true; CheckSymbolNames(dir.source, *objectList); CheckIsVarPartOfAnotherVar(dir.source, *objectList); CheckThreadprivateOrDeclareTargetVar(*objectList); } else if (const auto *clauseList{ parser::Unwrap(spec.u)}) { + bool toClauseFound{false}, deviceTypeClauseFound{false}; for (const auto &clause : clauseList->v) { - if (const auto *toClause{std::get_if(&clause.u)}) { - CheckSymbolNames(dir.source, toClause->v); - CheckIsVarPartOfAnotherVar(dir.source, toClause->v); - CheckThreadprivateOrDeclareTargetVar(toClause->v); - } else if (const auto *linkClause{ - std::get_if(&clause.u)}) { - CheckSymbolNames(dir.source, linkClause->v); - CheckIsVarPartOfAnotherVar(dir.source, linkClause->v); - CheckThreadprivateOrDeclareTargetVar(linkClause->v); + common::visit( + common::visitors{ + [&](const parser::OmpClause::To &toClause) { + toClauseFound = true; + CheckSymbolNames(dir.source, toClause.v); + CheckIsVarPartOfAnotherVar(dir.source, toClause.v); + CheckThreadprivateOrDeclareTargetVar(toClause.v); + }, + [&](const parser::OmpClause::Link &linkClause) { + CheckSymbolNames(dir.source, linkClause.v); + CheckIsVarPartOfAnotherVar(dir.source, linkClause.v); + CheckThreadprivateOrDeclareTargetVar(linkClause.v); + }, + [&](const parser::OmpClause::DeviceType &deviceTypeClause) { + deviceTypeClauseFound = true; + if (deviceTypeClause.v.v != + parser::OmpDeviceTypeClause::Type::Host) { + deviceConstructFound_ = true; + } + }, + [&](const auto &) {}, + }, + clause.u); + + if (toClauseFound && !deviceTypeClauseFound) { + deviceConstructFound_ = true; } } } @@ -1867,7 +1887,6 @@ // Following clauses do not have a separate node in parse-tree.h. CHECK_SIMPLE_CLAUSE(AcqRel, OMPC_acq_rel) CHECK_SIMPLE_CLAUSE(Acquire, OMPC_acquire) -CHECK_SIMPLE_CLAUSE(AtomicDefaultMemOrder, OMPC_atomic_default_mem_order) CHECK_SIMPLE_CLAUSE(Affinity, OMPC_affinity) CHECK_SIMPLE_CLAUSE(Allocate, OMPC_allocate) CHECK_SIMPLE_CLAUSE(Capture, OMPC_capture) @@ -1877,7 +1896,6 @@ CHECK_SIMPLE_CLAUSE(Detach, OMPC_detach) CHECK_SIMPLE_CLAUSE(DeviceType, OMPC_device_type) CHECK_SIMPLE_CLAUSE(DistSchedule, OMPC_dist_schedule) -CHECK_SIMPLE_CLAUSE(DynamicAllocators, OMPC_dynamic_allocators) CHECK_SIMPLE_CLAUSE(Exclusive, OMPC_exclusive) CHECK_SIMPLE_CLAUSE(Final, OMPC_final) CHECK_SIMPLE_CLAUSE(Flush, OMPC_flush) @@ -1890,7 +1908,6 @@ CHECK_SIMPLE_CLAUSE(Nontemporal, OMPC_nontemporal) CHECK_SIMPLE_CLAUSE(Order, OMPC_order) CHECK_SIMPLE_CLAUSE(Read, OMPC_read) -CHECK_SIMPLE_CLAUSE(ReverseOffload, OMPC_reverse_offload) CHECK_SIMPLE_CLAUSE(Threadprivate, OMPC_threadprivate) CHECK_SIMPLE_CLAUSE(Threads, OMPC_threads) CHECK_SIMPLE_CLAUSE(Inbranch, OMPC_inbranch) @@ -1911,8 +1928,6 @@ CHECK_SIMPLE_CLAUSE(Sizes, OMPC_sizes) CHECK_SIMPLE_CLAUSE(TaskReduction, OMPC_task_reduction) CHECK_SIMPLE_CLAUSE(To, OMPC_to) -CHECK_SIMPLE_CLAUSE(UnifiedAddress, OMPC_unified_address) -CHECK_SIMPLE_CLAUSE(UnifiedSharedMemory, OMPC_unified_shared_memory) CHECK_SIMPLE_CLAUSE(Uniform, OMPC_uniform) CHECK_SIMPLE_CLAUSE(Unknown, OMPC_unknown) CHECK_SIMPLE_CLAUSE(Untied, OMPC_untied) @@ -2836,4 +2851,54 @@ clause.u); } +void OmpStructureChecker::Enter( + const parser::OmpClause::AtomicDefaultMemOrder &x) { + CheckAllowed(llvm::omp::Clause::OMPC_atomic_default_mem_order); +} + +void OmpStructureChecker::Enter(const parser::OmpClause::DynamicAllocators &x) { + CheckAllowedRequiresClause(llvm::omp::Clause::OMPC_dynamic_allocators); +} + +void OmpStructureChecker::Enter(const parser::OmpClause::ReverseOffload &x) { + CheckAllowedRequiresClause(llvm::omp::Clause::OMPC_reverse_offload); +} + +void OmpStructureChecker::Enter(const parser::OmpClause::UnifiedAddress &x) { + CheckAllowedRequiresClause(llvm::omp::Clause::OMPC_unified_address); +} + +void OmpStructureChecker::Enter( + const parser::OmpClause::UnifiedSharedMemory &x) { + CheckAllowedRequiresClause(llvm::omp::Clause::OMPC_unified_shared_memory); +} + +void OmpStructureChecker::CheckAllowedRequiresClause(llvmOmpClause clause) { + CheckAllowed(clause); + + // Check that it does not appear after a device construct + if (deviceConstructFound_) { + context_.Say(GetContext().clauseSource, + "'%s' REQUIRES clause found lexically after device " + "construct"_err_en_US, + parser::ToUpperCaseLetters(getClauseName(clause).str())); + } +} + +void OmpStructureChecker::Enter(const parser::UseStmt &x) { + semantics::Symbol *symbol{x.moduleName.symbol}; + if (!symbol) { + // Cannot check used module if it wasn't resolved. + return; + } + + auto &details = std::get(symbol->details()); + if (details.has_ompRequires() && deviceConstructFound_) { + context_.Say(x.moduleName.source, + "'%s' module containing device-related REQUIRES directive imported " + "lexically after device construct"_err_en_US, + x.moduleName.ToString()); + } +} + } // namespace Fortran::semantics diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp --- a/flang/lib/Semantics/mod-file.cpp +++ b/flang/lib/Semantics/mod-file.cpp @@ -57,6 +57,8 @@ static llvm::raw_ostream &PutAttr(llvm::raw_ostream &, Attr); static llvm::raw_ostream &PutType(llvm::raw_ostream &, const DeclTypeSpec &); static llvm::raw_ostream &PutLower(llvm::raw_ostream &, std::string_view); +static llvm::raw_ostream &PutOmpRequires( + llvm::raw_ostream &, const WithOmpDeclarative &); static std::error_code WriteFile( const std::string &, const std::string &, bool = true); static bool FileContentsMatch( @@ -162,6 +164,7 @@ uses_.str().clear(); all << useExtraAttrs_.str(); useExtraAttrs_.str().clear(); + PutOmpRequires(all, details); all << decls_.str(); decls_.str().clear(); auto str{contains_.str()}; @@ -505,6 +508,8 @@ } } os << '\n'; + // print OpenMP requires + PutOmpRequires(os, details); // walk symbols, collect ones needed for interface const Scope &scope{ details.entryScope() ? *details.entryScope() : DEREF(symbol.scope())}; @@ -874,6 +879,44 @@ return os; } +llvm::raw_ostream &PutOmpRequires( + llvm::raw_ostream &os, const WithOmpDeclarative &details) { + if (details.has_ompRequires() || details.has_ompAtomicDefaultMemOrder()) { + os << "!$omp requires"; + if (auto *flags{details.ompRequires()}) { + if (*flags & OmpRequiresFlags::ReverseOffload) { + os << " reverse_offload"; + } + if (*flags & OmpRequiresFlags::UnifiedAddress) { + os << " unified_address"; + } + if (*flags & OmpRequiresFlags::UnifiedSharedMemory) { + os << " unified_shared_memory"; + } + if (*flags & OmpRequiresFlags::DynamicAllocators) { + os << " dynamic_allocators"; + } + } + if (auto *memOrder{details.ompAtomicDefaultMemOrder()}) { + os << " atomic_default_mem_order("; + switch (*memOrder) { + case parser::OmpAtomicDefaultMemOrderClause::Type::SeqCst: + os << "seq_cst"; + break; + case parser::OmpAtomicDefaultMemOrderClause::Type::AcqRel: + os << "acq_rel"; + break; + case parser::OmpAtomicDefaultMemOrderClause::Type::Relaxed: + os << "relaxed"; + break; + } + os << ')'; + } + os << '\n'; + } + return os; +} + struct Temp { Temp(int fd, std::string path) : fd{fd}, path{path} {} Temp(Temp &&t) : fd{std::exchange(t.fd, -1)}, path{std::move(t.path)} {} diff --git a/flang/lib/Semantics/resolve-directives.h b/flang/lib/Semantics/resolve-directives.h --- a/flang/lib/Semantics/resolve-directives.h +++ b/flang/lib/Semantics/resolve-directives.h @@ -11,6 +11,7 @@ namespace Fortran::parser { struct Name; +struct Program; struct ProgramUnit; } // namespace Fortran::parser @@ -21,6 +22,7 @@ // Name resolution for OpenACC and OpenMP directives void ResolveAccParts(SemanticsContext &, const parser::ProgramUnit &); void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &); +void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &); } // namespace Fortran::semantics #endif diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -21,6 +21,13 @@ #include #include +template +static Fortran::semantics::Scope *GetScope( + Fortran::semantics::SemanticsContext &context, const T &x) { + std::optional source{GetSource(x)}; + return source ? &context.FindScope(*source) : nullptr; +} + namespace Fortran::semantics { template class DirectiveAttributeVisitor { @@ -57,6 +64,12 @@ dirContext_.emplace_back(source, dir, context_.FindScope(source)); } void PopContext() { dirContext_.pop_back(); } + Scope *GetDeclScope() { + CHECK(!declScope_.empty()); + return declScope_.back(); + } + void PushDeclScope(Scope *scope) { declScope_.push_back(scope); } + void PopDeclScope() { declScope_.pop_back(); } void SetContextDirectiveSource(parser::CharBlock &dir) { GetContext().directiveSource = dir; } @@ -108,6 +121,7 @@ UnorderedSymbolSet dataSharingAttributeObjects_; // on one directive SemanticsContext &context_; std::vector dirContext_; // used as a stack + std::vector declScope_; // used as a stack }; class AccAttributeVisitor : DirectiveAttributeVisitor { @@ -274,11 +288,6 @@ return true; } - bool Pre(const parser::SpecificationPart &x) { - Walk(std::get>(x.t)); - return true; - } - bool Pre(const parser::StmtFunctionStmt &x) { const auto &parsedExpr{std::get>(x.t)}; if (const auto *expr{GetExpr(context_, parsedExpr)}) { @@ -326,6 +335,39 @@ bool Pre(const parser::OpenMPRequiresConstruct &x) { PushContext(x.source, llvm::omp::Directive::OMPD_requires); + + // Gather information from the clauses. + OmpRequiresFlags flags{OmpRequiresFlags::None}; + std::optional memOrder; + for (auto &clause : std::get(x.t).v) { + flags |= common::visit( + common::visitors{ + [&memOrder]( + const parser::OmpClause::AtomicDefaultMemOrder &atomic) { + memOrder = atomic.v.v; + return OmpRequiresFlags::None; + }, + [](const parser::OmpClause::ReverseOffload &) { + return OmpRequiresFlags::ReverseOffload; + }, + [](const parser::OmpClause::UnifiedAddress &) { + return OmpRequiresFlags::UnifiedAddress; + }, + [](const parser::OmpClause::UnifiedSharedMemory &) { + return OmpRequiresFlags::UnifiedSharedMemory; + }, + [](const parser::OmpClause::DynamicAllocators &) { + return OmpRequiresFlags::DynamicAllocators; + }, + [&](const auto &) { + context_.Say(x.source, + "Unexpected clause in REQUIRES construct"_err_en_US); + return OmpRequiresFlags::None; + }}, + clause.u); + } + // Merge clauses into parents' symbols details. + AddOmpRequiresToScope(&currScope(), flags, memOrder); return true; } void Post(const parser::OpenMPRequiresConstruct &) { PopContext(); } @@ -333,6 +375,33 @@ bool Pre(const parser::OpenMPDeclareTargetConstruct &); void Post(const parser::OpenMPDeclareTargetConstruct &) { PopContext(); } + void Post(const parser::UseStmt &x) { + Symbol *symbol{x.moduleName.symbol}; + if (!symbol) { + return; + } + + // Gather information from the imported module's symbol details. + OmpRequiresFlags flags{OmpRequiresFlags::None}; + std::optional memOrder; + common::visit( + [&](auto &details) { + if constexpr (std::is_base_of_v>) { + if (details.has_ompRequires()) { + flags = *details.ompRequires(); + } + if (details.has_ompAtomicDefaultMemOrder()) { + memOrder = *details.ompAtomicDefaultMemOrder(); + } + } + }, + symbol->details()); + + // Merge requires clauses into `use` statement parents. + AddOmpRequiresToScope(GetDeclScope(), flags, memOrder); + } + bool Pre(const parser::OpenMPThreadprivate &); void Post(const parser::OpenMPThreadprivate &) { PopContext(); } @@ -451,6 +520,59 @@ void Post(const parser::Name &); + // Keep track of contexts inside of which `SpecificationPart`s can be found + // to allow matching `Use` statements with their parent scopes. + bool Pre(const parser::MainProgram &x) { + PushDeclScope(GetScope(context_, x)); + return true; + } + bool Pre(const parser::Module &x) { + PushDeclScope(GetScope(context_, x)); + return true; + } + bool Pre(const parser::Submodule &x) { + PushDeclScope(GetScope(context_, x)); + return true; + } + bool Pre(const parser::FunctionSubprogram &x) { + PushDeclScope(GetScope(context_, x)); + return true; + } + bool Pre(const parser::InterfaceBody::Function &x) { + PushDeclScope(GetScope(context_, x)); + return true; + } + bool Pre(const parser::SubroutineSubprogram &x) { + PushDeclScope(GetScope(context_, x)); + return true; + } + bool Pre(const parser::InterfaceBody::Subroutine &x) { + PushDeclScope(GetScope(context_, x)); + return true; + } + bool Pre(const parser::SeparateModuleSubprogram &x) { + PushDeclScope(GetScope(context_, x)); + return true; + } + bool Pre(const parser::BlockConstruct &x) { + PushDeclScope(GetScope(context_, x)); + return true; + } + bool Pre(const parser::BlockData &x) { + PushDeclScope(GetScope(context_, x)); + return true; + } + void Post(const parser::MainProgram &) { PopDeclScope(); } + void Post(const parser::Module &) { PopDeclScope(); } + void Post(const parser::Submodule &) { PopDeclScope(); } + void Post(const parser::FunctionSubprogram &) { PopDeclScope(); } + void Post(const parser::InterfaceBody::Function &) { PopDeclScope(); } + void Post(const parser::SubroutineSubprogram &) { PopDeclScope(); } + void Post(const parser::InterfaceBody::Subroutine &) { PopDeclScope(); } + void Post(const parser::SeparateModuleSubprogram &) { PopDeclScope(); } + void Post(const parser::BlockConstruct &) { PopDeclScope(); } + void Post(const parser::BlockData &) { PopDeclScope(); } + // Keep track of labels in the statements that causes jumps to target labels void Post(const parser::GotoStmt &gotoStmt) { CheckSourceLabel(gotoStmt.v); } void Post(const parser::ComputedGotoStmt &computedGotoStmt) { @@ -586,6 +708,9 @@ bool HasSymbolInEnclosingScope(const Symbol &, Scope &); std::int64_t ordCollapseLevel{0}; + + void AddOmpRequiresToScope(Scope *, OmpRequiresFlags, + std::optional); }; template @@ -1875,6 +2000,73 @@ } } +void ResolveOmpTopLevelParts( + SemanticsContext &context, const parser::Program &program) { + if (!context.IsEnabled(common::LanguageFeature::OpenMP)) { + return; + } + + // Gather `requires` clauses from all non-module top-level program unit + // symbols, combine them together ensuring compatibility and apply them to all + // these program units. Modules are skipped because their `requires` clauses + // are propagated via `USE` statements instead. + OmpRequiresFlags combinedFlags{OmpRequiresFlags::None}; + std::optional combinedMemOrder; + + // Function to go through non-module top level program units and extract + // `requires` information to be processed by a function-like argument. + auto processProgramUnits{[&](auto processFn) { + for (const parser::ProgramUnit &unit : program.v) { + if (!std::get_if>(&unit.u) && + !std::get_if>(&unit.u)) { + Symbol *symbol{common::visit( + [&context]( + auto &x) { return GetScope(context, x.value())->symbol(); }, + unit.u)}; + + common::visit( + [&](auto &details) { + if constexpr (std::is_convertible_v) { + processFn(symbol, details); + } + }, + symbol->details()); + } + } + }}; + + // Combine global `requires` information from all program units except + // modules and submodules. + processProgramUnits([&](Symbol *symbol, WithOmpDeclarative &details) { + if (const OmpRequiresFlags * flags{details.ompRequires()}) { + combinedFlags |= *flags; + } + if (const parser::OmpAtomicDefaultMemOrderClause::Type * + memOrder{details.ompAtomicDefaultMemOrder()}) { + if (combinedMemOrder && *combinedMemOrder != *memOrder) { + context.Say(symbol->scope()->sourceRange(), + "Conflicting '%s' REQUIRES clauses found in " + "compilation " + "unit"_err_en_US, + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::Clause::OMPC_atomic_default_mem_order) + .str())); + } + combinedMemOrder = *memOrder; + } + }); + + // Update all program units except modules and submodules with the + // combined global `requires` information. + processProgramUnits([&](Symbol *, WithOmpDeclarative &details) { + details.set_ompRequires(combinedFlags); + if (combinedMemOrder) { + details.set_ompAtomicDefaultMemOrder(*combinedMemOrder); + } + }); +} + void OmpAttributeVisitor::CheckDataCopyingClause( const parser::Name &name, const Symbol &symbol, Symbol::Flag ompFlag) { const auto *checkSymbol{&symbol}; @@ -1987,4 +2179,43 @@ return llvm::is_contained(symbols, symbol); } +void OmpAttributeVisitor::AddOmpRequiresToScope(Scope *scope, + OmpRequiresFlags flags, + std::optional memOrder) { + do { + if (Symbol * symbol{scope->symbol()}) { + common::visit( + [&](auto &details) { + // Store clauses information into the symbol for the parent and + // enclosing modules, programs, functions and subroutines. + if constexpr (std::is_convertible_v) { + if (flags != OmpRequiresFlags::None) { + if (const OmpRequiresFlags * + otherFlags{details.ompRequires()}) { + flags |= *otherFlags; + } + details.set_ompRequires(flags); + } + if (memOrder) { + if (details.has_ompAtomicDefaultMemOrder() && + *details.ompAtomicDefaultMemOrder() != *memOrder) { + context_.Say(scope->sourceRange(), + "Conflicting '%s' REQUIRES clauses found in " + "compilation " + "unit"_err_en_US, + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::Clause::OMPC_atomic_default_mem_order) + .str())); + } + details.set_ompAtomicDefaultMemOrder(*memOrder); + } + } + }, + symbol->details()); + } + scope = &scope->parent(); + } while (!scope->IsGlobal()); +} + } // namespace Fortran::semantics diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -8585,11 +8585,14 @@ } } -void ResolveNamesVisitor::Post(const parser::Program &) { +void ResolveNamesVisitor::Post(const parser::Program &x) { // ensure that all temps were deallocated CHECK(!attrs_); CHECK(!cudaDataAttr_); CHECK(!GetDeclTypeSpec()); + // Top-level resolution to propagate information across program units after + // each of them has been resolved separately. + ResolveOmpTopLevelParts(context(), x); } // A singleton instance of the scope -> IMPLICIT rules mapping is diff --git a/flang/lib/Semantics/resolve-directives.h b/flang/lib/Semantics/rewrite-directives.h copy from flang/lib/Semantics/resolve-directives.h copy to flang/lib/Semantics/rewrite-directives.h --- a/flang/lib/Semantics/resolve-directives.h +++ b/flang/lib/Semantics/rewrite-directives.h @@ -1,4 +1,4 @@ -//===----------------------------------------------------------------------===// +//===-- lib/Semantics/rewrite-directives.h ----------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,21 +6,19 @@ // //===----------------------------------------------------------------------===// -#ifndef FORTRAN_SEMANTICS_RESOLVE_DIRECTIVES_H_ -#define FORTRAN_SEMANTICS_RESOLVE_DIRECTIVES_H_ +#ifndef FORTRAN_SEMANTICS_REWRITE_DIRECTIVES_H_ +#define FORTRAN_SEMANTICS_REWRITE_DIRECTIVES_H_ namespace Fortran::parser { -struct Name; -struct ProgramUnit; +struct Program; } // namespace Fortran::parser namespace Fortran::semantics { - class SemanticsContext; +} // namespace Fortran::semantics -// Name resolution for OpenACC and OpenMP directives -void ResolveAccParts(SemanticsContext &, const parser::ProgramUnit &); -void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &); - +namespace Fortran::semantics { +bool RewriteOmpParts(SemanticsContext &, parser::Program &); } // namespace Fortran::semantics -#endif + +#endif // FORTRAN_SEMANTICS_REWRITE_DIRECTIVES_H_ diff --git a/flang/lib/Semantics/rewrite-directives.cpp b/flang/lib/Semantics/rewrite-directives.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Semantics/rewrite-directives.cpp @@ -0,0 +1,193 @@ +//===-- lib/Semantics/rewrite-directives.cpp ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "rewrite-directives.h" +#include "flang/Parser/parse-tree-visitor.h" +#include "flang/Parser/parse-tree.h" +#include "flang/Semantics/semantics.h" +#include "flang/Semantics/symbol.h" +#include + +namespace Fortran::semantics { + +using namespace parser::literals; + +class DirectiveRewriteMutator { +public: + explicit DirectiveRewriteMutator(SemanticsContext &context) + : context_{context} {} + + // Default action for a parse tree node is to visit children. + template bool Pre(T &) { return true; } + template void Post(T &) {} + +protected: + SemanticsContext &context_; +}; + +// Rewrite atomic constructs to add an explicit memory ordering to all that do +// not specify it, honoring in this way the `atomic_default_mem_order` clause of +// the `requires` directive. +class OmpRewriteMutator : public DirectiveRewriteMutator { +public: + explicit OmpRewriteMutator(SemanticsContext &context) + : DirectiveRewriteMutator(context) {} + + template void Walk(T &x) { parser::Walk(x, *this); } + template bool Pre(T &) { return true; } + template void Post(T &) {} + + bool Pre(parser::OpenMPAtomicConstruct &); + bool Pre(parser::OpenMPRequiresConstruct &); + void Post(parser::OmpClause::AtomicDefaultMemOrder &); + void Post(parser::UseStmt &); + +private: + bool atomicDirectiveDefaultOrderFound_{false}; + parser::CharBlock requiresClauseSource_; +}; + +bool OmpRewriteMutator::Pre(parser::OpenMPAtomicConstruct &x) { + // Find top-level parent of the operation. + Symbol *topLevelParent{common::visit( + [&](auto &atomic) { + Symbol *symbol{nullptr}; + Scope *scope{ + &context_.FindScope(std::get(atomic.t).source)}; + do { + if (Symbol * parent{scope->symbol()}) { + symbol = parent; + } + scope = &scope->parent(); + } while (!scope->IsGlobal()); + + assert(symbol && + "Atomic construct must be within a scope associated with a symbol"); + return symbol; + }, + x.u)}; + + // Get the `atomic_default_mem_order` clause from the top-level parent. + std::optional defaultMemOrder; + common::visit( + [&](auto &details) { + if constexpr (std::is_convertible_v) { + if (details.has_ompAtomicDefaultMemOrder()) { + defaultMemOrder = *details.ompAtomicDefaultMemOrder(); + } + } + }, + topLevelParent->details()); + + if (!defaultMemOrder) { + return false; + } + + auto findMemOrderClause = + [](const std::list &clauses) { + return std::find_if( + clauses.begin(), clauses.end(), [](const auto &clause) { + return std::get_if( + &clause.u); + }) != clauses.end(); + }; + + // Get the clause list to which the new memory order clause must be added, + // only if there are no other memory order clauses present for this atomic + // directive. + std::list *clauseList = common::visit( + common::visitors{[&](parser::OmpAtomic &atomicConstruct) { + // OmpAtomic only has a single list of clauses. + auto &clauses{std::get( + atomicConstruct.t)}; + return !findMemOrderClause(clauses.v) ? &clauses.v + : nullptr; + }, + [&](auto &atomicConstruct) { + // All other atomic constructs have two lists of clauses. + auto &clausesLhs{std::get<0>(atomicConstruct.t)}; + auto &clausesRhs{std::get<2>(atomicConstruct.t)}; + return !findMemOrderClause(clausesLhs.v) && + !findMemOrderClause(clausesRhs.v) + ? &clausesRhs.v + : nullptr; + }}, + x.u); + + // Add a memory order clause to the atomic directive. + if (clauseList) { + atomicDirectiveDefaultOrderFound_ = true; + switch (*defaultMemOrder) { + case parser::OmpAtomicDefaultMemOrderClause::Type::AcqRel: + clauseList->push_back(parser::OmpMemoryOrderClause( + parser::OmpClause(parser::OmpClause::AcqRel()))); + break; + case parser::OmpAtomicDefaultMemOrderClause::Type::Relaxed: + clauseList->push_back(parser::OmpMemoryOrderClause( + parser::OmpClause(parser::OmpClause::Relaxed()))); + break; + case parser::OmpAtomicDefaultMemOrderClause::Type::SeqCst: + clauseList->push_back(parser::OmpMemoryOrderClause( + parser::OmpClause(parser::OmpClause::SeqCst()))); + break; + } + } + + return false; +} + +bool OmpRewriteMutator::Pre(parser::OpenMPRequiresConstruct &x) { + for (parser::OmpClause &clause : std::get(x.t).v) { + requiresClauseSource_ = clause.source; + Walk(clause.u); + requiresClauseSource_ = nullptr; + } + return false; +} + +// Check that the `atomic_default_mem_order` clause does not appear after an +// atomic operation without `memory_order` defined. +void OmpRewriteMutator::Post(parser::OmpClause::AtomicDefaultMemOrder &x) { + if (atomicDirectiveDefaultOrderFound_) { + context_.Say(requiresClauseSource_, + "'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clause found lexically after " + "atomic operation with the 'MEMORY_ORDER' clause not defined"_err_en_US); + } +} + +// Check that a module containing a `requires` statement with the +// `atomic_default_mem_order` clause is not `USE`d after an atomic operation +// without `memory_order` defined. +void OmpRewriteMutator::Post(parser::UseStmt &x) { + semantics::Symbol *symbol{x.moduleName.symbol}; + if (!symbol) { + // Cannot check used module if it wasn't resolved. + return; + } + + auto *details = symbol->detailsIf(); + if (atomicDirectiveDefaultOrderFound_ && details && + details->has_ompAtomicDefaultMemOrder()) { + context_.Say(x.moduleName.source, + "'%s' module containing 'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clause " + "imported lexically after atomic operation with the 'MEMORY_ORDER' " + "clause not defined"_err_en_US, + x.moduleName.ToString()); + } +} + +bool RewriteOmpParts(SemanticsContext &context, parser::Program &program) { + if (!context.IsEnabled(common::LanguageFeature::OpenMP)) { + return true; + } + OmpRewriteMutator{context}.Walk(program); + return !context.AnyFatalError(); +} + +} // namespace Fortran::semantics diff --git a/flang/lib/Semantics/rewrite-parse-tree.cpp b/flang/lib/Semantics/rewrite-parse-tree.cpp --- a/flang/lib/Semantics/rewrite-parse-tree.cpp +++ b/flang/lib/Semantics/rewrite-parse-tree.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "rewrite-parse-tree.h" +#include "rewrite-directives.h" #include "flang/Common/indirection.h" #include "flang/Parser/parse-tree-visitor.h" #include "flang/Parser/parse-tree.h" @@ -175,7 +176,7 @@ bool RewriteParseTree(SemanticsContext &context, parser::Program &program) { RewriteMutator mutator{context}; parser::Walk(program, mutator); - return !context.AnyFatalError(); + return !context.AnyFatalError() && RewriteOmpParts(context, program); } } // namespace Fortran::semantics diff --git a/flang/test/Semantics/Inputs/requires_module.f90 b/flang/test/Semantics/Inputs/requires_module.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/Inputs/requires_module.f90 @@ -0,0 +1,3 @@ +module requires_module + !$omp requires atomic_default_mem_order(seq_cst), unified_shared_memory +end module diff --git a/flang/test/Semantics/OpenMP/requires-rewrite.f90 b/flang/test/Semantics/OpenMP/requires-rewrite.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/OpenMP/requires-rewrite.f90 @@ -0,0 +1,76 @@ +! RUN: %flang_fc1 -fopenmp -fdebug-dump-parse-tree %s 2>&1 | FileCheck %s +! Ensure that requires atomic_default_mem_order is used to update atomic +! operations with no explicit memory order set. +program requires + implicit none + !$omp requires atomic_default_mem_order(seq_cst) + integer :: i, j, k + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomic + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic relaxed + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomic + ! CHECK: OmpMemoryOrderClause -> OmpClause -> SeqCst + !$omp atomic + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicUpdate + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic relaxed update + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicUpdate + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic update relaxed + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicUpdate + ! CHECK: OmpMemoryOrderClause -> OmpClause -> SeqCst + !$omp atomic update + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicCapture + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic relaxed capture + i = j + j = k + !$omp end atomic + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicCapture + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic capture relaxed + i = j + j = k + !$omp end atomic + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicCapture + ! CHECK: OmpMemoryOrderClause -> OmpClause -> SeqCst + !$omp atomic capture + i = j + j = k + !$omp end atomic + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicWrite + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic relaxed write + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicWrite + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic write relaxed + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicWrite + ! CHECK: OmpMemoryOrderClause -> OmpClause -> SeqCst + !$omp atomic write + i = j +end program requires diff --git a/flang/test/Semantics/OpenMP/requires01.f90 b/flang/test/Semantics/OpenMP/requires01.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/OpenMP/requires01.f90 @@ -0,0 +1,14 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp +! OpenMP Version 5.0 +! 2.4 Requires directive +! All atomic_default_mem_order clauses in 'requires' directives found within a +! compilation unit must specify the same ordering. + +subroutine f + !$omp requires atomic_default_mem_order(seq_cst) +end subroutine f + +!ERROR: Conflicting 'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clauses found in compilation unit +subroutine g + !$omp requires atomic_default_mem_order(relaxed) +end subroutine g diff --git a/flang/test/Semantics/OpenMP/requires02.f90 b/flang/test/Semantics/OpenMP/requires02.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/OpenMP/requires02.f90 @@ -0,0 +1,17 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp +! OpenMP Version 5.0 +! 2.4 Requires directive +! All atomic_default_mem_order clauses in 'requires' directives must come +! strictly before any atomic directives on which the memory_order clause is not +! specified. + +subroutine f + integer :: a = 0 + !$omp atomic + a = a + 1 +end subroutine f + +subroutine g + !ERROR: 'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clause found lexically after atomic operation with the 'MEMORY_ORDER' clause not defined + !$omp requires atomic_default_mem_order(relaxed) +end subroutine g diff --git a/flang/test/Semantics/OpenMP/requires03.f90 b/flang/test/Semantics/OpenMP/requires03.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/OpenMP/requires03.f90 @@ -0,0 +1,21 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp +! OpenMP Version 5.0 +! 2.4 Requires directive +! Target-related clauses in 'requires' directives must come strictly before any +! device constructs, such as target regions. + +subroutine f + !$omp target + !$omp end target +end subroutine f + +subroutine g + !ERROR: 'DYNAMIC_ALLOCATORS' REQUIRES clause found lexically after device construct + !$omp requires dynamic_allocators + !ERROR: 'REVERSE_OFFLOAD' REQUIRES clause found lexically after device construct + !$omp requires reverse_offload + !ERROR: 'UNIFIED_ADDRESS' REQUIRES clause found lexically after device construct + !$omp requires unified_address + !ERROR: 'UNIFIED_SHARED_MEMORY' REQUIRES clause found lexically after device construct + !$omp requires unified_shared_memory +end subroutine g diff --git a/flang/test/Semantics/OpenMP/requires04.f90 b/flang/test/Semantics/OpenMP/requires04.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/OpenMP/requires04.f90 @@ -0,0 +1,20 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp +! OpenMP Version 5.0 +! 2.4 Requires directive +! Target-related clauses in 'requires' directives must come strictly before any +! device constructs, such as declare target with device_type=nohost|any. + +subroutine f + !$omp declare target device_type(nohost) +end subroutine f + +subroutine g + !ERROR: 'DYNAMIC_ALLOCATORS' REQUIRES clause found lexically after device construct + !$omp requires dynamic_allocators + !ERROR: 'REVERSE_OFFLOAD' REQUIRES clause found lexically after device construct + !$omp requires reverse_offload + !ERROR: 'UNIFIED_ADDRESS' REQUIRES clause found lexically after device construct + !$omp requires unified_address + !ERROR: 'UNIFIED_SHARED_MEMORY' REQUIRES clause found lexically after device construct + !$omp requires unified_shared_memory +end subroutine g diff --git a/flang/test/Semantics/OpenMP/requires05.f90 b/flang/test/Semantics/OpenMP/requires05.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/OpenMP/requires05.f90 @@ -0,0 +1,20 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp +! OpenMP Version 5.0 +! 2.4 Requires directive +! Target-related clauses in 'requires' directives must come strictly before any +! device constructs, such as declare target with 'to' clause and no device_type. + +subroutine f + !$omp declare target to(f) +end subroutine f + +subroutine g + !ERROR: 'DYNAMIC_ALLOCATORS' REQUIRES clause found lexically after device construct + !$omp requires dynamic_allocators + !ERROR: 'REVERSE_OFFLOAD' REQUIRES clause found lexically after device construct + !$omp requires reverse_offload + !ERROR: 'UNIFIED_ADDRESS' REQUIRES clause found lexically after device construct + !$omp requires unified_address + !ERROR: 'UNIFIED_SHARED_MEMORY' REQUIRES clause found lexically after device construct + !$omp requires unified_shared_memory +end subroutine g diff --git a/flang/test/Semantics/OpenMP/requires06.f90 b/flang/test/Semantics/OpenMP/requires06.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/OpenMP/requires06.f90 @@ -0,0 +1,20 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp +! OpenMP Version 5.0 +! 2.4 Requires directive +! Target-related clauses in 'requires' directives must come strictly before any +! device constructs, such as declare target with extended list. + +subroutine f + !$omp declare target (f) +end subroutine f + +subroutine g + !ERROR: 'DYNAMIC_ALLOCATORS' REQUIRES clause found lexically after device construct + !$omp requires dynamic_allocators + !ERROR: 'REVERSE_OFFLOAD' REQUIRES clause found lexically after device construct + !$omp requires reverse_offload + !ERROR: 'UNIFIED_ADDRESS' REQUIRES clause found lexically after device construct + !$omp requires unified_address + !ERROR: 'UNIFIED_SHARED_MEMORY' REQUIRES clause found lexically after device construct + !$omp requires unified_shared_memory +end subroutine g diff --git a/flang/test/Semantics/OpenMP/requires07.f90 b/flang/test/Semantics/OpenMP/requires07.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/OpenMP/requires07.f90 @@ -0,0 +1,13 @@ +! RUN: rm -rf %t && mkdir %t +! RUN: %flang_fc1 -fsyntax-only -fopenmp -module-dir %t '%S/../Inputs/requires_module.f90' +! RUN: %python %S/../test_errors.py %s %flang -fopenmp -module-dir %t +! OpenMP Version 5.0 +! 2.4 Requires directive +! All atomic_default_mem_order clauses in 'requires' directives found within a +! compilation unit must specify the same ordering. Test that this is propagated +! from imported modules + +!ERROR: Conflicting 'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clauses found in compilation unit +use requires_module +!$omp requires atomic_default_mem_order(relaxed) +end program diff --git a/flang/test/Semantics/OpenMP/requires08.f90 b/flang/test/Semantics/OpenMP/requires08.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/OpenMP/requires08.f90 @@ -0,0 +1,17 @@ +! RUN: rm -rf %t && mkdir %t +! RUN: %flang_fc1 -fsyntax-only -fopenmp -module-dir %t '%S/../Inputs/requires_module.f90' +! RUN: %python %S/../test_errors.py %s %flang -fopenmp -module-dir %t +! OpenMP Version 5.0 +! 2.4 Requires directive +! Target-related clauses in 'requires' directives must come strictly before any +! device constructs, such as declare target with extended list. Test that this +! is propagated from imported modules. + +subroutine f + !$omp declare target (f) +end subroutine f + +program requires + !ERROR: 'requires_module' module containing device-related REQUIRES directive imported lexically after device construct + use requires_module +end program requires