diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h --- a/clang/include/clang-c/Index.h +++ b/clang/include/clang-c/Index.h @@ -2578,7 +2578,11 @@ */ CXCursor_OMPDepobjDirective = 286, - CXCursor_LastStmt = CXCursor_OMPDepobjDirective, + /** OpenMP metadirective. + * + */ + CXCursor_OMPMetaDirective = 287, + + CXCursor_LastStmt = CXCursor_OMPMetaDirective, /** * Cursor that represents the translation unit itself. diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -853,6 +853,77 @@ } }; +/// This represents 'when' clause in the '#pragma omp ...' directive +/// +/// \code +/// #pragma omp metadirective when(user={condition(N<10)}: parallel) +/// \endcode +/// In this example directive '#pragma omp metadirective' has simple 'when' +/// clause with user defined condition. +class OMPWhenClause : public OMPClause { + friend class OMPClauseReader; + + Expr *expr; + OpenMPDirectiveKind DKind; + SmallVector Clauses; + + /// Location of '('. + SourceLocation LParenLoc; + + void setExpr(Expr *E) { expr = E; } + +public: + /// Build 'when' clause with argument \a A ('none' or 'shared'). + /// + /// \param A Argument of the clause ('none' or 'shared'). + /// \param ALoc Starting location of the argument. + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + OMPWhenClause(Expr *expr, OpenMPDirectiveKind dKind, + SmallVector clauses, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc) + : OMPClause(OMPC_when, StartLoc, EndLoc), + expr(expr), DKind(dKind), Clauses(clauses), LParenLoc(LParenLoc) {} + + /// Build an empty clause. + OMPWhenClause() + : OMPClause(OMPC_when, SourceLocation(), SourceLocation()) {} + + /// Sets the location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + /// Returns the associated condition expression + Expr *getExpr() const { return expr; } + + /// Returns the directive variant + OpenMPDirectiveKind getDKind() { return DKind; } + + /// Returns the clauses associated with the directive variants + SmallVector getClauses() { return Clauses; } + + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + + const_child_range children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == OMPC_when; + } +}; + /// This represents 'default' clause in the '#pragma omp ...' directive. /// /// \code diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -2784,6 +2784,9 @@ return TraverseOMPExecutableDirective(S); } +DEF_TRAVERSE_STMT(OMPMetaDirective, + { TRY_TO(TraverseOMPExecutableDirective(S)); }) + DEF_TRAVERSE_STMT(OMPParallelDirective, { TRY_TO(TraverseOMPExecutableDirective(S)); }) @@ -3037,6 +3040,12 @@ } template +bool RecursiveASTVisitor::VisitOMPWhenClause(OMPWhenClause *C) { + TRY_TO(TraverseStmt(C->getExpr())); + return true; +} + +template bool RecursiveASTVisitor::VisitOMPDefaultClause(OMPDefaultClause *) { return true; } diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -345,6 +345,35 @@ } }; +class OMPMetaDirective : public OMPExecutableDirective { + friend class ASTStmtReader; + Stmt* IfStmt; + + OMPMetaDirective(SourceLocation StartLoc, SourceLocation EndLoc, + unsigned NumClauses) + : OMPExecutableDirective(this, OMPMetaDirectiveClass, + llvm::omp::OMPD_metadirective, StartLoc, EndLoc, + NumClauses, 1) {} + explicit OMPMetaDirective(unsigned NumClauses) + : OMPExecutableDirective(this, OMPMetaDirectiveClass, + llvm::omp::OMPD_metadirective, SourceLocation(), + SourceLocation(), NumClauses, 1) {} +public: + static OMPMetaDirective *Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation EndLoc, + ArrayRef Clauses, + Stmt *AssociatedStmt, Stmt* IfStmt); + static OMPMetaDirective *CreateEmpty(const ASTContext &C, + unsigned NumClauses, EmptyShell); + + void setIfStmt(Stmt* stmt) { IfStmt = stmt; } + Stmt* getIfStmt() const { return IfStmt; } + + static bool classof(const Stmt *T) { + return T->getStmtClass() == OMPMetaDirectiveClass; + } +}; + /// This represents '#pragma omp parallel' directive. /// /// \code diff --git a/clang/include/clang/Basic/OpenMPKinds.def b/clang/include/clang/Basic/OpenMPKinds.def --- a/clang/include/clang/Basic/OpenMPKinds.def +++ b/clang/include/clang/Basic/OpenMPKinds.def @@ -212,6 +212,9 @@ #ifndef OPENMP_DEPOBJ_CLAUSE #define OPENMP_DEPOBJ_CLAUSE(Name) #endif +#ifndef OPENMP_METADIRECTIVE_CLAUSE +#define OPENMP_METADIRECTIVE_CLAUSE(Name) +#endif // OpenMP clauses. OPENMP_CLAUSE(allocator, OMPAllocatorClause) @@ -277,6 +280,7 @@ OPENMP_CLAUSE(order, OMPOrderClause) OPENMP_CLAUSE(depobj, OMPDepobjClause) OPENMP_CLAUSE(destroy, OMPDestroyClause) +OPENMP_CLAUSE(when, OMPWhenClause) // Clauses allowed for OpenMP directive 'parallel'. OPENMP_PARALLEL_CLAUSE(if) @@ -1089,6 +1093,11 @@ OPENMP_DEPOBJ_CLAUSE(destroy) OPENMP_DEPOBJ_CLAUSE(update) +// Clauses allowed for OpenMP directive 'metadirective'. +OPENMP_METADIRECTIVE_CLAUSE(when) +OPENMP_METADIRECTIVE_CLAUSE(default) + +#undef OPENMP_METADIRECTIVE_CLAUSE #undef OPENMP_DEPOBJ_CLAUSE #undef OPENMP_FLUSH_CLAUSE #undef OPENMP_ORDER_KIND diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -212,6 +212,7 @@ // OpenMP Directives. def OMPExecutableDirective : StmtNode; +def OMPMetaDirective : StmtNode; def OMPLoopDirective : StmtNode; def OMPParallelDirective : StmtNode; def OMPSimdDirective : StmtNode; diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h --- a/clang/include/clang/Parse/Parser.h +++ b/clang/include/clang/Parse/Parser.h @@ -3002,6 +3002,10 @@ /// \param StmtCtx The context in which we're parsing the directive. StmtResult ParseOpenMPDeclarativeOrExecutableDirective(ParsedStmtContext StmtCtx); + /// Parses clause for metadirective + /// + OMPClause* ParseOpenMPMetaClause(OpenMPDirectiveKind DKind, + OpenMPClauseKind CKind); /// Parses clause of kind \a CKind for directive of a kind \a Kind. /// /// \param DKind Kind of current directive. diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -9886,6 +9886,12 @@ void ActOnOpenMPLoopInitialization(SourceLocation ForLoc, Stmt *Init); // OpenMP directives and clauses. + /// Called on well-formed '\#pragma omp metadirective' after parsing + /// of the associated statement. + StmtResult ActOnOpenMPMetaDirective(ArrayRef Clauses, + Stmt *AStmt, + SourceLocation StartLoc, + SourceLocation EndLoc); /// Called on correct id-expression from the '#pragma omp /// threadprivate'. ExprResult ActOnOpenMPIdExpression(Scope *CurScope, CXXScopeSpec &ScopeSpec, @@ -10369,6 +10375,13 @@ SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); + /// Called on well-formed 'when' clause. + OMPClause *ActOnOpenMPWhenClause(Expr* Expr, + OpenMPDirectiveKind DKind, + SmallVector Clauses, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc); /// Called on well-formed 'default' clause. OMPClause *ActOnOpenMPDefaultClause(llvm::omp::DefaultKind Kind, SourceLocation KindLoc, diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h --- a/clang/include/clang/Serialization/ASTBitCodes.h +++ b/clang/include/clang/Serialization/ASTBitCodes.h @@ -1807,6 +1807,7 @@ STMT_SEH_TRY, // SEHTryStmt // OpenMP directives + STMT_OMP_META_DIRECTIVE, STMT_OMP_PARALLEL_DIRECTIVE, STMT_OMP_SIMD_DIRECTIVE, STMT_OMP_FOR_DIRECTIVE, diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -145,6 +145,7 @@ case OMPC_nontemporal: case OMPC_order: case OMPC_destroy: + case OMPC_when: break; } @@ -231,6 +232,7 @@ case OMPC_nontemporal: case OMPC_order: case OMPC_destroy: + case OMPC_when: break; } @@ -1249,6 +1251,12 @@ //===----------------------------------------------------------------------===// // OpenMP clauses printing methods //===----------------------------------------------------------------------===// +void OMPClausePrinter::VisitOMPWhenClause(OMPWhenClause *Node) { + OS << "when("; + if(Node->getExpr() != NULL) + Node->getExpr()->printPretty(OS, nullptr, Policy, 0); + OS << ")"; +} void OMPClausePrinter::VisitOMPIfClause(OMPIfClause *Node) { OS << "if("; diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -159,6 +159,29 @@ llvm::copy(A, getFinalsConditions().begin()); } +OMPMetaDirective *OMPMetaDirective::Create( + const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, + ArrayRef Clauses, Stmt *AssociatedStmt, Stmt *IfStmt) { + unsigned Size = llvm::alignTo(sizeof(OMPMetaDirective), alignof(OMPClause *)); + void *Mem = + C.Allocate(Size + sizeof(OMPClause *) * Clauses.size() + sizeof(Stmt *)); + OMPMetaDirective *Dir = new (Mem) OMPMetaDirective(StartLoc, EndLoc, + Clauses.size()); + Dir->setClauses(Clauses); + Dir->setAssociatedStmt(AssociatedStmt); + Dir->setIfStmt(IfStmt); + return Dir; +} + +OMPMetaDirective *OMPMetaDirective::CreateEmpty(const ASTContext &C, + unsigned NumClauses, + EmptyShell) { + unsigned Size = llvm::alignTo(sizeof(OMPMetaDirective), alignof(OMPClause *)); + void *Mem = + C.Allocate(Size + sizeof(OMPClause *) * NumClauses + sizeof(Stmt *)); + return new (Mem) OMPMetaDirective(NumClauses); +} + OMPParallelDirective *OMPParallelDirective::Create( const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, ArrayRef Clauses, Stmt *AssociatedStmt, bool HasCancel) { diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -650,6 +650,11 @@ PrintStmt(S->getInnermostCapturedStmt()->getCapturedStmt()); } +void StmtPrinter::VisitOMPMetaDirective(OMPMetaDirective *Node) { + Indent() << "#pragma omp metadirective"; + PrintOMPExecutableDirective(Node); +} + void StmtPrinter::VisitOMPParallelDirective(OMPParallelDirective *Node) { Indent() << "#pragma omp parallel"; PrintOMPExecutableDirective(Node); diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -472,6 +472,11 @@ Profiler->VisitStmt(C->getNumForLoops()); } +void OMPClauseProfiler::VisitOMPWhenClause(const OMPWhenClause *C) { + if(C->getExpr()) + Profiler->VisitStmt(C->getExpr()); +} + void OMPClauseProfiler::VisitOMPDefaultClause(const OMPDefaultClause *C) { } void OMPClauseProfiler::VisitOMPProcBindClause(const OMPProcBindClause *C) { } @@ -804,6 +809,10 @@ P.Visit(*I); } +void StmtProfiler::VisitOMPMetaDirective(const OMPMetaDirective *S) { + VisitOMPExecutableDirective(S); +} + void StmtProfiler::VisitOMPLoopDirective(const OMPLoopDirective *S) { VisitOMPExecutableDirective(S); } diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -207,6 +207,7 @@ case OMPC_match: case OMPC_nontemporal: case OMPC_destroy: + case OMPC_when: break; } llvm_unreachable("Invalid OpenMP simple clause kind"); @@ -432,6 +433,7 @@ case OMPC_match: case OMPC_nontemporal: case OMPC_destroy: + case OMPC_when: break; } llvm_unreachable("Invalid OpenMP simple clause kind"); @@ -449,6 +451,16 @@ if (OpenMPVersion < 50 && CKind == OMPC_order) return false; switch (DKind) { + case OMPD_metadirective: + switch(CKind) { +#define OPENMP_METADIRECTIVE_CLAUSE(Name) \ + case OMPC_##Name: \ + return true; +#include "clang/Basic/OpenMPKinds.def" + default: + break; + } + break; case OMPD_parallel: switch (CKind) { #define OPENMP_PARALLEL_CLAUSE(Name) \ @@ -1147,6 +1159,9 @@ OpenMPDirectiveKind DKind) { assert(DKind <= OMPD_unknown); switch (DKind) { + case OMPD_metadirective: + CaptureRegions.push_back(OMPD_metadirective); + break; case OMPD_parallel: case OMPD_parallel_for: case OMPD_parallel_for_simd: diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -6983,6 +6983,7 @@ case OMPD_parallel_master_taskloop: case OMPD_parallel_master_taskloop_simd: case OMPD_requires: + case OMPD_metadirective: case OMPD_unknown: break; } @@ -7295,6 +7296,7 @@ case OMPD_parallel_master_taskloop: case OMPD_parallel_master_taskloop_simd: case OMPD_requires: + case OMPD_metadirective: case OMPD_unknown: break; } @@ -9080,6 +9082,7 @@ case OMPD_parallel_master_taskloop: case OMPD_parallel_master_taskloop_simd: case OMPD_requires: + case OMPD_metadirective: case OMPD_unknown: llvm_unreachable("Unexpected directive."); } @@ -9844,6 +9847,7 @@ case OMPD_parallel_master_taskloop: case OMPD_parallel_master_taskloop_simd: case OMPD_requires: + case OMPD_metadirective: case OMPD_unknown: llvm_unreachable("Unknown target directive for OpenMP device codegen."); } @@ -10492,6 +10496,7 @@ case OMPD_target_parallel_for: case OMPD_target_parallel_for_simd: case OMPD_requires: + case OMPD_metadirective: case OMPD_unknown: llvm_unreachable("Unexpected standalone target data directive."); break; diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp @@ -813,6 +813,7 @@ case OMPD_parallel_master_taskloop: case OMPD_parallel_master_taskloop_simd: case OMPD_requires: + case OMPD_metadirective: case OMPD_unknown: llvm_unreachable("Unexpected directive."); } @@ -890,6 +891,7 @@ case OMPD_parallel_master_taskloop: case OMPD_parallel_master_taskloop_simd: case OMPD_requires: + case OMPD_metadirective: case OMPD_unknown: break; } @@ -1060,6 +1062,7 @@ case OMPD_parallel_master_taskloop: case OMPD_parallel_master_taskloop_simd: case OMPD_requires: + case OMPD_metadirective: case OMPD_unknown: llvm_unreachable("Unexpected directive."); } @@ -1143,6 +1146,7 @@ case OMPD_parallel_master_taskloop: case OMPD_parallel_master_taskloop_simd: case OMPD_requires: + case OMPD_metadirective: case OMPD_unknown: break; } diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp --- a/clang/lib/CodeGen/CGStmt.cpp +++ b/clang/lib/CodeGen/CGStmt.cpp @@ -190,6 +190,9 @@ case Stmt::SEHTryStmtClass: EmitSEHTryStmt(cast(*S)); break; + case Stmt::OMPMetaDirectiveClass: + EmitOMPMetaDirective(cast(*S)); + break; case Stmt::OMPParallelDirectiveClass: EmitOMPParallelDirective(cast(*S)); break; diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -4615,6 +4615,7 @@ case OMPC_nontemporal: case OMPC_order: case OMPC_destroy: + case OMPC_when: llvm_unreachable("Clause is not allowed in 'omp atomic'."); } } @@ -5757,6 +5758,11 @@ CGM.getOpenMPRuntime().emitMasterRegion(*this, CodeGen, S.getBeginLoc()); } +void CodeGenFunction::EmitOMPMetaDirective(const OMPMetaDirective &S) { + Stmt *I = S.getIfStmt(); + EmitIfStmt(cast(*I)); +} + void CodeGenFunction::EmitOMPParallelMasterTaskLoopDirective( const OMPParallelMasterTaskLoopDirective &S) { auto &&CodeGen = [this, &S](CodeGenFunction &CGF, PrePostActionTy &Action) { diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -3243,6 +3243,7 @@ const RegionCodeGenTy &BodyGen, OMPTargetDataInfo &InputInfo); + void EmitOMPMetaDirective(const OMPMetaDirective &S); void EmitOMPParallelDirective(const OMPParallelDirective &S); void EmitOMPSimdDirective(const OMPSimdDirective &S); void EmitOMPForDirective(const OMPForDirective &S); diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -1888,6 +1888,7 @@ case OMPD_target_teams_distribute_parallel_for: case OMPD_target_teams_distribute_parallel_for_simd: case OMPD_target_teams_distribute_simd: + case OMPD_metadirective: Diag(Tok, diag::err_omp_unexpected_directive) << 1 << getOpenMPDirectiveName(DKind); break; @@ -1957,6 +1958,50 @@ bool HasAssociatedStatement = true; switch (DKind) { + case OMPD_metadirective: { + ConsumeToken(); + ParseScope OMPDirectiveScope(this, ScopeFlags); + Actions.StartOpenMPDSABlock(DKind, DirName, Actions.getCurScope(), Loc); + + while (Tok.isNot(tok::annot_pragma_openmp_end)) { + OpenMPClauseKind CKind = + Tok.isAnnotation() + ? OMPC_unknown + : getOpenMPClauseKind(PP.getSpelling(Tok)); + Actions.StartOpenMPClause(CKind); + OMPClause *Clause = ParseOpenMPMetaClause(DKind, CKind); + FirstClauses[CKind].setInt(true); + if (Clause) { + FirstClauses[CKind].setPointer(Clause); + Clauses.push_back(Clause); + } + + // Skip ',' if any. + if (Tok.is(tok::comma)) + ConsumeToken(); + Actions.EndOpenMPClause(); + //Consume trailing ')' if any + if(Tok.is(tok::r_paren)) + ConsumeToken(); + } + // End location of the directive. + EndLoc = Tok.getLocation(); + // Consume final annot_pragma_openmp_end. + ConsumeAnnotationToken(); + + // The body is a block scope like in Lambdas and Blocks. + Actions.ActOnOpenMPRegionStart(DKind, getCurScope()); + StmtResult AssociatedStmt = (Sema::CompoundScopeRAII(Actions), ParseStatement()); + AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses); + Directive = Actions.ActOnOpenMPExecutableDirective( + DKind, DirName, CancelRegion, Clauses, AssociatedStmt.get(), Loc, + EndLoc); + + // Exit scope. + Actions.EndOpenMPDSABlock(Directive.get()); + OMPDirectiveScope.Exit(); + break; + } case OMPD_threadprivate: { // FIXME: Should this be permitted in C++? if ((StmtCtx & ParsedStmtContext::AllowDeclarationsInC) == @@ -2324,6 +2369,165 @@ return !IsCorrect; } +OMPClause *Parser::ParseOpenMPMetaClause(OpenMPDirectiveKind DKind, + OpenMPClauseKind CKind) { + OMPClause *Clause = nullptr; + bool ErrorFound = false; + bool WrongDirective = false; + SmallVector, OMPC_unknown + 1> + FirstClauses(OMPC_unknown + 1); + // Check if it is called from metadirective. + if(DKind != OMPD_metadirective) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + ErrorFound = true; + } + + // Check if clause is allowed for the given directive. + if (CKind != OMPC_unknown && + !isAllowedClauseForDirective(DKind, CKind, getLangOpts().OpenMP)) { + Diag(Tok, diag::err_omp_unexpected_clause) << getOpenMPClauseName(CKind) + << getOpenMPDirectiveName(DKind); + ErrorFound = true; + WrongDirective = true; + } + switch (CKind) { + case OMPC_default: + case OMPC_when: { + SourceLocation Loc = ConsumeToken(); + SourceLocation DelimLoc; + // Parse '('. + BalancedDelimiterTracker T(*this, tok::l_paren, tok::annot_pragma_openmp_end); + if (T.expectAndConsume(diag::err_expected_lparen_after, + getOpenMPClauseName(CKind))) + return nullptr; + + Expr *expr = NULL; + if(CKind == OMPC_when) { + //parse and get condition expression to pass to the When clause + OMPTraitInfo TI; + parseOMPContextSelectors(Loc, TI); + expr = TI.Sets.front().Selectors.front().ScoreOrCondition; + + // Parse ':' + if (Tok.is(tok::colon)) + ConsumeAnyToken(); + else { + Diag(Tok, diag::warn_pragma_expected_colon) << "when clause"; + return nullptr; + } + } + OpenMPDirectiveKind DirKind = OMPD_unknown; + SmallVector Clauses; + if(!Tok.is(tok::r_paren)) { + DirKind = parseOpenMPDirectiveKind(*this); + Tok.setKind(tok::identifier); + ConsumeToken(); + int paren = 0; + while (Tok.isNot(tok::r_paren) || paren !=0) { + if(Tok.is(tok::l_paren)) paren++; + if(Tok.is(tok::r_paren)) paren--; + + OpenMPClauseKind CKind = Tok.isAnnotation() + ? OMPC_unknown + : getOpenMPClauseKind(PP.getSpelling(Tok)); + Actions.StartOpenMPClause(CKind); + OMPClause *Clause = ParseOpenMPClause(DKind, CKind, + !FirstClauses[CKind].getInt()); + FirstClauses[CKind].setInt(true); + if (Clause) { + FirstClauses[CKind].setPointer(Clause); + Clauses.push_back(Clause); + } + + // Skip ',' if any. + if (Tok.is(tok::comma)) + ConsumeToken(); + Actions.EndOpenMPClause(); + } + //Consume ) + ConsumeToken(); + } + + if (WrongDirective) + return nullptr; + + Clause = Actions.ActOnOpenMPWhenClause(expr, DirKind, Clauses, Loc, + DelimLoc, Tok.getLocation()); + break; + } + case OMPC_final: + case OMPC_num_threads: + case OMPC_safelen: + case OMPC_simdlen: + case OMPC_collapse: + case OMPC_ordered: + case OMPC_device: + case OMPC_num_teams: + case OMPC_thread_limit: + case OMPC_priority: + case OMPC_grainsize: + case OMPC_num_tasks: + case OMPC_hint: + case OMPC_allocator: + case OMPC_proc_bind: + case OMPC_atomic_default_mem_order: + case OMPC_order: + case OMPC_schedule: + case OMPC_dist_schedule: + case OMPC_defaultmap: + case OMPC_if: + case OMPC_nowait: + case OMPC_untied: + case OMPC_mergeable: + case OMPC_read: + case OMPC_write: + case OMPC_update: + case OMPC_capture: + case OMPC_seq_cst: + case OMPC_acq_rel: + case OMPC_acquire: + case OMPC_release: + case OMPC_relaxed: + case OMPC_threads: + case OMPC_simd: + case OMPC_nogroup: + case OMPC_unified_address: + case OMPC_unified_shared_memory: + case OMPC_reverse_offload: + case OMPC_dynamic_allocators: + case OMPC_private: + case OMPC_firstprivate: + case OMPC_lastprivate: + case OMPC_shared: + case OMPC_reduction: + case OMPC_task_reduction: + case OMPC_in_reduction: + case OMPC_linear: + case OMPC_aligned: + case OMPC_copyin: + case OMPC_copyprivate: + case OMPC_flush: + case OMPC_depend: + case OMPC_map: + case OMPC_to: + case OMPC_from: + case OMPC_use_device_ptr: + case OMPC_is_device_ptr: + case OMPC_allocate: + case OMPC_nontemporal: + case OMPC_device_type: + case OMPC_unknown: + case OMPC_threadprivate: + case OMPC_uniform: + case OMPC_match: + ErrorFound = false; + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + } + return ErrorFound ? nullptr : Clause; +} + /// Parsing of OpenMP clauses. /// /// clause: @@ -2517,6 +2721,7 @@ case OMPC_threadprivate: case OMPC_uniform: case OMPC_match: + case OMPC_when: if (!WrongDirective) Diag(Tok, diag::err_omp_unexpected_clause) << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp --- a/clang/lib/Sema/SemaExceptionSpec.cpp +++ b/clang/lib/Sema/SemaExceptionSpec.cpp @@ -1473,6 +1473,7 @@ case Stmt::OMPTeamsDistributeParallelForDirectiveClass: case Stmt::OMPTeamsDistributeParallelForSimdDirectiveClass: case Stmt::OMPTeamsDistributeSimdDirectiveClass: + case Stmt::OMPMetaDirectiveClass: case Stmt::ReturnStmtClass: case Stmt::SEHExceptStmtClass: case Stmt::SEHFinallyStmtClass: diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -3460,6 +3460,7 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) { switch (DKind) { + case OMPD_metadirective: case OMPD_parallel: case OMPD_parallel_for: case OMPD_parallel_for_simd: @@ -4757,6 +4758,10 @@ llvm::SmallVector AllowedNameModifiers; switch (Kind) { + case OMPD_metadirective: + Res = ActOnOpenMPMetaDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc); + AllowedNameModifiers.push_back(OMPD_metadirective); + break; case OMPD_parallel: Res = ActOnOpenMPParallelDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc); @@ -5160,6 +5165,7 @@ case OMPC_atomic_default_mem_order: case OMPC_device_type: case OMPC_match: + case OMPC_when: llvm_unreachable("Unexpected clause"); } for (Stmt *CC : C->children()) { @@ -5771,6 +5777,57 @@ } } +StmtResult Sema::ActOnOpenMPMetaDirective(ArrayRef Clauses, + Stmt *AStmt, + SourceLocation StartLoc, + SourceLocation EndLoc) { + if (!AStmt) + return StmtError(); + + StmtResult IfStmt = StmtError(); + Stmt* ElseStmt = NULL; + + for(auto i = Clauses.rbegin(); i < Clauses.rend(); i++) { + Expr *expr = ((OMPWhenClause*)*i)->getExpr(); + Stmt *stmt = NULL; + + OpenMPDirectiveKind DKind = ((OMPWhenClause*)*i)->getDKind(); + DeclarationNameInfo DirName; + SmallVector clauses = ((OMPWhenClause*)*i)->getClauses(); + + StartOpenMPDSABlock(DKind, DirName, getCurScope(), StartLoc); + if(DKind != OMPD_unknown) + stmt = ActOnOpenMPExecutableDirective(DKind, DirName, OMPD_unknown, + clauses, AStmt, + StartLoc, EndLoc).get(); + EndOpenMPDSABlock(stmt); + + if(expr == NULL) { + if(ElseStmt != NULL) { + llvm::errs() << "Misplaced default clause! Only one default clause is"; + llvm::errs() << " allowed in metadirective in the end\n"; + return StmtError(); + } + if(DKind == OMPD_unknown) + ElseStmt = AStmt; + else + ElseStmt = stmt; + continue; + } + + if(stmt == NULL) + stmt = AStmt; + + IfStmt = ActOnIfStmt(SourceLocation(), false, NULL, + ActOnCondition(getCurScope(), SourceLocation(), expr, + Sema::ConditionKind::Boolean), + stmt, SourceLocation(), ElseStmt); + } + + return OMPMetaDirective::Create(Context, StartLoc, EndLoc, Clauses, + AStmt, IfStmt.get()); +} + StmtResult Sema::ActOnOpenMPParallelDirective(ArrayRef Clauses, Stmt *AStmt, SourceLocation StartLoc, @@ -11080,6 +11137,7 @@ case OMPC_nontemporal: case OMPC_order: case OMPC_destroy: + case OMPC_when: llvm_unreachable("Clause is not allowed."); } return Res; @@ -11229,6 +11287,7 @@ case OMPD_atomic: case OMPD_teams_distribute: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with if-clause"); case OMPD_unknown: llvm_unreachable("Unknown OpenMP directive"); @@ -11304,6 +11363,7 @@ case OMPD_teams_distribute: case OMPD_teams_distribute_simd: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with num_threads-clause"); case OMPD_unknown: llvm_unreachable("Unknown OpenMP directive"); @@ -11377,6 +11437,7 @@ case OMPD_atomic: case OMPD_distribute_simd: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with num_teams-clause"); case OMPD_unknown: llvm_unreachable("Unknown OpenMP directive"); @@ -11450,6 +11511,7 @@ case OMPD_atomic: case OMPD_distribute_simd: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with thread_limit-clause"); case OMPD_unknown: llvm_unreachable("Unknown OpenMP directive"); @@ -11523,6 +11585,7 @@ case OMPD_distribute_simd: case OMPD_target_teams: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with schedule clause"); case OMPD_unknown: llvm_unreachable("Unknown OpenMP directive"); @@ -11596,6 +11659,7 @@ case OMPD_atomic: case OMPD_target_teams: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with schedule clause"); case OMPD_unknown: llvm_unreachable("Unknown OpenMP directive"); @@ -11669,6 +11733,7 @@ case OMPD_atomic: case OMPD_distribute_simd: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with num_teams-clause"); case OMPD_unknown: llvm_unreachable("Unknown OpenMP directive"); @@ -11744,6 +11809,7 @@ case OMPD_atomic: case OMPD_distribute_simd: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with grainsize-clause"); case OMPD_unknown: llvm_unreachable("Unknown OpenMP directive"); @@ -11807,6 +11873,78 @@ case OMPC_order: case OMPC_destroy: llvm_unreachable("Unexpected OpenMP clause."); + case OMPC_when: + switch (DKind) { + case OMPD_metadirective: + CaptureRegion = OMPD_metadirective; + break; + case OMPD_parallel_master: + case OMPD_declare_mapper: + case OMPD_allocate: + case OMPD_declare_variant: + case OMPD_master_taskloop: + case OMPD_parallel_master_taskloop: + case OMPD_master_taskloop_simd: + case OMPD_parallel_master_taskloop_simd: + case OMPD_target_parallel: + case OMPD_target_parallel_for: + case OMPD_target_parallel_for_simd: + case OMPD_target_teams_distribute_parallel_for: + case OMPD_target_teams_distribute_parallel_for_simd: + case OMPD_teams_distribute_parallel_for: + case OMPD_teams_distribute_parallel_for_simd: + case OMPD_target_update: + case OMPD_target_enter_data: + case OMPD_target_exit_data: + case OMPD_cancel: + case OMPD_parallel: + case OMPD_parallel_sections: + case OMPD_parallel_for: + case OMPD_parallel_for_simd: + case OMPD_target: + case OMPD_target_simd: + case OMPD_target_teams: + case OMPD_target_teams_distribute: + case OMPD_target_teams_distribute_simd: + case OMPD_distribute_parallel_for: + case OMPD_distribute_parallel_for_simd: + case OMPD_task: + case OMPD_taskloop: + case OMPD_taskloop_simd: + case OMPD_target_data: + case OMPD_threadprivate: + case OMPD_taskyield: + case OMPD_barrier: + case OMPD_taskwait: + case OMPD_cancellation_point: + case OMPD_flush: + case OMPD_declare_reduction: + case OMPD_declare_simd: + case OMPD_declare_target: + case OMPD_end_declare_target: + case OMPD_teams: + case OMPD_simd: + case OMPD_for: + case OMPD_for_simd: + case OMPD_sections: + case OMPD_section: + case OMPD_single: + case OMPD_master: + case OMPD_critical: + case OMPD_taskgroup: + case OMPD_distribute: + case OMPD_ordered: + case OMPD_atomic: + case OMPD_distribute_simd: + case OMPD_teams_distribute: + case OMPD_teams_distribute_simd: + case OMPD_requires: + case OMPD_depobj: + llvm_unreachable("Unexpected OpenMP directive with when clause"); + case OMPD_unknown: + llvm_unreachable("Unknown OpenMP directive"); + } + break; } return CaptureRegion; } @@ -12242,6 +12380,7 @@ case OMPC_match: case OMPC_nontemporal: case OMPC_destroy: + case OMPC_when: llvm_unreachable("Clause is not allowed."); } return Res; @@ -12268,6 +12407,16 @@ return std::string(Out.str()); } +OMPClause *Sema::ActOnOpenMPWhenClause(Expr *Expr, + OpenMPDirectiveKind DKind, + SmallVector Clauses, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + return new (Context) OMPWhenClause(Expr, DKind, Clauses, StartLoc, LParenLoc, + EndLoc); +} + OMPClause *Sema::ActOnOpenMPDefaultClause(DefaultKind Kind, SourceLocation KindKwLoc, SourceLocation StartLoc, @@ -12459,6 +12608,7 @@ case OMPC_nontemporal: case OMPC_order: case OMPC_destroy: + case OMPC_when: llvm_unreachable("Clause is not allowed."); } return Res; @@ -12688,6 +12838,7 @@ case OMPC_match: case OMPC_nontemporal: case OMPC_order: + case OMPC_when: llvm_unreachable("Clause is not allowed."); } return Res; @@ -12943,6 +13094,7 @@ case OMPC_match: case OMPC_order: case OMPC_destroy: + case OMPC_when: llvm_unreachable("Clause is not allowed."); } return Res; diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -1606,6 +1606,20 @@ EndLoc); } + /// Build a new OpenMP 'when' clause. + /// + /// By default, performs semantic analysis to build the new OpenMP clause. + /// Subclasses may override this routine to provide different behavior. + OMPClause *RebuildOMPWhenClause(Expr* expr, + OpenMPDirectiveKind DKind, + SmallVector Clauses, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + return getSema().ActOnOpenMPWhenClause(expr, DKind, Clauses, StartLoc, + LParenLoc, EndLoc); + } + /// Build a new OpenMP 'default' clause. /// /// By default, performs semantic analysis to build the new OpenMP clause. @@ -8076,6 +8090,17 @@ template StmtResult +TreeTransform::TransformOMPMetaDirective(OMPMetaDirective *D) { + DeclarationNameInfo DirName; + getDerived().getSema().StartOpenMPDSABlock(OMPD_metadirective, DirName, + nullptr, D->getBeginLoc()); + StmtResult Res = getDerived().TransformOMPExecutableDirective(D); + getDerived().getSema().EndOpenMPDSABlock(Res.get()); + return Res; +} + +template +StmtResult TreeTransform::TransformOMPParallelDirective(OMPParallelDirective *D) { DeclarationNameInfo DirName; getDerived().getSema().StartOpenMPDSABlock(OMPD_parallel, DirName, nullptr, @@ -8739,6 +8764,17 @@ template OMPClause * +TreeTransform::TransformOMPWhenClause(OMPWhenClause *C) { + ExprResult E = getDerived().TransformExpr(C->getExpr()); + if (E.isInvalid()) + return nullptr; + return getDerived().RebuildOMPWhenClause(C->getExpr(), C->getDKind(), + C->getClauses(), C->getBeginLoc(), + C->getLParenLoc(), C->getEndLoc()); +} + +template +OMPClause * TreeTransform::TransformOMPDefaultClause(OMPDefaultClause *C) { return getDerived().RebuildOMPDefaultClause( C->getDefaultKind(), C->getDefaultKindKwLoc(), C->getBeginLoc(), diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -11608,6 +11608,9 @@ OMPClause *OMPClauseReader::readClause() { OMPClause *C = nullptr; switch (Record.readInt()) { + case OMPC_when: + C = new (Context) OMPWhenClause(); + break; case OMPC_if: C = new (Context) OMPIfClause(); break; @@ -11891,6 +11894,11 @@ C->setLParenLoc(Record.readSourceLocation()); } +void OMPClauseReader::VisitOMPWhenClause(OMPWhenClause *C) { + C->setExpr(Record.readSubExpr()); + C->setLParenLoc(Record.readSourceLocation()); +} + void OMPClauseReader::VisitOMPDefaultClause(OMPDefaultClause *C) { C->setDefaultKind(static_cast(Record.readInt())); C->setLParenLoc(Record.readSourceLocation()); diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2246,6 +2246,13 @@ D->setFinalsConditions(Sub); } +void ASTStmtReader::VisitOMPMetaDirective(OMPMetaDirective *D) { + VisitStmt(D); + // The NumClauses field was read in ReadStmtFromStream. + Record.skipInts(1); + VisitOMPExecutableDirective(D); +} + void ASTStmtReader::VisitOMPParallelDirective(OMPParallelDirective *D) { VisitStmt(D); // The NumClauses field was read in ReadStmtFromStream. @@ -3097,6 +3104,13 @@ nullptr); break; + case STMT_OMP_META_DIRECTIVE: + S = + OMPMetaDirective::CreateEmpty(Context, + Record[ASTStmtReader::NumStmtFields], + Empty); + break; + case STMT_OMP_PARALLEL_DIRECTIVE: S = OMPParallelDirective::CreateEmpty(Context, diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -6096,6 +6096,11 @@ Record.AddSourceLocation(C->getLParenLoc()); } +void OMPClauseWriter::VisitOMPWhenClause(OMPWhenClause *C) { + Record.AddStmt(C->getExpr()); + Record.AddSourceLocation(C->getLParenLoc()); +} + void OMPClauseWriter::VisitOMPDefaultClause(OMPDefaultClause *C) { Record.push_back(unsigned(C->getDefaultKind())); Record.AddSourceLocation(C->getLParenLoc()); diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -2134,6 +2134,13 @@ Record.AddStmt(S); } +void ASTStmtWriter::VisitOMPMetaDirective(OMPMetaDirective *D) { + VisitStmt(D); + Record.push_back(D->getNumClauses()); + VisitOMPExecutableDirective(D); + Code = serialization::STMT_OMP_META_DIRECTIVE; +} + void ASTStmtWriter::VisitOMPParallelDirective(OMPParallelDirective *D) { VisitStmt(D); Record.push_back(D->getNumClauses()); diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp --- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp +++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp @@ -1291,6 +1291,7 @@ case Stmt::OMPTargetTeamsDistributeParallelForDirectiveClass: case Stmt::OMPTargetTeamsDistributeParallelForSimdDirectiveClass: case Stmt::OMPTargetTeamsDistributeSimdDirectiveClass: + case Stmt::OMPMetaDirectiveClass: case Stmt::CapturedStmtClass: { const ExplodedNode *node = Bldr.generateSink(S, Pred, Pred->getState()); Engine.addAbortedBlock(node, currBldrCtx->getBlock()); diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2207,6 +2207,10 @@ Visitor->AddStmt(C->getNumForLoops()); } +void OMPClauseEnqueue::VisitOMPWhenClause(const OMPWhenClause *C) { + Visitor->AddStmt(C->getExpr()); +} + void OMPClauseEnqueue::VisitOMPDefaultClause(const OMPDefaultClause *C) { } void OMPClauseEnqueue::VisitOMPProcBindClause(const OMPProcBindClause *C) { } @@ -5475,6 +5479,8 @@ return cxstring::createRef("CXXAccessSpecifier"); case CXCursor_ModuleImportDecl: return cxstring::createRef("ModuleImport"); + case CXCursor_OMPMetaDirective: + return cxstring::createRef("OMPMetaDirective"); case CXCursor_OMPParallelDirective: return cxstring::createRef("OMPParallelDirective"); case CXCursor_OMPSimdDirective: diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp --- a/clang/tools/libclang/CXCursor.cpp +++ b/clang/tools/libclang/CXCursor.cpp @@ -578,6 +578,9 @@ case Stmt::MSDependentExistsStmtClass: K = CXCursor_UnexposedStmt; break; + case Stmt::OMPMetaDirectiveClass: + K = CXCursor_OMPMetaDirective; + break; case Stmt::OMPParallelDirectiveClass: K = CXCursor_OMPParallelDirective; break; diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -92,6 +92,7 @@ __OMP_DIRECTIVE_EXT(parallel_master_taskloop_simd, "parallel master taskloop simd") __OMP_DIRECTIVE(depobj) +__OMP_DIRECTIVE(metadirective) // Has to be the last because Clang implicitly expects it to be. __OMP_DIRECTIVE(unknown)