diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -8612,6 +8612,80 @@ } }; +/// This represents 'when' clause in the '#pragma omp metadirective' +/// directive. +/// +/// \code +/// #pragma omp metadirective when(user={condition(N<10)}: parallel) +/// \endcode +/// In this example directive '#pragma omp metadirective' has simple 'when' +/// clause with user defined condition. +class OMPTraitInfo; +class OMPWhenClause final : public OMPClause { + friend class OMPClauseReader; + + OMPTraitInfo *TI; + Stmt *Directive; + + /// Location of '('. + SourceLocation LParenLoc; + +public: + /// Build 'when' clause with arguments \a T for traits, \a D for the + /// associated directive. + /// + /// \param T TraitInfo containing information about the context selector + /// \param D The statement associated with the when clause + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + OMPWhenClause(OMPTraitInfo &T, Stmt *D, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_when, StartLoc, EndLoc), TI(&T), Directive(D), + LParenLoc(LParenLoc) {} + + /// Build an empty clause. + OMPWhenClause() + : OMPClause(llvm::omp::OMPC_when, SourceLocation(), SourceLocation()) {} + + /// Sets the location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + /// Returns the associated OpenMP directive. + Stmt *getDirective() const { return Directive; } + + /// Set the associated OpenMP directive. + void setDirective(Stmt *S) { Directive = S; } + + /// Returns the OMPTraitInfo + OMPTraitInfo &getTraitInfo() const { return *TI; } + + /// Set the OMPTraitInfo + void setTraitInfo(OMPTraitInfo *T) { TI = T; } + + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + + const_child_range children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_when; + } +}; + /// This class implements a simple visitor for OMPClause /// subclasses. template class Ptr, typename RetTy> diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3700,6 +3700,18 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPWhenClause(OMPWhenClause *C) { + for (const OMPTraitSet &Set : C->getTraitInfo().Sets) { + for (const OMPTraitSelector &Selector : Set.Selectors) { + if (Selector.Kind == llvm::omp::TraitSelector::user_condition && + Selector.ScoreOrCondition) + TRY_TO(TraverseStmt(Selector.ScoreOrCondition)); + } + } + return true; +} + // FIXME: look at the following tricky-seeming exprs to see if we // need to recurse on anything. These are ones that have methods // returning decls or qualtypes or nestednamespecifier -- though I'm diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -5456,7 +5456,6 @@ class OMPMetaDirective final : public OMPExecutableDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; - Stmt *IfStmt; OMPMetaDirective(SourceLocation StartLoc, SourceLocation EndLoc) : OMPExecutableDirective(OMPMetaDirectiveClass, @@ -5467,16 +5466,13 @@ llvm::omp::OMPD_metadirective, SourceLocation(), SourceLocation()) {} - void setIfStmt(Stmt *S) { IfStmt = S; } - public: static OMPMetaDirective *Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, - ArrayRef Clauses, - Stmt *AssociatedStmt, Stmt *IfStmt); + ArrayRef Clauses); + static OMPMetaDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses, EmptyShell); - Stmt *getIfStmt() const { return IfStmt; } static bool classof(const Stmt *T) { return T->getStmtClass() == OMPMetaDirectiveClass; diff --git a/clang/include/clang/Lex/Preprocessor.h b/clang/include/clang/Lex/Preprocessor.h --- a/clang/include/clang/Lex/Preprocessor.h +++ b/clang/include/clang/Lex/Preprocessor.h @@ -899,6 +899,7 @@ /// Cached tokens are stored here when we do backtracking or /// lookahead. They are "lexed" by the CachingLex() method. CachedTokensTy CachedTokens; + CachedTokensTy UnannotatedCachedTokens; /// The position of the cached token that CachingLex() should /// "lex" next. @@ -1458,7 +1459,7 @@ /// Make Preprocessor re-lex the tokens that were lexed since /// EnableBacktrackAtThisPos() was previously called. - void Backtrack(); + void Backtrack(bool RevertAnnotations); /// True if EnableBacktrackAtThisPos() was called and /// caching of tokens is on. diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h --- a/clang/include/clang/Parse/Parser.h +++ b/clang/include/clang/Parse/Parser.h @@ -964,9 +964,9 @@ P.PP.CommitBacktrackedTokens(); isActive = false; } - void Revert() { + void Revert(bool ReverseAnnotations = false) { assert(isActive && "Parsing action was finished!"); - P.PP.Backtrack(); + P.PP.Backtrack(ReverseAnnotations); P.PreferredType = PrevPreferredType; P.Tok = PrevTok; P.TentativelyDeclaredIdentifiers.resize( diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -10543,7 +10543,7 @@ /// Called on well-formed '\#pragma omp metadirective' after parsing /// of the associated statement. StmtResult ActOnOpenMPMetaDirective(ArrayRef Clauses, - Stmt *AStmt, SourceLocation StartLoc, + SourceLocation StartLoc, SourceLocation EndLoc); // OpenMP directives and clauses. @@ -11139,7 +11139,8 @@ SourceLocation LParenLoc, SourceLocation EndLoc); /// Called on well-formed 'when' clause. - OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, SourceLocation StartLoc, + OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, StmtResult Directive, + SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); /// Called on well-formed 'default' clause. diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -15,6 +15,7 @@ #include "clang/AST/Attr.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclOpenMP.h" +#include "clang/AST/StmtOpenMP.h" #include "clang/Basic/LLVM.h" #include "clang/Basic/OpenMPKinds.h" #include "clang/Basic/TargetInfo.h" @@ -2334,6 +2335,32 @@ << ")"; } +void OMPClausePrinter::VisitOMPWhenClause(OMPWhenClause *Node) { + OMPTraitInfo &TI = Node->getTraitInfo(); + if (TI.Sets.empty()) + OS << "default("; + else + OS << "when("; + TI.print(OS, Policy); + Stmt *S = Node->getDirective(); + if (S) { + OS << ":"; + OMPExecutableDirective *D = cast(S); + auto DKind = D->getDirectiveKind(); + OS << getOpenMPDirectiveName(DKind); + + OMPClausePrinter Printer(OS, Policy); + ArrayRef Clauses = D->clauses(); + for (auto *Clause : Clauses) + if (Clause && !Clause->isImplicit()) { + OS << ' '; + Printer.Visit(Clause); + } + } + OS << ")"; + return; +} + void OMPTraitInfo::getAsVariantMatchInfo(ASTContext &ASTCtx, VariantMatchInfo &VMI) const { for (const OMPTraitSet &Set : Sets) { @@ -2348,13 +2375,13 @@ TraitProperty::user_condition_unknown && "Ill-formed user condition, expected unknown trait property!"); + // If Condition is statically resolvable add it as a trait, otherwise + // do nothing since codegen will generate dynamic conditions. if (Optional CondVal = Selector.ScoreOrCondition->getIntegerConstantExpr(ASTCtx)) VMI.addTrait(CondVal->isZero() ? TraitProperty::user_condition_false : TraitProperty::user_condition_true, ""); - else - VMI.addTrait(TraitProperty::user_condition_false, ""); continue; } diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -262,11 +262,10 @@ OMPMetaDirective *OMPMetaDirective::Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, - ArrayRef Clauses, - Stmt *AssociatedStmt, Stmt *IfStmt) { - auto *Dir = createDirective( - C, Clauses, AssociatedStmt, /*NumChildren=*/1, StartLoc, EndLoc); - Dir->setIfStmt(IfStmt); + ArrayRef Clauses) { + auto *Dir = + createDirective(C, Clauses, /*AssociatedStmt*/ nullptr, + /*NumChildren=*/0, StartLoc, EndLoc); return Dir; } diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -886,6 +886,11 @@ } void OMPClauseProfiler::VisitOMPOrderClause(const OMPOrderClause *C) {} void OMPClauseProfiler::VisitOMPBindClause(const OMPBindClause *C) {} + +void OMPClauseProfiler::VisitOMPWhenClause(const OMPWhenClause *C) { + if (C->getDirective()) + Profiler->VisitStmt(C->getDirective()); +} } // namespace void diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -10762,6 +10762,12 @@ } if (const auto *E = dyn_cast(S)) { + if (E->getDirectiveKind() == OMPD_metadirective) { + for (const auto *C : E->getClausesOfKind()) + if (C->getDirective()) + scanForTargetRegionsFunctions(C->getDirective(), ParentName); + } + if (!E->hasAssociatedStmt() || !E->getAssociatedStmt()) return; diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -24,6 +24,7 @@ #include "clang/AST/StmtVisitor.h" #include "clang/Basic/OpenMPKinds.h" #include "clang/Basic/PrettyStackTrace.h" +#include "clang/Parse/ParseDiagnostic.h" #include "llvm/BinaryFormat/Dwarf.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" @@ -1788,8 +1789,72 @@ checkForLastprivateConditionalUpdate(*this, S); } -void CodeGenFunction::EmitOMPMetaDirective(const OMPMetaDirective &S) { - EmitStmt(S.getIfStmt()); +void CodeGenFunction::EmitOMPMetaDirective(const OMPMetaDirective &D) { + llvm::BasicBlock *AfterBlock = + createBasicBlock("omp.meta.user.condition.after"); + + SmallVector StaticWhenClauses; + SmallVector StaticVMIs; + + for (auto *C : D.getClausesOfKind()) { + OMPTraitInfo &TI = C->getTraitInfo(); + + llvm::BasicBlock *ExitBlock = + createBasicBlock("omp.meta.user.condition.exit"); + + // Emit code to generate a dynamic condition, returns true if there is a + // condition, false otherwise. + auto GenerateCond = [&](Expr *&E, bool IsScore) { + if (IsScore) + return false; + + // Do not emit code if the expression is statically resolvable, will be + // handled as a static when clause. + if (E->getIntegerConstantExpr(getContext())) + return false; + + llvm::BasicBlock *TrueBlock = createBasicBlock("omp.meta.user.condition"); + EmitBranchOnBoolExpr(E, TrueBlock, ExitBlock, + getProfileCount(C->getDirective())); + EmitBlock(TrueBlock); + EmitStmt(C->getDirective()); + EmitBranch(AfterBlock); + + return true; + }; + + // If there is no dynamic condition for the clause then add it to the + // static clauses and resolve later the best match to generate code for. + // This also handles the default clause. + if (!TI.anyScoreOrCondition(GenerateCond)) { + StaticWhenClauses.push_back(C); + VariantMatchInfo VMI; + TI.getAsVariantMatchInfo(getContext(), VMI); + StaticVMIs.push_back(VMI); + } else + EmitBlock(ExitBlock); + } + + // Emit code for static clauses, if any. + if (!StaticWhenClauses.empty()) { + std::function DiagUnknownTrait = [&](StringRef ISATrait) { + CGM.getDiags().Report(D.getBeginLoc(), + diag::warn_unknown_declare_variant_isa_trait) + << ISATrait; + }; + + TargetOMPContext OMPCtx( + getContext(), std::move(DiagUnknownTrait), + /* CurrentFunctionDecl */ nullptr, + /* ConstructTraits */ ArrayRef()); + + int BestIdx = getBestVariantMatchForContext(StaticVMIs, OMPCtx); + + EmitStmt(StaticWhenClauses[BestIdx]->getDirective()); + EmitBranch(AfterBlock); + } + + EmitBlock(AfterBlock); } namespace { diff --git a/clang/lib/Lex/PPCaching.cpp b/clang/lib/Lex/PPCaching.cpp --- a/clang/lib/Lex/PPCaching.cpp +++ b/clang/lib/Lex/PPCaching.cpp @@ -37,12 +37,18 @@ // Make Preprocessor re-lex the tokens that were lexed since // EnableBacktrackAtThisPos() was previously called. -void Preprocessor::Backtrack() { - assert(!BacktrackPositions.empty() - && "EnableBacktrackAtThisPos was not called!"); +void Preprocessor::Backtrack(bool ReverseAnnotations = false) { + assert(!BacktrackPositions.empty() && + "EnableBacktrackAtThisPos was not called!"); CachedLexPos = BacktrackPositions.back(); BacktrackPositions.pop_back(); recomputeCurLexerKind(); + + if (ReverseAnnotations) { + CachedTokens.erase(CachedTokens.begin() + CachedLexPos, CachedTokens.end()); + CachedTokens.append(UnannotatedCachedTokens.begin() + CachedLexPos, + UnannotatedCachedTokens.end()); + } } void Preprocessor::CachingLex(Token &Result) { @@ -66,6 +72,7 @@ // Cache the lexed token. EnterCachingLexModeUnchecked(); CachedTokens.push_back(Result); + UnannotatedCachedTokens.push_back(Result); ++CachedLexPos; return; } @@ -75,6 +82,7 @@ } else { // All cached tokens were consumed. CachedTokens.clear(); + UnannotatedCachedTokens.clear(); CachedLexPos = 0; } } @@ -106,8 +114,10 @@ assert(CachedLexPos + N > CachedTokens.size() && "Confused caching."); ExitCachingLexMode(); for (size_t C = CachedLexPos + N - CachedTokens.size(); C > 0; --C) { - CachedTokens.push_back(Token()); - Lex(CachedTokens.back()); + Token Result; + Lex(Result); + CachedTokens.push_back(Result); + UnannotatedCachedTokens.push_back(Result); } EnterCachingLexMode(); return CachedTokens.back(); diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -2488,12 +2488,12 @@ // First iteration of parsing all clauses of metadirective. // This iteration only parses and collects all context selector ignoring the // associated directives. - TentativeParsingAction TPA(*this); ASTContext &ASTContext = Actions.getASTContext(); BalancedDelimiterTracker T(*this, tok::l_paren, tok::annot_pragma_openmp_end); while (Tok.isNot(tok::annot_pragma_openmp_end)) { + TentativeParsingAction TPA(*this); OpenMPClauseKind CKind = Tok.isAnnotation() ? OMPC_unknown : getOpenMPClauseKind(PP.getSpelling(Tok)); @@ -2505,122 +2505,102 @@ return Directive; OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo(); - if (CKind == OMPC_when) { - // parse and get OMPTraitInfo to pass to the When clause - parseOMPContextSelectors(Loc, TI); - if (TI.Sets.size() == 0) { - Diag(Tok, diag::err_omp_expected_context_selector) << "when clause"; - TPA.Commit(); - return Directive; - } - - // Parse ':' - if (Tok.is(tok::colon)) - ConsumeAnyToken(); - else { - Diag(Tok, diag::err_omp_expected_colon) << "when clause"; - TPA.Commit(); - return Directive; - } - } - // Skip Directive for now. We will parse directive in the second iteration - int paren = 0; - while (Tok.isNot(tok::r_paren) || paren != 0) { - if (Tok.is(tok::l_paren)) - paren++; - if (Tok.is(tok::r_paren)) - paren--; - if (Tok.is(tok::annot_pragma_openmp_end)) { - Diag(Tok, diag::err_omp_expected_punc) - << getOpenMPClauseName(CKind) << 0; - TPA.Commit(); - return Directive; - } - ConsumeAnyToken(); - } - // Parse ')' - if (Tok.is(tok::r_paren)) - T.consumeClose(); - - VariantMatchInfo VMI; - TI.getAsVariantMatchInfo(ASTContext, VMI); - - VMIs.push_back(VMI); - } - - TPA.Revert(); - // End of the first iteration. Parser is reset to the start of metadirective + if (CKind == OMPC_when || CKind == OMPC_default) { + // If it is a "when" clause parse the context selectors for the trait + // info, a "default" clause will have empty trait info. + if (CKind == OMPC_when) { + // parse and get OMPTraitInfo to pass to the When clause + parseOMPContextSelectors(Loc, TI); + if (TI.Sets.size() == 0) { + Diag(Tok, diag::err_omp_expected_context_selector) << "when clause"; + TPA.Commit(); + return Directive; + } - std::function DiagUnknownTrait = [this, Loc]( - StringRef ISATrait) { - // TODO Track the selector locations in a way that is accessible here to - // improve the diagnostic location. - Diag(Loc, diag::warn_unknown_declare_variant_isa_trait) << ISATrait; - }; - TargetOMPContext OMPCtx(ASTContext, std::move(DiagUnknownTrait), - /* CurrentFunctionDecl */ nullptr, - ArrayRef()); + // Parse ':' + if (Tok.is(tok::colon)) + ConsumeAnyToken(); + else { + Diag(Tok, diag::err_omp_expected_colon) << "when clause"; + TPA.Commit(); + return Directive; + } - // A single match is returned for OpenMP 5.0 - int BestIdx = getBestVariantMatchForContext(VMIs, OMPCtx); + VariantMatchInfo VMI; + TI.getAsVariantMatchInfo(ASTContext, VMI); + VMIs.push_back(VMI); + } - int Idx = 0; - // In OpenMP 5.0 metadirective is either replaced by another directive or - // ignored. - // TODO: In OpenMP 5.1 generate multiple directives based upon the matches - // found by getBestWhenMatchForContext. - while (Tok.isNot(tok::annot_pragma_openmp_end)) { - // OpenMP 5.0 implementation - Skip to the best index found. - if (Idx++ != BestIdx) { + // TODO: currently expects a directive, OpenMP 5.1 specifies nothing + // directive when there is none. + ReadDirectiveWithinMetadirective = true; + StmtResult WhenDirective = + ParseOpenMPDeclarativeOrExecutableDirective(StmtCtx); + ReadDirectiveWithinMetadirective = false; + if (OMPExecutableDirective *D = + dyn_cast(WhenDirective.get())) + if (D->hasAssociatedStmt()) + HasAssociatedStatement = true; + + auto *WhenClause = Actions.ActOnOpenMPWhenClause( + TI, WhenDirective, Loc, T.getOpenLocation(), T.getCloseLocation()); + Clauses.push_back(WhenClause); + + // Revert back to the beginning of the clause. + TPA.Revert(/*RevertAnnotations*/ true); + + // Skip until the end of this clause. ConsumeToken(); // Consume clause name T.consumeOpen(); // Consume '(' int paren = 0; - // Skip everything inside the clause while (Tok.isNot(tok::r_paren) || paren != 0) { if (Tok.is(tok::l_paren)) paren++; if (Tok.is(tok::r_paren)) paren--; + if (Tok.is(tok::annot_pragma_openmp_end)) { + Diag(Tok, diag::err_omp_expected_punc) + << getOpenMPClauseName(CKind) << 0; + return Directive; + } ConsumeAnyToken(); } - // Parse ')' - if (Tok.is(tok::r_paren)) - T.consumeClose(); - continue; - } - - OpenMPClauseKind CKind = Tok.isAnnotation() - ? OMPC_unknown - : getOpenMPClauseKind(PP.getSpelling(Tok)); - SourceLocation Loc = ConsumeToken(); - - // Parse '('. - T.consumeOpen(); - // Skip ContextSelectors for when clause - if (CKind == OMPC_when) { - OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo(); - // parse and skip the ContextSelectors - parseOMPContextSelectors(Loc, TI); - - // Parse ':' - ConsumeAnyToken(); + assert(Tok.is(tok::r_paren) && + "Expected right paren ending the clause."); + T.consumeClose(); } + } - // If no directive is passed, skip in OpenMP 5.0. - // TODO: Generate nothing directive from OpenMP 5.1. - if (Tok.is(tok::r_paren)) { - SkipUntil(tok::annot_pragma_openmp_end); - break; - } + // Skip until the end of the metadirective. + SkipUntil(tok::annot_pragma_openmp_end); + // Skip any associated statement. + if (HasAssociatedStatement) + ParseStatement(); + + std::function DiagUnknownTrait = + [this, Loc](StringRef ISATrait) { + // TODO Track the selector locations in a way that is accessible here + // to improve the diagnostic location. + Diag(Loc, diag::warn_unknown_declare_variant_isa_trait) << ISATrait; + }; + TargetOMPContext OMPCtx(ASTContext, std::move(DiagUnknownTrait), + /* CurrentFunctionDecl */ nullptr, + ArrayRef()); - // Parse Directive - ReadDirectiveWithinMetadirective = true; - Directive = ParseOpenMPDeclarativeOrExecutableDirective(StmtCtx); - ReadDirectiveWithinMetadirective = false; - break; + // Find applicable clauses. + SmallVector ApplicableClauses; + for (OMPClause *C : Clauses) { + OMPTraitInfo &TI = cast(C)->getTraitInfo(); + VariantMatchInfo VMI; + TI.getAsVariantMatchInfo(ASTContext, VMI); + SmallVector ConstructMatches; + if (isVariantApplicableInContext(VMI, OMPCtx, /*DeviceOnly*/ false)) + ApplicableClauses.push_back(C); } - break; + + return Actions.ActOnOpenMPMetaDirective(ApplicableClauses, Loc, + Tok.getLastLoc()); } case OMPD_threadprivate: { // FIXME: Should this be permitted in C++? diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -21631,3 +21631,17 @@ return OMPBindClause::Create(Context, Kind, KindLoc, StartLoc, LParenLoc, EndLoc); } + +OMPClause *Sema::ActOnOpenMPWhenClause(OMPTraitInfo &TI, StmtResult Directive, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + return new (Context) + OMPWhenClause(TI, Directive.get(), StartLoc, LParenLoc, EndLoc); +} + +StmtResult Sema::ActOnOpenMPMetaDirective(ArrayRef Clauses, + SourceLocation StartLoc, + SourceLocation EndLoc) { + return OMPMetaDirective::Create(Context, StartLoc, EndLoc, Clauses); +} diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -2284,6 +2284,18 @@ return getSema().ActOnOpenMPAlignClause(A, StartLoc, LParenLoc, EndLoc); } + /// Build a new OpenMP 'when' clause. + /// + /// By default, performs semantic analysis to build the new OpenMP clause. + /// Subclasses may override this routine to provide different behavior. + OMPClause *RebuildOMPWhenClause(OMPTraitInfo &TI, Stmt *Directive, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + return getSema().ActOnOpenMPWhenClause(TI, Directive, StartLoc, LParenLoc, + EndLoc); + } + /// Rebuild the operand to an Objective-C \@synchronized statement. /// /// By default, performs semantic analysis to build the new statement. @@ -10323,6 +10335,13 @@ C->getLParenLoc(), C->getEndLoc()); } +template +OMPClause *TreeTransform::TransformOMPWhenClause(OMPWhenClause *C) { + return getDerived().RebuildOMPWhenClause(C->getTraitInfo(), C->getDirective(), + C->getBeginLoc(), C->getLParenLoc(), + C->getEndLoc()); +} + //===----------------------------------------------------------------------===// // Expression transformation //===----------------------------------------------------------------------===// diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -12994,6 +12994,12 @@ C->setBindKindLoc(Record.readSourceLocation()); } +void OMPClauseReader::VisitOMPWhenClause(OMPWhenClause *C) { + // TODO: check, not familiar with this. + C->setTraitInfo(Record.readOMPTraitInfo()); + C->setDirective(Record.readStmt()); +} + void OMPClauseReader::VisitOMPAlignClause(OMPAlignClause *C) { C->setAlignment(Record.readExpr()); C->setLParenLoc(Record.readSourceLocation()); diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -6864,6 +6864,12 @@ Record.AddSourceLocation(C->getBindKindLoc()); } +void OMPClauseWriter::VisitOMPWhenClause(OMPWhenClause *C) { + // TODO: check, not familiar with this. + Record.writeOMPTraitInfo(&C->getTraitInfo()); + Record.AddStmt(C->getDirective()); +} + void ASTRecordWriter::writeOMPTraitInfo(const OMPTraitInfo *TI) { writeUInt32(TI->Sets.size()); for (const auto &Set : TI->Sets) { diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2591,6 +2591,10 @@ } void OMPClauseEnqueue::VisitOMPBindClause(const OMPBindClause *C) {} +void OMPClauseEnqueue::VisitOMPWhenClause(const OMPWhenClause *C) { + Visitor->AddStmt(C->getDirective()); +} + } // namespace void EnqueueVisitor::EnqueueChildren(const OMPClause *S) { diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -368,7 +368,9 @@ def OMPC_Align : Clause<"align"> { let clangClass = "OMPAlignClause"; } -def OMPC_When: Clause<"when"> {} +def OMPC_When: Clause<"when"> { + let clangClass = "OMPWhenClause"; +} def OMPC_Bind : Clause<"bind"> { let clangClass = "OMPBindClause";