diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -8611,25 +8611,7 @@ template class ConstOMPClauseVisitor : public OMPClauseVisitorBase {}; - -class OMPClausePrinter final : public OMPClauseVisitor { - raw_ostream &OS; - const PrintingPolicy &Policy; - - /// Process clauses with list of variables. - template void VisitOMPClauseList(T *Node, char StartSym); - /// Process motion clauses. - template void VisitOMPMotionClause(T *Node); - -public: - OMPClausePrinter(raw_ostream &OS, const PrintingPolicy &Policy) - : OS(OS), Policy(Policy) {} - -#define GEN_CLANG_CLAUSE_CLASS -#define CLAUSE_CLASS(Enum, Str, Class) void Visit##Class(Class *S); -#include "llvm/Frontend/OpenMP/OMP.inc" -}; - + struct OMPTraitProperty { llvm::omp::TraitProperty Kind = llvm::omp::TraitProperty::invalid; @@ -8872,6 +8854,98 @@ } }; +/// This captures 'when' clause in the '#pragma omp metadirective' +/// \code +/// #pragma omp metadirective when(user={condition(N<100)}:parallel for) +/// \endcode +/// In the above example, the metadirective clause has a condition which when +/// satisfied will use the parallel for directive with the code enclosed by the +/// directive. +class OMPWhenClause final : public OMPClause { + friend class OMPClauseReader; + friend class OMPExecutableDirective; + template friend class OMPDeclarativeDirective; + + OMPTraitInfo *TI; + OpenMPDirectiveKind DKind; + Stmt *Directive; + + /// Location of '('. + SourceLocation LParenLoc; + + /// Sets the location of '('. + /// + /// \param Loc Location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + +public: + /// Build 'when' clause with argument \a A ('none' or 'shared'). + /// + /// \param T TraitInfor containing information about the context selector + /// \param DKind The directive associated with the when clause + /// \param D The statement associated with the when clause + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + OMPWhenClause(OMPTraitInfo &T, OpenMPDirectiveKind dKind, Stmt *D, + SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_when, StartLoc, EndLoc), TI(&T), DKind(dKind), + Directive(D), LParenLoc(LParenLoc) {} + + /// Build an empty clause. + OMPWhenClause() + : OMPClause(llvm::omp::OMPC_when, SourceLocation(), SourceLocation()) {} + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + /// Returns the directive variant kind + OpenMPDirectiveKind getDKind() const { return DKind; } + + Stmt *getDirective() const { return Directive; } + + /// Returns the OMPTraitInfo + OMPTraitInfo &getTI() const { return *TI; } + + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + + const_child_range children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_when; + } +}; + +class OMPClausePrinter final : public OMPClauseVisitor { + raw_ostream &OS; + const PrintingPolicy &Policy; + + /// Process clauses with list of variables. + template void VisitOMPClauseList(T *Node, char StartSym); + /// Process motion clauses. + template void VisitOMPMotionClause(T *Node); + +public: + OMPClausePrinter(raw_ostream &OS, const PrintingPolicy &Policy) + : OS(OS), Policy(Policy) {} + + void VisitOMPWhenClause(OMPWhenClause *Node); + +#define GEN_CLANG_CLAUSE_CLASS +#define CLAUSE_CLASS(Enum, Str, Class) void Visit##Class(Class *S); +#include "llvm/Frontend/OpenMP/OMP.inc" +}; } // namespace clang #endif // LLVM_CLANG_AST_OPENMPCLAUSE_H diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -501,7 +501,8 @@ /// Process clauses with pre-initis. bool VisitOMPClauseWithPreInit(OMPClauseWithPreInit *Node); bool VisitOMPClauseWithPostUpdate(OMPClauseWithPostUpdate *Node); - + bool VisitOMPWhenClause(OMPWhenClause *C); + bool PostVisitStmt(Stmt *S); }; @@ -3136,6 +3137,18 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPWhenClause(OMPWhenClause *C) { + for (const OMPTraitSet &Set : C->getTI().Sets) { + for (const OMPTraitSelector &Selector : Set.Selectors) { + if (Selector.Kind == llvm::omp::TraitSelector::user_condition && + Selector.ScoreOrCondition) + TRY_TO(TraverseStmt(Selector.ScoreOrCondition)); + } + } + return true; +} + template bool RecursiveASTVisitor::VisitOMPDefaultClause(OMPDefaultClause *) { return true; diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -5475,7 +5475,7 @@ ArrayRef Clauses, Stmt *AssociatedStmt, Stmt *IfStmt); static OMPMetaDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses, - EmptyShell); + EmptyShell); Stmt *getIfStmt() const { return IfStmt; } static bool classof(const Stmt *T) { diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10848,6 +10848,8 @@ "'%0' clause requires 'dispatch' context selector">; def err_omp_append_args_with_varargs : Error< "'append_args' is not allowed with varargs functions">; +def err_omp_misplaced_default_clause : Error< + "Only one 'default clause' is allowed">; } // end of OpenMP category let CategoryName = "Related Result Type Issue" in { diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h --- a/clang/include/clang/Parse/Parser.h +++ b/clang/include/clang/Parse/Parser.h @@ -3291,6 +3291,13 @@ /// \param StmtCtx The context in which we're parsing the directive. StmtResult ParseOpenMPDeclarativeOrExecutableDirective(ParsedStmtContext StmtCtx); + /// Parse clause for metadirective + /// + /// \param Dkind Kind of current directive + /// \param CKind Kind of current clause + /// + OMPClause *ParseOpenMPMetaDirectiveClause(OpenMPDirectiveKind DKind, + OpenMPClauseKind CKind); /// Parses clause of kind \a CKind for directive of a kind \a Kind. /// /// \param DKind Kind of current directive. diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -66,6 +66,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TinyPtrVector.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Frontend/OpenMP/OMPContext.h" #include #include #include @@ -10678,10 +10679,18 @@ /// /// \returns Statement for finished OpenMP region. StmtResult ActOnOpenMPRegionEnd(StmtResult S, ArrayRef Clauses); + + /// Called on well-formed StmtResult ActOnOpenMPExecutableDirective( OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName, OpenMPDirectiveKind CancelRegion, ArrayRef Clauses, Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc); + + /// Called on meta directive + StmtResult ActOnOpenMPExecutableMetaDirective( + OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName, + OpenMPDirectiveKind CancelRegion, ArrayRef Clauses, + Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc); /// Called on well-formed '\#pragma omp parallel' after parsing /// of the associated statement. StmtResult ActOnOpenMPParallelDirective(ArrayRef Clauses, @@ -11127,7 +11136,9 @@ SourceLocation LParenLoc, SourceLocation EndLoc); /// Called on well-formed 'when' clause. - OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, SourceLocation StartLoc, + OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, OpenMPDirectiveKind DKind, + StmtResult Directive, + SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); /// Called on well-formed 'default' clause. diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -16,6 +16,7 @@ #include "clang/AST/Decl.h" #include "clang/AST/DeclOpenMP.h" #include "clang/Basic/LLVM.h" +#include "clang/AST/StmtOpenMP.h" // #include "clang/Basic/OpenMPKinds.h" #include "clang/Basic/TargetInfo.h" #include "llvm/ADT/SmallPtrSet.h" @@ -1609,6 +1610,31 @@ // OpenMP clauses printing methods //===----------------------------------------------------------------------===// +void OMPClausePrinter::VisitOMPWhenClause(OMPWhenClause *Node) { + OMPTraitInfo &TI = Node->getTI(); + if (TI.Sets.empty()) + OS << "default("; + else + OS << "when("; + TI.print(OS, Policy); + Stmt *S = Node->getDirective(); + if (S) { + OS << ":"; + OMPExecutableDirective *D = cast(S); + auto DKind = D->getDirectiveKind(); + OS << getOpenMPDirectiveName(DKind); + + OMPClausePrinter Printer(OS, Policy); + ArrayRef Clauses = D->clauses(); + for (auto *Clause : Clauses) + if (Clause && !Clause->isImplicit()){ + OS << ' '; + Printer.Visit(Clause); + } + } + OS<<")"; +} + void OMPClausePrinter::VisitOMPIfClause(OMPIfClause *Node) { OS << "if("; if (Node->getNameModifier() != OMPD_unknown) @@ -2382,6 +2408,7 @@ const PrintingPolicy &Policy) const { bool FirstSet = true; for (const OMPTraitSet &Set : Sets) { + if (!FirstSet) OS << ", "; FirstSet = false; @@ -2389,6 +2416,7 @@ bool FirstSelector = true; for (const OMPTraitSelector &Selector : Set.Selectors) { + if (!FirstSelector) OS << ", "; FirstSelector = false; diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -657,11 +657,21 @@ bool ForceNoStmt) { OMPClausePrinter Printer(OS, Policy); ArrayRef Clauses = S->clauses(); - for (auto *Clause : Clauses) - if (Clause && !Clause->isImplicit()) { + + for (auto *Clause : Clauses){ + if (Clause && !Clause->isImplicit()){ OS << ' '; Printer.Visit(Clause); - } + if (isa(S)){ + OMPWhenClause *WhenClause = dyn_cast(Clause); + if (WhenClause!=nullptr){ + if (WhenClause->getDKind() != llvm::omp::OMPD_unknown){ + Printer.VisitOMPWhenClause(WhenClause); + } + } + } + } + } OS << NL; if (!ForceNoStmt && S->hasAssociatedStmt()) PrintStmt(S->getRawStmt()); diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -2458,7 +2458,7 @@ bool HasAssociatedStatement = true; switch (DKind) { - case OMPD_metadirective: { + case OMPD_metadirective:{ ConsumeToken(); SmallVector VMIs; @@ -2470,10 +2470,12 @@ BalancedDelimiterTracker T(*this, tok::l_paren, tok::annot_pragma_openmp_end); - while (Tok.isNot(tok::annot_pragma_openmp_end)) { + + while (Tok.isNot(tok::annot_pragma_openmp_end)){ OpenMPClauseKind CKind = Tok.isAnnotation() ? OMPC_unknown : getOpenMPClauseKind(PP.getSpelling(Tok)); + SourceLocation Loc = ConsumeToken(); // Parse '('. @@ -2491,7 +2493,7 @@ return Directive; } - // Parse ':' + // Parse ':' // You have parsed the OpenMP Context in the meta directive if (Tok.is(tok::colon)) ConsumeAnyToken(); else { @@ -2500,6 +2502,7 @@ return Directive; } } + // Skip Directive for now. We will parse directive in the second iteration int paren = 0; while (Tok.isNot(tok::r_paren) || paren != 0) { @@ -2513,86 +2516,111 @@ TPA.Commit(); return Directive; } - ConsumeAnyToken(); - } + ConsumeAnyToken(); + } + // Parse ')' if (Tok.is(tok::r_paren)) T.consumeClose(); - + VariantMatchInfo VMI; TI.getAsVariantMatchInfo(ASTContext, VMI); - - VMIs.push_back(VMI); - } - + + if (CKind == OMPC_when ) + VMIs.push_back(VMI); + } + + // This is the end of the first iteration + // The pointer is moved back TPA.Revert(); // End of the first iteration. Parser is reset to the start of metadirective - + TargetOMPContext OMPCtx(ASTContext, /* DiagUnknownTrait */ nullptr, /* CurrentFunctionDecl */ nullptr, ArrayRef()); - - // A single match is returned for OpenMP 5.0 - int BestIdx = getBestVariantMatchForContext(VMIs, OMPCtx); - - int Idx = 0; - // In OpenMP 5.0 metadirective is either replaced by another directive or - // ignored. - // TODO: In OpenMP 5.1 generate multiple directives based upon the matches - // found by getBestWhenMatchForContext. - while (Tok.isNot(tok::annot_pragma_openmp_end)) { - // OpenMP 5.0 implementation - Skip to the best index found. - if (Idx++ != BestIdx) { - ConsumeToken(); // Consume clause name - T.consumeOpen(); // Consume '(' - int paren = 0; - // Skip everything inside the clause - while (Tok.isNot(tok::r_paren) || paren != 0) { - if (Tok.is(tok::l_paren)) - paren++; - if (Tok.is(tok::r_paren)) - paren--; - ConsumeAnyToken(); - } - // Parse ')' - if (Tok.is(tok::r_paren)) - T.consumeClose(); - continue; - } - + + // Array SortedCluases will be used for sorting clauses + // based on the context selector score + SmallVector> SortedCluases; + + // The function will get the score for each clause and sort it + // based on the score number + + getArrayVariantMatchForContext(VMIs, OMPCtx, SortedCluases) ; + + ParseScope OMPDirectiveScope(this, ScopeFlags); + Actions.StartOpenMPDSABlock(DKind, DirName, Actions.getCurScope(), Loc); + + while(Tok.isNot(tok::annot_pragma_openmp_end)){ + OpenMPClauseKind CKind = Tok.isAnnotation() ? OMPC_unknown : getOpenMPClauseKind(PP.getSpelling(Tok)); - SourceLocation Loc = ConsumeToken(); - - // Parse '('. - T.consumeOpen(); - - // Skip ContextSelectors for when clause - if (CKind == OMPC_when) { - OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo(); - // parse and skip the ContextSelectors - parseOMPContextSelectors(Loc, TI); - - // Parse ':' - ConsumeAnyToken(); - } - - // If no directive is passed, skip in OpenMP 5.0. - // TODO: Generate nothing directive from OpenMP 5.1. - if (Tok.is(tok::r_paren)) { - SkipUntil(tok::annot_pragma_openmp_end); - break; - } - - // Parse Directive - ReadDirectiveWithinMetadirective = true; - Directive = ParseOpenMPDeclarativeOrExecutableDirective(StmtCtx); - ReadDirectiveWithinMetadirective = false; - break; + + Actions.StartOpenMPClause(CKind); + OMPClause *Clause = ParseOpenMPMetaDirectiveClause( DKind, CKind); + + FirstClauses[(unsigned) CKind].setInt(true); + if (Clause) { + FirstClauses[(unsigned) CKind].setPointer(Clause); + Clauses.push_back(Clause); + } + + if (Tok.is(tok::comma)) + ConsumeToken(); + + Actions.EndOpenMPClause(); + + if (Tok.is(tok::r_paren)) + ConsumeAnyToken(); + } - break; - } + + // End location of the directive + EndLoc = Tok.getLocation(); + + //Consume final annot_pragma_openmp_end + ConsumeAnnotationToken(); + + SmallVector Clauses_new; + unsigned count = 0; + + // SortedClauses has index and score, and are sorted with respect to the + // the context score. The first iteration will take each element. The + // first element will have the highiest score. The element will have the + // index of the cluase for the best score. The second iteration tries to + // find that specific clause by checking the count numder with the + // index (Iteration1.first) + for ( auto &Iteration1 : SortedCluases){ + count = 0; + for ( auto &Iteration2 : Clauses){ + if ( count == Iteration1.first ){ + Clauses_new.push_back(Iteration2); + break; + } else count++; + } + } + // Adding the default clasue at the end + Clauses_new.push_back(Clauses.back()); + + // Parsing the OpenMP region which will take the + // metadirective + + Actions.ActOnOpenMPRegionStart(DKind, getCurScope()); + ParsingOpenMPDirectiveRAII NormalScope(*this, /*value=*/ false); + // This is parsing the region + StmtResult AStmt = ParseStatement(); + + StmtResult AssociatedStmt = (Sema::CompoundScopeRAII(Actions), AStmt); + // Ending of the parallel region + AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses_new); + Directive = Actions.ActOnOpenMPExecutableDirective( + DKind, DirName, CancelRegion, Clauses_new, AssociatedStmt.get(), Loc, + EndLoc); + // Exit scope + Actions.EndOpenMPDSABlock(Directive.get()); + break; + } // end of case OMPD_metadirective: case OMPD_threadprivate: { // FIXME: Should this be permitted in C++? if ((StmtCtx & ParsedStmtContext::AllowDeclarationsInC) == @@ -3050,6 +3078,164 @@ return Actions.ActOnOpenMPUsesAllocatorClause(Loc, T.getOpenLocation(), T.getCloseLocation(), Data); } +/// Parsing of OpenMP MetaDirective Clauses + +OMPClause *Parser::ParseOpenMPMetaDirectiveClause(OpenMPDirectiveKind DKind, + OpenMPClauseKind CKind) { + OMPClause *Clause = nullptr; + bool ErrorFound = false; + bool WrongDirective = false; + SmallVector, + llvm::omp::Clause_enumSize + 1> + FirstClauses(llvm::omp::Clause_enumSize + 1); + + // Check if it is called from metadirective. + if (DKind != OMPD_metadirective) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + ErrorFound = true; + } + + // Check if clause is allowed for the given directive. + if (CKind != OMPC_unknown && + !isAllowedClauseForDirective(DKind, CKind, getLangOpts().OpenMP)) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + ErrorFound = true; + WrongDirective = true; + } + + // Check if clause is not allowed + if (CKind == OMPC_unknown) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << "Unknown clause: Not allowed"; + ErrorFound = true; + WrongDirective = true; + } + + if (CKind == OMPC_default || CKind == OMPC_when) { + SourceLocation Loc = ConsumeToken(); + SourceLocation DelimLoc; + // Parse '('. + BalancedDelimiterTracker T(*this, tok::l_paren, + tok::annot_pragma_openmp_end); + if (T.expectAndConsume(diag::err_expected_lparen_after, + getOpenMPClauseName(CKind).data())) + return nullptr; + + OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo(); + if (CKind == OMPC_when) { + // parse and get condition expression to pass to the When clause + parseOMPContextSelectors(Loc, TI); + + // Parse ':' + if (Tok.is(tok::colon)) + ConsumeAnyToken(); + else { + Diag(Tok, diag::warn_pragma_expected_colon) << "when clause"; + return nullptr; + } + } + + // Parse Directive + OpenMPDirectiveKind DirKind = OMPD_unknown; + SmallVector Clauses; + StmtResult AssociatedStmt; + StmtResult Directive = StmtError(); + + if (Tok.isNot(tok::r_paren)) { + ParsingOpenMPDirectiveRAII DirScope(*this); + ParenBraceBracketBalancer BalancerRAIIObj(*this); + DeclarationNameInfo DirName; + unsigned ScopeFlags = Scope::FnScope | Scope::DeclScope | + Scope::CompoundStmtScope | + Scope::OpenMPDirectiveScope; + + DirKind = parseOpenMPDirectiveKind(*this); + ConsumeToken(); + ParseScope OMPDirectiveScope(this, ScopeFlags); + Actions.StartOpenMPDSABlock(DirKind, DirName, Actions.getCurScope(), Loc); + + int paren = 0; + + while (Tok.isNot(tok::r_paren) || paren != 0) { + if (Tok.is(tok::l_paren)) + paren++; + if (Tok.is(tok::r_paren)) + paren--; + + OpenMPClauseKind CKind = Tok.isAnnotation() + ? OMPC_unknown + : getOpenMPClauseKind(PP.getSpelling(Tok)); + + if (CKind == OMPC_unknown && + !isAllowedClauseForDirective(DirKind, CKind, getLangOpts().OpenMP)) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + ErrorFound = true; + WrongDirective = true; + } + + Actions.StartOpenMPClause(CKind); + OMPClause *Clause = ParseOpenMPClause( + DirKind, CKind, !FirstClauses[(unsigned)CKind].getInt()); + FirstClauses[(unsigned)CKind].setInt(true); + if (Clause) { + FirstClauses[(unsigned)CKind].setPointer(Clause); + Clauses.push_back(Clause); + } + + // Skip ',' if any. + if (Tok.is(tok::comma)) + ConsumeToken(); + Actions.EndOpenMPClause(); + } + + Actions.ActOnOpenMPRegionStart(DirKind, getCurScope()); + ParsingOpenMPDirectiveRAII NormalScope(*this, /*Value=*/false); + + /* Get Stmt and revert back */ + TentativeParsingAction TPA(*this); + while (Tok.isNot(tok::annot_pragma_openmp_end)) { + ConsumeAnyToken(); + } + + ConsumeAnnotationToken(); + ParseScope InnerStmtScope(this, Scope::DeclScope, + getLangOpts().C99 || getLangOpts().CPlusPlus, + Tok.is(tok::l_brace)); + + StmtResult AStmt = ParseStatement(); + InnerStmtScope.Exit(); + TPA.Revert(); + /* End Get Stmt */ + + AssociatedStmt = (Sema::CompoundScopeRAII(Actions), AStmt); + AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses); + + Directive = Actions.ActOnOpenMPExecutableDirective( + DirKind, DirName, OMPD_unknown, llvm::makeArrayRef(Clauses), + AssociatedStmt.get(), Loc, Tok.getLocation()); + + Actions.EndOpenMPDSABlock(Directive.get()); + OMPDirectiveScope.Exit(); + } + // Parse ')' + T.consumeClose(); + + if (WrongDirective) + return nullptr; + + Clause = Actions.ActOnOpenMPWhenClause(TI, DirKind, Directive, Loc, + DelimLoc, Tok.getLocation()); + } else { + ErrorFound = false; + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + } + + return ErrorFound ? nullptr : Clause; +} /// Parsing of OpenMP clauses. /// diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -37,6 +37,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Frontend/OpenMP/OMPContext.h" #include using namespace clang; @@ -45,7 +46,7 @@ //===----------------------------------------------------------------------===// // Stack of data-sharing attributes for variables //===----------------------------------------------------------------------===// - +git static const Expr *checkMapClauseExpressionBase( Sema &SemaRef, Expr *E, OMPClauseMappableExprCommon::MappableExprComponentList &CurComponents, @@ -3930,6 +3931,7 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) { switch (DKind) { + case OMPD_metadirective: case OMPD_parallel: case OMPD_parallel_for: case OMPD_parallel_for_simd: @@ -4339,7 +4341,6 @@ case OMPD_declare_variant: case OMPD_begin_declare_variant: case OMPD_end_declare_variant: - case OMPD_metadirective: llvm_unreachable("OpenMP Directive is not allowed"); case OMPD_unknown: default: @@ -4521,7 +4522,7 @@ } StmtResult Sema::ActOnOpenMPRegionEnd(StmtResult S, - ArrayRef Clauses) { + ArrayRef Clauses){ handleDeclareVariantConstructTrait(DSAStack, DSAStack->getCurrentDirective(), /* ScopeEntry */ false); if (DSAStack->getCurrentDirective() == OMPD_atomic || @@ -4611,6 +4612,7 @@ << SourceRange(OC->getBeginLoc(), OC->getEndLoc()); ErrorFound = true; } + // OpenMP 5.0, 2.9.2 Worksharing-Loop Construct, Restrictions. // If an order(concurrent) clause is present, an ordered clause may not appear // on the same directive. @@ -4623,6 +4625,7 @@ } ErrorFound = true; } + if (isOpenMPWorksharingDirective(DSAStack->getCurrentDirective()) && isOpenMPSimdDirective(DSAStack->getCurrentDirective()) && OC && OC->getNumForLoops()) { @@ -4635,7 +4638,9 @@ } StmtResult SR = S; unsigned CompletedRegions = 0; + for (OpenMPDirectiveKind ThisCaptureRegion : llvm::reverse(CaptureRegions)) { + // Mark all variables in private list clauses as used in inner region. // Required for proper codegen of combined directives. // TODO: add processing for other clauses. @@ -4656,6 +4661,7 @@ } } } + if (ThisCaptureRegion == OMPD_target) { // Capture allocator traits in the target region. They are used implicitly // and, thus, are not captured by default. @@ -4671,6 +4677,7 @@ } } } + if (ThisCaptureRegion == OMPD_parallel) { // Capture temp arrays for inscan reductions and locals in aligned // clauses. @@ -4687,10 +4694,14 @@ } } } + if (++CompletedRegions == CaptureRegions.size()) DSAStack->setBodyComplete(); + SR = ActOnCapturedRegionEnd(SR.get()); + } + return SR; } @@ -5963,6 +5974,12 @@ llvm::SmallVector AllowedNameModifiers; switch (Kind) { + + case OMPD_metadirective: + Res = ActOnOpenMPMetaDirective(ClausesWithImplicit, AStmt, StartLoc, + EndLoc); + AllowedNameModifiers.push_back(OMPD_metadirective); + break; case OMPD_parallel: Res = ActOnOpenMPParallelDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc); @@ -7342,10 +7359,123 @@ FD->addAttr(NewAttr); } +StmtResult Sema::ActOnOpenMPMetaDirective(ArrayRef Clauses, + Stmt *AStmt, + SourceLocation StartLoc, + SourceLocation EndLoc){ + if (!AStmt) + return StmtError(); + + auto *CS = cast(AStmt); + // 1.2.2 OpenMP Language Terminology + // Structured block - An executable statement with a single entry at the + // top and a single exit at the bottom. + // The point of exit cannot be a branch out of the structured block. + // longjmp() and throw() must not violate the entry/exit criteria. + + CS->getCapturedDecl()->setNothrow(); + + StmtResult IfStmt = StmtError(); + Stmt *ElseStmt = nullptr; + + for (auto i = Clauses.rbegin(); i < Clauses.rend(); i++) { + OMPWhenClause *WhenClause = dyn_cast(*i); + Expr *WhenCondExpr = nullptr; + Stmt *ThenStmt = nullptr; + OpenMPDirectiveKind DKind = WhenClause->getDKind(); + + if (DKind != OMPD_unknown) + ThenStmt = CompoundStmt::Create(Context, {WhenClause->getDirective()}, + SourceLocation(), SourceLocation()); + + for (const OMPTraitSet &Set : WhenClause->getTI().Sets){ + for (const OMPTraitSelector &Selector : Set.Selectors){ + switch (Selector.Kind){ + case TraitSelector::device_arch:{ + bool archMatch = false; + for (const OMPTraitProperty &Property : Selector.Properties){ + for (auto &T : getLangOpts().OMPTargetTriples){ + if (T.getArchName() == Property.RawString){ + archMatch = true; + break; + } + } + if (archMatch) + break; + } + // Create a true/false boolean expression and assign to WhenCondExpr + auto *C = new (Context) + CXXBoolLiteralExpr(archMatch, Context.BoolTy, StartLoc); + WhenCondExpr = dyn_cast(C); + break; + } + case TraitSelector::user_condition:{ + assert(Selector.ScoreOrCondition && + "Ill-formed user condition, expected condition expression!"); + + WhenCondExpr = Selector.ScoreOrCondition; + break; + } + case TraitSelector::implementation_vendor:{ + bool vendorMatch = false; + for (const OMPTraitProperty &Property : Selector.Properties){ + for (auto &T : getLangOpts().OMPTargetTriples){ + if (T.getVendorName() == Property.RawString){ + vendorMatch = true; + break; + } + } + if (vendorMatch) + break; + } + // Create a true/false boolean expression and assign to WhenCondExpr + auto *WhenCondition = new (Context) + CXXBoolLiteralExpr(vendorMatch, Context.BoolTy, StartLoc); + WhenCondExpr = dyn_cast(WhenCondition); + break; + } + case TraitSelector::device_isa: + case TraitSelector::device_kind: + case TraitSelector::implementation_extension: + default: + break; + } + } + } + + if (WhenCondExpr == nullptr) { + if (ElseStmt != nullptr) { + Diag(WhenClause->getBeginLoc(), diag::err_omp_misplaced_default_clause); + return StmtError(); + } + if (DKind == OMPD_unknown) + ElseStmt = CompoundStmt::Create(Context, {CS->getCapturedStmt()}, + SourceLocation(), SourceLocation()); + else + ElseStmt = ThenStmt; + continue; + } + + if (ThenStmt == NULL) + ThenStmt = CompoundStmt::Create(Context, {CS->getCapturedStmt()}, + SourceLocation(), SourceLocation()); + + IfStmt = + ActOnIfStmt(SourceLocation(), /*false*/ IfStatementKind::Ordinary, SourceLocation(), nullptr, + ActOnCondition(getCurScope(), SourceLocation(), + WhenCondExpr, Sema::ConditionKind::Boolean), + SourceLocation(), ThenStmt, SourceLocation(), ElseStmt); + ElseStmt = IfStmt.get(); + } + + return OMPMetaDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, + IfStmt.get()); +} + StmtResult Sema::ActOnOpenMPParallelDirective(ArrayRef Clauses, Stmt *AStmt, SourceLocation StartLoc, - SourceLocation EndLoc) { + SourceLocation EndLoc){ if (!AStmt) return StmtError(); @@ -14837,6 +14967,17 @@ return std::string(Out.str()); } +OMPClause * +Sema::ActOnOpenMPWhenClause(OMPTraitInfo &TI, OpenMPDirectiveKind DKind, + StmtResult Directive, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc) { + return new (Context) + OMPWhenClause(TI, DKind, Directive.get(), StartLoc, LParenLoc, EndLoc); +} + + + + OMPClause *Sema::ActOnOpenMPDefaultClause(DefaultKind Kind, SourceLocation KindKwLoc, SourceLocation StartLoc, diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp --- a/clang/lib/Sema/SemaStmt.cpp +++ b/clang/lib/Sema/SemaStmt.cpp @@ -4792,7 +4792,7 @@ getASTContext(), S, static_cast(RSI->CapRegionKind), Captures, CaptureInits, CD, RD); - CD->setBody(Res->getCapturedStmt()); + CD->setBody(Res->getCapturedStmt()); RD->completeDefinition(); return Res; diff --git a/clang/test/OpenMP/metadirective_ast_print_new_1.cpp b/clang/test/OpenMP/metadirective_ast_print_new_1.cpp new file mode 100644 --- /dev/null +++ b/clang/test/OpenMP/metadirective_ast_print_new_1.cpp @@ -0,0 +1,20 @@ +// RUN: %clang_cc1 -verify -fopenmp -ast-print %s -o - | FileCheck %s +// expected-no-diagnostics +void bar(){ + int i=0; +} + +void myfoo(void){ + + int N = 13; + int b,n; + int a[100]; + + #pragma omp metadirective when(user={condition(N>10)}: target teams ) default(parallel for) + for (int i = 0; i < N; i++) + bar(); + +} + +// CHECK: void bar() +// CHECK: #pragma omp metadirective when(user={condition(N > 10)}: target teams) default(parallel for) diff --git a/clang/test/OpenMP/metadirective_ast_print_new_2.cpp b/clang/test/OpenMP/metadirective_ast_print_new_2.cpp new file mode 100644 --- /dev/null +++ b/clang/test/OpenMP/metadirective_ast_print_new_2.cpp @@ -0,0 +1,29 @@ +// RUN: %clang_cc1 -verify -fopenmp -ast-print %s -o - | FileCheck %s +// expected-no-diagnostics + +void bar(){ + int i=0; +} + +void myfoo(void){ + + int N = 13; + int b,n; + int a[100]; + + + #pragma omp metadirective when (user = {condition(N>10)}: target teams distribute parallel for ) \ + when (user = {condition(N==10)}: parallel for )\ + when (user = {condition(N==13)}: parallel for simd) \ + when ( device={arch("arm")}: target teams num_teams(512) thread_limit(32))\ + when ( device={arch("nvptx")}: target teams num_teams(512) thread_limit(32))\ + default ( parallel for)\ + + { for (int i = 0; i < N; i++) + bar(); + } +} + +// CHECK: bar() +// CHECK: myfoo +// CHECK: #pragma omp metadirective when(user={condition(N > 10)}: target teams distribute parallel for) when(user={condition(N == 13)}: parallel for simd) when(device={arch(nvptx)}: target teams) diff --git a/clang/test/OpenMP/metadirective_ast_print_new_3.cpp b/clang/test/OpenMP/metadirective_ast_print_new_3.cpp new file mode 100644 --- /dev/null +++ b/clang/test/OpenMP/metadirective_ast_print_new_3.cpp @@ -0,0 +1,22 @@ +// RUN: %clang_cc1 -verify -fopenmp -ast-print %s -o - | FileCheck %s +// expected-no-diagnostics + +int main() { + int N = 15; +#pragma omp metadirective when(user = {condition(N > 10)} : parallel for)\ + default(target teams) + for (int i = 0; i < N; i++) + ; + + +#pragma omp metadirective when(device = {arch("nvptx64")}, user = {condition(N >= 100)} : parallel for)\ + default(target parallel for) + for (int i = 0; i < N; i++) + ; + return 0; +} + + + +// CHECK: #pragma omp metadirective when(user={condition(N > 10)}: parallel for) default(target teams) +// CHECK: #pragma omp metadirective when(device={arch(nvptx64)}, user={condition(N >= 100)}: parallel for) default(target parallel for) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPContext.h b/llvm/include/llvm/Frontend/OpenMP/OMPContext.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPContext.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPContext.h @@ -189,6 +189,15 @@ int getBestVariantMatchForContext(const SmallVectorImpl &VMIs, const OMPContext &Ctx); +/// Sort array \p A of clause index with score +/// This will be used to produce AST clauses +/// in a sorted order with the clause with the highiest order +/// on the top and default clause at the bottom +void getArrayVariantMatchForContext( + const SmallVectorImpl &VMIs, const OMPContext &Ctx, + SmallVector> &A); + +// new-- } // namespace omp template <> struct DenseMapInfo { diff --git a/llvm/lib/Frontend/OpenMP/OMPContext.cpp b/llvm/lib/Frontend/OpenMP/OMPContext.cpp --- a/llvm/lib/Frontend/OpenMP/OMPContext.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPContext.cpp @@ -20,6 +20,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include #define DEBUG_TYPE "openmp-ir-builder" using namespace llvm; @@ -339,6 +340,45 @@ return Score; } +/// Takes \p VMI and \p Ctx and sort the +/// scores using \p VectorOfClauses +void llvm::omp::getArrayVariantMatchForContext(const SmallVectorImpl &VMIs, + const OMPContext &Ctx, SmallVector> &VectorOfClauses){ + + //APInt BestScore(64, 0); + APInt Score (64, 0); + //The MapOfClauses will contain the index of the cluase and its context socre + llvm::DenseMap MapOfCluases; + + for (unsigned u = 0, e = VMIs.size(); u < e; ++u) { + const VariantMatchInfo &VMI = VMIs[u]; + + SmallVector ConstructMatches; + // If the variant is not applicable its not the best. + if (!isVariantApplicableInContextHelper(VMI, Ctx, &ConstructMatches, + /* DeviceSetOnly */ false)){ + Score = 0; + // adding index and its corresdoning score + MapOfCluases.insert({u, Score}); + continue; + } + // Else get the score + Score = getVariantMatchScore(VMI, Ctx, ConstructMatches); + MapOfCluases.insert({u, Score}); + } + + for (auto& it : MapOfCluases) + VectorOfClauses.push_back(it); + + // The following Lamda will sort the VectorOfClauses based on the score + std::sort(VectorOfClauses.begin(), VectorOfClauses.end(), [] (std::pair&a, + std::pair&b){ + return a.second.ugt(b.second); + }); +} + + + int llvm::omp::getBestVariantMatchForContext( const SmallVectorImpl &VMIs, const OMPContext &Ctx) {