diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h --- a/flang/lib/Semantics/check-omp-structure.h +++ b/flang/lib/Semantics/check-omp-structure.h @@ -153,6 +153,13 @@ #define GEN_FLANG_CLAUSE_CHECK_ENTER #include "llvm/Frontend/OpenMP/OMP.inc" + // Get the OpenMP Clause Kind for the corresponding Parser class + template + llvm::omp::Clause GetClauseKindForParserClass(const A &) { +#define GEN_FLANG_CLAUSE_PARSER_KIND_MAP +#include "llvm/Frontend/OpenMP/OMP.inc" + } + private: bool HasInvalidWorksharingNesting( const parser::CharBlock &, const OmpDirectiveSet &); @@ -197,6 +204,7 @@ const parser::Name &name, const llvm::omp::Clause clause); void CheckMultipleAppearanceAcrossContext( const parser::OmpObjectList &ompObjectList); + const parser::OmpObjectList *GetOmpObjectList(const parser::OmpClause &); }; } // namespace Fortran::semantics #endif // FORTRAN_SEMANTICS_CHECK_OMP_STRUCTURE_H_ diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -840,7 +840,6 @@ void OmpStructureChecker::CheckMultipleAppearanceAcrossContext( const parser::OmpObjectList &redObjectList) { - const parser::OmpObjectList *objList{nullptr}; // TODO: Verify the assumption here that the immediately enclosing region is // the parallel region to which the worksharing construct having reduction // binds to. @@ -848,43 +847,29 @@ for (auto it : enclosingContext->clauseInfo) { llvmOmpClause type = it.first; const auto *clause = it.second; - if (type == llvm::omp::Clause::OMPC_private) { - const auto &pClause{std::get(clause->u)}; - objList = &pClause.v; - } else if (type == llvm::omp::Clause::OMPC_firstprivate) { - const auto &fpClause{ - std::get(clause->u)}; - objList = &fpClause.v; - } else if (type == llvm::omp::Clause::OMPC_lastprivate) { - const auto &lpClause{ - std::get(clause->u)}; - objList = &lpClause.v; - } else if (type == llvm::omp::Clause::OMPC_reduction) { - const auto &rClause{std::get(clause->u)}; - const auto &olist{std::get<1>(rClause.v.t)}; - objList = &olist; - } - if (objList) { - for (const auto &ompObject : objList->v) { - if (const auto *name{parser::Unwrap(ompObject)}) { - if (const auto *symbol{name->symbol}) { - for (const auto &redOmpObject : redObjectList.v) { - if (const auto *rname{ - parser::Unwrap(redOmpObject)}) { - if (const auto *rsymbol{rname->symbol}) { - if (rsymbol->name() == symbol->name()) { - context_.Say(GetContext().clauseSource, - "%s variable '%s' is %s in outer context must" - " be shared in the parallel regions to which any" - " of the worksharing regions arising from the " - "worksharing" - " construct bind."_err_en_US, - parser::ToUpperCaseLetters( - getClauseName(llvm::omp::Clause::OMPC_reduction) - .str()), - symbol->name(), - parser::ToUpperCaseLetters( - getClauseName(type).str())); + if (llvm::omp::privateReductionSet.test(type)) { + if (const auto *objList{GetOmpObjectList(*clause)}) { + for (const auto &ompObject : objList->v) { + if (const auto *name{parser::Unwrap(ompObject)}) { + if (const auto *symbol{name->symbol}) { + for (const auto &redOmpObject : redObjectList.v) { + if (const auto *rname{ + parser::Unwrap(redOmpObject)}) { + if (const auto *rsymbol{rname->symbol}) { + if (rsymbol->name() == symbol->name()) { + context_.Say(GetContext().clauseSource, + "%s variable '%s' is %s in outer context must" + " be shared in the parallel regions to which any" + " of the worksharing regions arising from the " + "worksharing" + " construct bind."_err_en_US, + parser::ToUpperCaseLetters( + getClauseName(llvm::omp::Clause::OMPC_reduction) + .str()), + symbol->name(), + parser::ToUpperCaseLetters( + getClauseName(type).str())); + } } } } @@ -1213,7 +1198,7 @@ DirectivesClauseTriple dirClauseTriple; SymbolSourceMap currSymbols; GetSymbolsInObjectList(x.v, currSymbols); - CheckDefinableObjects(currSymbols, llvm::omp::Clause::OMPC_lastprivate); + CheckDefinableObjects(currSymbols, GetClauseKindForParserClass(x)); // Check lastprivate variables in worksharing constructs dirClauseTriple.emplace(llvm::omp::Directive::OMPD_do, @@ -1224,7 +1209,7 @@ llvm::omp::Directive::OMPD_parallel, llvm::omp::privateReductionSet)); CheckPrivateSymbolsInOuterCxt( - currSymbols, dirClauseTriple, llvm::omp::Clause::OMPC_lastprivate); + currSymbols, dirClauseTriple, GetClauseKindForParserClass(x)); } llvm::StringRef OmpStructureChecker::getClauseName(llvm::omp::Clause clause) { @@ -1368,40 +1353,11 @@ if (auto *enclosingContext{GetEnclosingContextWithDir(enclosingDir)}) { for (auto it{enclosingContext->clauseInfo.begin()}; it != enclosingContext->clauseInfo.end(); ++it) { - // TODO: Replace the hard-coded clause names by using autogen checks or - // a function which maps parser::OmpClause:: to the corresponding - // llvm::omp::Clause::OMPC_ - std::visit(common::visitors{ - [&](const parser::OmpClause::Private &x) { - if (enclosingClauseSet.test( - llvm::omp::Clause::OMPC_private)) { - GetSymbolsInObjectList(x.v, enclosingSymbols); - } - }, - [&](const parser::OmpClause::Firstprivate &x) { - if (enclosingClauseSet.test( - llvm::omp::Clause::OMPC_firstprivate)) { - GetSymbolsInObjectList(x.v, enclosingSymbols); - } - }, - [&](const parser::OmpClause::Lastprivate &x) { - if (enclosingClauseSet.test( - llvm::omp::Clause::OMPC_lastprivate)) { - GetSymbolsInObjectList(x.v, enclosingSymbols); - } - }, - [&](const parser::OmpClause::Reduction &x) { - if (enclosingClauseSet.test( - llvm::omp::Clause::OMPC_reduction)) { - const auto &ompObjectList{ - std::get(x.v.t)}; - GetSymbolsInObjectList( - ompObjectList, enclosingSymbols); - } - }, - [&](const auto &) {}, - }, - it->second->u); + if (enclosingClauseSet.test(it->first)) { + if (const auto *ompObjectList{GetOmpObjectList(*it->second)}) { + GetSymbolsInObjectList(*ompObjectList, enclosingSymbols); + } + } } // Check if the symbols in current context are private in outer context @@ -1497,4 +1453,37 @@ } } +const parser::OmpObjectList *OmpStructureChecker::GetOmpObjectList( + const parser::OmpClause &clause) { + + // Clauses with OmpObjectList as its data member + using MemberObjectListClauses = std::tuple; + + // Clauses with OmpObjectList in the tuple + using TupleObjectListClauses = std::tuple; + + // TODO:: Generate the tuples using TableGen. + // Handle other constructs with OmpObjectList such as OpenMPThreadprivate. + return std::visit( + common::visitors{ + [&](const auto &x) -> const parser::OmpObjectList * { + using Ty = std::decay_t; + if constexpr (common::HasMember) { + return &x.v; + } else if constexpr (common::HasMember) { + return &(std::get(x.v.t)); + } else { + return nullptr; + } + }, + }, + clause.u); +} + } // namespace Fortran::semantics diff --git a/llvm/test/TableGen/directive1.td b/llvm/test/TableGen/directive1.td --- a/llvm/test/TableGen/directive1.td +++ b/llvm/test/TableGen/directive1.td @@ -256,3 +256,23 @@ // GEN-NEXT: } // GEN-EMPTY: // GEN-NEXT: #endif // GEN_FLANG_CLAUSE_UNPARSE +// GEN-EMPTY: +// GEN-NEXT: #ifdef GEN_FLANG_CLAUSE_CHECK_ENTER +// GEN-NEXT: #undef GEN_FLANG_CLAUSE_CHECK_ENTER +// GEN-EMPTY: +// GEN-NEXT: void Enter(const parser::TdlClause::Clausea &); +// GEN-NEXT: void Enter(const parser::TdlClause::Clauseb &); +// GEN-EMPTY: +// GEN-NEXT: #endif // GEN_FLANG_CLAUSE_CHECK_ENTER +// GEN-EMPTY: +// GEN-NEXT: #ifdef GEN_FLANG_CLAUSE_PARSER_KIND_MAP +// GEN-NEXT: #undef GEN_FLANG_CLAUSE_PARSER_KIND_MAP +// GEN-EMPTY: +// GEN-NEXT: if constexpr (std::is_same_v) +// GEN-NEXT: return llvm::tdl::Clause::TDLC_clausea; +// GEN-NEXT: if constexpr (std::is_same_v) +// GEN-NEXT: return llvm::tdl::Clause::TDLC_clauseb; +// GEN-NEXT: llvm_unreachable("Invalid Tdl Parser clause"); +// GEN-EMPTY: +// GEN-NEXT: #endif // GEN_FLANG_CLAUSE_PARSER_KIND_MAP +// GEN-EMPTY: diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp --- a/llvm/utils/TableGen/DirectiveEmitter.cpp +++ b/llvm/utils/TableGen/DirectiveEmitter.cpp @@ -647,6 +647,29 @@ } } +// Generate the mapping for clauses between the parser class and the +// corresponding clause Kind +void GenerateFlangClauseParserKindMap(const DirectiveLanguage &DirLang, + raw_ostream &OS) { + + IfDefScope Scope("GEN_FLANG_CLAUSE_PARSER_KIND_MAP", OS); + + OS << "\n"; + for (const auto &C : DirLang.getClauses()) { + Clause Clause{C}; + OS << "if constexpr (std::is_same_v)\n"; + OS << " return llvm::" << DirLang.getCppNamespace() + << "::Clause::" << DirLang.getClausePrefix() << Clause.getFormattedName() + << ";\n"; + } + + OS << "llvm_unreachable(\"Invalid " << DirLang.getName() + << " Parser clause\");\n"; +} + // Generate the implementation section for the enumeration in the directive // language void EmitDirectivesFlangImpl(const DirectiveLanguage &DirLang, @@ -665,6 +688,8 @@ GenerateFlangClauseUnparse(DirLang, OS); GenerateFlangClauseCheckPrototypes(DirLang, OS); + + GenerateFlangClauseParserKindMap(DirLang, OS); } void GenerateClauseClassMacro(const DirectiveLanguage &DirLang,