diff --git a/flang/lib/Semantics/CMakeLists.txt b/flang/lib/Semantics/CMakeLists.txt --- a/flang/lib/Semantics/CMakeLists.txt +++ b/flang/lib/Semantics/CMakeLists.txt @@ -36,6 +36,7 @@ resolve-directives.cpp resolve-names-utils.cpp resolve-names.cpp + rewrite-directives.cpp rewrite-parse-tree.cpp runtime-type-info.cpp scope.cpp diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h --- a/flang/lib/Semantics/check-omp-structure.h +++ b/flang/lib/Semantics/check-omp-structure.h @@ -208,7 +208,6 @@ void CheckAllowedRequiresClause(llvmOmpClause clause); bool deviceConstructFound_{false}; - bool atomicDirectiveDefaultOrderFound_{false}; void EnterDirectiveNest(const int index) { directiveNest_[index]++; } void ExitDirectiveNest(const int index) { directiveNest_[index]--; } diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -1760,9 +1760,6 @@ if (rightHandClauseList) { checkForValidMemoryOrderClause(rightHandClauseList); } - if (numMemoryOrderClause == 0) { - atomicDirectiveDefaultOrderFound_ = true; - } } void OmpStructureChecker::Enter(const parser::OpenMPAtomicConstruct &x) { @@ -3024,16 +3021,7 @@ void OmpStructureChecker::CheckAllowedRequiresClause(llvmOmpClause clause) { CheckAllowed(clause); - if (clause == llvm::omp::Clause::OMPC_atomic_default_mem_order) { - // Check that it does not appear after an atomic operation without memory - // order - if (atomicDirectiveDefaultOrderFound_) { - context_.Say(GetContext().clauseSource, - "REQUIRES directive with '%s' clause found lexically after atomic " - "operation without a memory order clause"_err_en_US, - parser::ToUpperCaseLetters(getClauseName(clause).str())); - } - } else { + if (clause != llvm::omp::Clause::OMPC_atomic_default_mem_order) { // Check that it does not appear after a device construct if (deviceConstructFound_) { context_.Say(GetContext().clauseSource, diff --git a/flang/lib/Semantics/rewrite-directives.h b/flang/lib/Semantics/rewrite-directives.h new file mode 100644 --- /dev/null +++ b/flang/lib/Semantics/rewrite-directives.h @@ -0,0 +1,24 @@ +//===-- lib/Semantics/rewrite-directives.h ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_SEMANTICS_REWRITE_DIRECTIVES_H_ +#define FORTRAN_SEMANTICS_REWRITE_DIRECTIVES_H_ + +namespace Fortran::parser { +struct Program; +} // namespace Fortran::parser + +namespace Fortran::semantics { +class SemanticsContext; +} // namespace Fortran::semantics + +namespace Fortran::semantics { +bool RewriteOmpParts(SemanticsContext &, parser::Program &); +} // namespace Fortran::semantics + +#endif // FORTRAN_SEMANTICS_REWRITE_DIRECTIVES_H_ diff --git a/flang/lib/Semantics/rewrite-directives.cpp b/flang/lib/Semantics/rewrite-directives.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Semantics/rewrite-directives.cpp @@ -0,0 +1,174 @@ +//===-- lib/Semantics/rewrite-directives.cpp ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "rewrite-directives.h" +#include "flang/Parser/parse-tree-visitor.h" +#include "flang/Parser/parse-tree.h" +#include "flang/Semantics/semantics.h" +#include "flang/Semantics/symbol.h" +#include +#include + +namespace Fortran::semantics { + +using namespace parser::literals; + +class DirectiveRewriteMutator { +public: + explicit DirectiveRewriteMutator(SemanticsContext &context) + : context_{context} {} + + // Default action for a parse tree node is to visit children. + template bool Pre(T &) { return true; } + template void Post(T &) {} + +protected: + SemanticsContext &context_; +}; + +// Rewrite atomic constructs to add an explicit memory ordering to all that do +// not specify it, honoring in this way the `atomic_default_mem_order` clause of +// the REQUIRES directive. +class OmpRewriteMutator : public DirectiveRewriteMutator { +public: + explicit OmpRewriteMutator(SemanticsContext &context) + : DirectiveRewriteMutator(context) {} + + template void Walk(T &x) { parser::Walk(x, *this); } + template bool Pre(T &) { return true; } + template void Post(T &) {} + + bool Pre(parser::OpenMPAtomicConstruct &); + bool Pre(parser::OpenMPRequiresConstruct &); + void Post(parser::OmpAtomicDefaultMemOrderClause &); + +private: + parser::CharBlock requiresClauseSource_{nullptr}; + bool atomicDirectiveDefaultOrderFound_{false}; +}; + +bool OmpRewriteMutator::Pre(parser::OpenMPAtomicConstruct &x) { + // Find top-level parent of the operation. + Symbol *topLevelParent{common::visit( + [&](auto &atomic) { + Symbol *symbol{nullptr}; + Scope *scope{ + &context_.FindScope(std::get(atomic.t).source)}; + do { + if (Symbol * parent{scope->symbol()}) { + symbol = parent; + } + scope = &scope->parent(); + } while (!scope->IsGlobal()); + + assert(symbol && + "Atomic construct must be within a scope associated with a symbol"); + return symbol; + }, + x.u)}; + + // Get the `atomic_default_mem_order` clause from the top-level parent. + std::optional defaultMemOrder; + common::visit( + [&](auto &details) { + if constexpr (std::is_convertible_v) { + if (details.has_ompAtomicDefaultMemOrder()) { + defaultMemOrder = *details.ompAtomicDefaultMemOrder(); + } + } + }, + topLevelParent->details()); + + if (!defaultMemOrder) { + return false; + } + + auto findMemOrderClause = + [](const std::list &clauses) { + return std::find_if( + clauses.begin(), clauses.end(), [](const auto &clause) { + return std::get_if( + &clause.u); + }) != clauses.end(); + }; + + // Get the clause list to which the new memory order clause must be added, + // only if there are no other memory order clauses present for this atomic + // directive. + std::list *clauseList = common::visit( + common::visitors{[&](parser::OmpAtomic &atomicConstruct) { + // OmpAtomic only has a single list of clauses. + auto &clauses{std::get( + atomicConstruct.t)}; + return !findMemOrderClause(clauses.v) ? &clauses.v + : nullptr; + }, + [&](auto &atomicConstruct) { + // All other atomic constructs have two lists of clauses. + auto &clausesLhs{std::get<0>(atomicConstruct.t)}; + auto &clausesRhs{std::get<2>(atomicConstruct.t)}; + return !findMemOrderClause(clausesLhs.v) && + !findMemOrderClause(clausesRhs.v) + ? &clausesRhs.v + : nullptr; + }}, + x.u); + + // Add a memory order clause to the atomic directive. + if (clauseList) { + atomicDirectiveDefaultOrderFound_ = true; + switch (*defaultMemOrder) { + case parser::OmpAtomicDefaultMemOrderClause::Type::AcqRel: + clauseList->push_back(parser::OmpMemoryOrderClause( + parser::OmpClause(parser::OmpClause::AcqRel()))); + break; + case parser::OmpAtomicDefaultMemOrderClause::Type::Relaxed: + clauseList->push_back(parser::OmpMemoryOrderClause( + parser::OmpClause(parser::OmpClause::Relaxed()))); + break; + case parser::OmpAtomicDefaultMemOrderClause::Type::SeqCst: + clauseList->push_back(parser::OmpMemoryOrderClause( + parser::OmpClause(parser::OmpClause::SeqCst()))); + break; + } + } + + return false; +} + +bool OmpRewriteMutator::Pre(parser::OpenMPRequiresConstruct &x) { + for (parser::OmpClause &clause : std::get(x.t).v) { + // Store source in case we need to emit an error + requiresClauseSource_ = clause.source; + Walk(clause.u); + requiresClauseSource_ = nullptr; + } + return false; +} + +void OmpRewriteMutator::Post(parser::OmpAtomicDefaultMemOrderClause &x) { + if (atomicDirectiveDefaultOrderFound_) { + context_.Say(requiresClauseSource_, + "REQUIRES directive with '%s' clause found lexically after atomic " + "operation without a memory order clause"_err_en_US, + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( + llvm::omp::OMPC_atomic_default_mem_order) + .str())); + } +} + +bool RewriteOmpParts(SemanticsContext &context, parser::Program &program) { + if (!context.IsEnabled(common::LanguageFeature::OpenMP)) { + return true; + } + OmpRewriteMutator{context}.Walk(program); + return !context.AnyFatalError(); +} + +} // namespace Fortran::semantics diff --git a/flang/lib/Semantics/rewrite-parse-tree.cpp b/flang/lib/Semantics/rewrite-parse-tree.cpp --- a/flang/lib/Semantics/rewrite-parse-tree.cpp +++ b/flang/lib/Semantics/rewrite-parse-tree.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "rewrite-parse-tree.h" +#include "rewrite-directives.h" #include "flang/Common/indirection.h" #include "flang/Parser/parse-tree-visitor.h" #include "flang/Parser/parse-tree.h" @@ -175,7 +176,7 @@ bool RewriteParseTree(SemanticsContext &context, parser::Program &program) { RewriteMutator mutator{context}; parser::Walk(program, mutator); - return !context.AnyFatalError(); + return !context.AnyFatalError() && RewriteOmpParts(context, program); } } // namespace Fortran::semantics diff --git a/flang/test/Semantics/OpenMP/requires-rewrite.f90 b/flang/test/Semantics/OpenMP/requires-rewrite.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/OpenMP/requires-rewrite.f90 @@ -0,0 +1,76 @@ +! RUN: %flang_fc1 -fopenmp -fdebug-dump-parse-tree %s 2>&1 | FileCheck %s +! Ensure that requires atomic_default_mem_order is used to update atomic +! operations with no explicit memory order set. +program requires + implicit none + !$omp requires atomic_default_mem_order(seq_cst) + integer :: i, j, k + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomic + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic relaxed + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomic + ! CHECK: OmpMemoryOrderClause -> OmpClause -> SeqCst + !$omp atomic + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicUpdate + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic relaxed update + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicUpdate + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic update relaxed + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicUpdate + ! CHECK: OmpMemoryOrderClause -> OmpClause -> SeqCst + !$omp atomic update + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicCapture + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic relaxed capture + i = j + j = k + !$omp end atomic + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicCapture + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic capture relaxed + i = j + j = k + !$omp end atomic + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicCapture + ! CHECK: OmpMemoryOrderClause -> OmpClause -> SeqCst + !$omp atomic capture + i = j + j = k + !$omp end atomic + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicWrite + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic relaxed write + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicWrite + ! CHECK-NOT: OmpMemoryOrderClause -> OmpClause -> SeqCst + ! CHECK: OmpMemoryOrderClause -> OmpClause -> Relaxed + !$omp atomic write relaxed + i = j + + ! CHECK-LABEL: OpenMPAtomicConstruct -> OmpAtomicWrite + ! CHECK: OmpMemoryOrderClause -> OmpClause -> SeqCst + !$omp atomic write + i = j +end program requires