diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -8570,6 +8570,18 @@ } }; +/// This represents 'when' clause in the '#pragma omp ...' directive +/// +/// \code +/// #pragma omp metadirective when(user={condition(N<10)} : parallel for) \ +/// when(user={condition(N> 10)}: parallel)\ +/// \endcode +/// In this example directive '#pragma omp metadirective' has two 'when' +/// clauses with user defined conditions. + + + + /// This class implements a simple visitor for OMPClause /// subclasses. template class Ptr, typename RetTy> @@ -8611,7 +8623,9 @@ template class ConstOMPClauseVisitor : public OMPClauseVisitorBase {}; - + + +/* class OMPClausePrinter final : public OMPClauseVisitor { raw_ostream &OS; const PrintingPolicy &Policy; @@ -8624,12 +8638,16 @@ public: OMPClausePrinter(raw_ostream &OS, const PrintingPolicy &Policy) : OS(OS), Policy(Policy) {} - + + // void VisitOMPWhenClause(OMPWhenClause *Node); + #define GEN_CLANG_CLAUSE_CLASS #define CLAUSE_CLASS(Enum, Str, Class) void Visit##Class(Class *S); #include "llvm/Frontend/OpenMP/OMP.inc" }; +*/ + struct OMPTraitProperty { llvm::omp::TraitProperty Kind = llvm::omp::TraitProperty::invalid; @@ -8728,6 +8746,11 @@ llvm::StringMap FeatureMap; }; + + + + + /// Contains data for OpenMP directives: clauses, children /// expressions/statements (helpers for codegen) and associated statement, if /// any. @@ -8872,6 +8895,94 @@ } }; +/// This captures 'when' clause in the '#pragma omp metadirective' +/// \code +/// #pragma omp metadirective when(user={consition(N<100)}:parallel for) +/// \endcode +/// In the above example, the metadirective clause has a condition which when +/// satisfied will use the parallel for directive with the code enclosed by the +/// directive. +class OMPWhenClause final : public OMPClause { + friend class OMPClauseReader; + + OMPTraitInfo *TI; + OpenMPDirectiveKind DKind; + Stmt *Directive; + + /// Location of '('. + SourceLocation LParenLoc; + +public: + /// Build 'when' clause with argument \a A ('none' or 'shared'). + /// + /// \param T TraitInfor containing information about the context selector + /// \param DKind The directive associated with the when clause + /// \param D The statement associated with the when clause + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + OMPWhenClause(OMPTraitInfo &T, OpenMPDirectiveKind dKind, Stmt *D, + SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_when, StartLoc, EndLoc), TI(&T), DKind(dKind), + Directive(D), LParenLoc(LParenLoc) {} + + /// Build an empty clause. + OMPWhenClause() + : OMPClause(llvm::omp::OMPC_when, SourceLocation(), SourceLocation()) {} + + /// Sets the location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + /// Returns the directive variant kind + OpenMPDirectiveKind getDKind() { return DKind; } + + Stmt *getDirective() const { return Directive; } + + /// Returns the OMPTraitInfo + OMPTraitInfo &getTI() { return *TI; } + + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + + const_child_range children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_when; + } +}; + +class OMPClausePrinter final : public OMPClauseVisitor { + raw_ostream &OS; + const PrintingPolicy &Policy; + + /// Process clauses with list of variables. + template void VisitOMPClauseList(T *Node, char StartSym); + /// Process motion clauses. + template void VisitOMPMotionClause(T *Node); + +public: + OMPClausePrinter(raw_ostream &OS, const PrintingPolicy &Policy) + : OS(OS), Policy(Policy) {} + + void VisitOMPWhenClause(OMPWhenClause *Node); + +#define GEN_CLANG_CLAUSE_CLASS +#define CLAUSE_CLASS(Enum, Str, Class) void Visit##Class(Class *S); +#include "llvm/Frontend/OpenMP/OMP.inc" +}; } // namespace clang #endif // LLVM_CLANG_AST_OPENMPCLAUSE_H diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -501,7 +501,8 @@ /// Process clauses with pre-initis. bool VisitOMPClauseWithPreInit(OMPClauseWithPreInit *Node); bool VisitOMPClauseWithPostUpdate(OMPClauseWithPostUpdate *Node); - + bool VisitOMPWhenClause(OMPWhenClause *C); + bool PostVisitStmt(Stmt *S); }; @@ -3136,6 +3137,18 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPWhenClause(OMPWhenClause *C) { + for (const OMPTraitSet &Set : C->getTI().Sets) { + for (const OMPTraitSelector &Selector : Set.Selectors) { + if (Selector.Kind == llvm::omp::TraitSelector::user_condition && + Selector.ScoreOrCondition) + TRY_TO(TraverseStmt(Selector.ScoreOrCondition)); + } + } + return true; +} + template bool RecursiveASTVisitor::VisitOMPDefaultClause(OMPDefaultClause *) { return true; diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -5476,6 +5476,7 @@ Stmt *AssociatedStmt, Stmt *IfStmt); static OMPMetaDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses, EmptyShell); + Stmt *getIfStmt() const { return IfStmt; } static bool classof(const Stmt *T) { diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10848,6 +10848,9 @@ "'%0' clause requires 'dispatch' context selector">; def err_omp_append_args_with_varargs : Error< "'append_args' is not allowed with varargs functions">; +def err_omp_misplaced_default_clause : Error< + "misplaced default clause! Only one default clause is allowed in" + "metadirective in the end">; } // end of OpenMP category let CategoryName = "Related Result Type Issue" in { diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h --- a/clang/include/clang/Parse/Parser.h +++ b/clang/include/clang/Parse/Parser.h @@ -3291,6 +3291,12 @@ /// \param StmtCtx The context in which we're parsing the directive. StmtResult ParseOpenMPDeclarativeOrExecutableDirective(ParsedStmtContext StmtCtx); + /// Parse clause for metadirective + /// \param Dkind Kind of current directive + /// \param CKind Kind of current clause + /// + OMPClause *ParseOpenMPMetaDirectiveClause(OpenMPDirectiveKind DKind, + OpenMPClauseKind CKind); /// Parses clause of kind \a CKind for directive of a kind \a Kind. /// /// \param DKind Kind of current directive. diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -66,6 +66,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TinyPtrVector.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Frontend/OpenMP/OMPContext.h" #include #include #include @@ -10682,6 +10683,10 @@ OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName, OpenMPDirectiveKind CancelRegion, ArrayRef Clauses, Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc); + StmtResult ActOnOpenMPExecutableMetaDirective( // This might be needed later + OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName, + OpenMPDirectiveKind CancelRegion, ArrayRef Clauses, + Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc); /// Called on well-formed '\#pragma omp parallel' after parsing /// of the associated statement. StmtResult ActOnOpenMPParallelDirective(ArrayRef Clauses, @@ -11127,7 +11132,10 @@ SourceLocation LParenLoc, SourceLocation EndLoc); /// Called on well-formed 'when' clause. - OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, SourceLocation StartLoc, + /// Abid + OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, OpenMPDirectiveKind DKind, + StmtResult Directive, + SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); /// Called on well-formed 'default' clause. diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -1609,6 +1609,77 @@ // OpenMP clauses printing methods //===----------------------------------------------------------------------===// + +void OMPClausePrinter::VisitOMPWhenClause(OMPWhenClause *Node) { + + if (Node->getTI().Sets.size() == 0) { + OS << "default("; + return; + } + OS << "when("; + int count = 0; + for (const OMPTraitSet &Set : Node->getTI().Sets) { + if (count == 0) + count++; + else + OS << ", "; + for (const OMPTraitSelector &Selector : Set.Selectors) { + switch (Selector.Kind) { + case TraitSelector::device_kind: { + OS << "device={kind("; + for (const OMPTraitProperty &Property : Selector.Properties) { + OS << Property.RawString; + } + OS << ")}"; + break; + } + case TraitSelector::device_arch: { + OS << "device={arch("; + for (const OMPTraitProperty &Property : Selector.Properties) { + OS << Property.RawString; + } + OS << ")}"; + break; + } + case TraitSelector::device_isa: { + OS << "device={isa("; + for (const OMPTraitProperty &Property : Selector.Properties) { + OS << Property.RawString; + } + OS << ")}"; + break; + } + case TraitSelector::implementation_vendor: { + OS << "implementation={vendor("; + for (const OMPTraitProperty &Property : Selector.Properties) { + OS << Property.RawString; + } + OS << ")}"; + break; + } + case TraitSelector::implementation_extension: { + OS << "implementation={extension("; + for (const OMPTraitProperty &Property : Selector.Properties) { + OS << Property.RawString; + } + OS << ")}"; + break; + } + case TraitSelector::user_condition: { + OS << "user={condition("; + Selector.ScoreOrCondition->printPretty(OS, nullptr, Policy, 0); + OS << ")}"; + break; + } + + default: + break; + } + } + } + OS << ": "; +} + void OMPClausePrinter::VisitOMPIfClause(OMPIfClause *Node) { OS << "if("; if (Node->getNameModifier() != OMPD_unknown) diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -655,12 +655,28 @@ void StmtPrinter::PrintOMPExecutableDirective(OMPExecutableDirective *S, bool ForceNoStmt) { + + OMPClausePrinter Printer(OS, Policy); ArrayRef Clauses = S->clauses(); - for (auto *Clause : Clauses) + for (auto *Clause : Clauses){ + if (Clause && !Clause->isImplicit()) { OS << ' '; Printer.Visit(Clause); + + if (dyn_cast(S)){ + + OMPWhenClause *c = dyn_cast(Clause); + if (c!=NULL){ + if (c->getDKind() != llvm::omp::OMPD_unknown){ + Printer.VisitOMPWhenClause(c); + OS << llvm::omp::getOpenMPDirectiveName(c->getDKind()); + } + OS << ")"; + } + } + } } OS << NL; if (!ForceNoStmt && S->hasAssociatedStmt()) @@ -668,6 +684,7 @@ } void StmtPrinter::VisitOMPMetaDirective(OMPMetaDirective *Node) { + Indent() << "#pragma omp metadirective"; PrintOMPExecutableDirective(Node); } diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -2430,7 +2430,8 @@ /// StmtResult Parser::ParseOpenMPDeclarativeOrExecutableDirective(ParsedStmtContext StmtCtx) { - static bool ReadDirectiveWithinMetadirective = false; +// need to check about the following + static bool ReadDirectiveWithinMetadirective = false; // what is this? if (!ReadDirectiveWithinMetadirective) assert(Tok.isOneOf(tok::annot_pragma_openmp, tok::annot_attr_openmp) && "Not an OpenMP directive!"); @@ -2470,10 +2471,13 @@ BalancedDelimiterTracker T(*this, tok::l_paren, tok::annot_pragma_openmp_end); + while (Tok.isNot(tok::annot_pragma_openmp_end)) { + OpenMPClauseKind CKind = Tok.isAnnotation() ? OMPC_unknown : getOpenMPClauseKind(PP.getSpelling(Tok)); + SourceLocation Loc = ConsumeToken(); // Parse '('. @@ -2491,7 +2495,7 @@ return Directive; } - // Parse ':' + // Parse ':' // You have parsed the OpenMP Context in the meta directive if (Tok.is(tok::colon)) ConsumeAnyToken(); else { @@ -2499,8 +2503,10 @@ TPA.Commit(); return Directive; } - } + } // if (CKind == OMPC_when) statement ends + // Skip Directive for now. We will parse directive in the second iteration + // This need to be catched int paren = 0; while (Tok.isNot(tok::r_paren) || paren != 0) { if (Tok.is(tok::l_paren)) @@ -2513,86 +2519,105 @@ TPA.Commit(); return Directive; } - ConsumeAnyToken(); - } + ConsumeAnyToken(); + } // end of the while statement while (Tok.isNot(tok::r_paren) + // Parse ')' if (Tok.is(tok::r_paren)) T.consumeClose(); - + VariantMatchInfo VMI; TI.getAsVariantMatchInfo(ASTContext, VMI); - - VMIs.push_back(VMI); - } - + + if (CKind == OMPC_when ) + VMIs.push_back(VMI); + } // end of while (Tok.isNot(tok::annot_pragma_openmp_end)) + + // This is the end of the first iteration + // The pointer is moved back TPA.Revert(); // End of the first iteration. Parser is reset to the start of metadirective - + TargetOMPContext OMPCtx(ASTContext, /* DiagUnknownTrait */ nullptr, /* CurrentFunctionDecl */ nullptr, ArrayRef()); - - // A single match is returned for OpenMP 5.0 - int BestIdx = getBestVariantMatchForContext(VMIs, OMPCtx); - - int Idx = 0; - // In OpenMP 5.0 metadirective is either replaced by another directive or - // ignored. - // TODO: In OpenMP 5.1 generate multiple directives based upon the matches - // found by getBestWhenMatchForContext. - while (Tok.isNot(tok::annot_pragma_openmp_end)) { - // OpenMP 5.0 implementation - Skip to the best index found. - if (Idx++ != BestIdx) { - ConsumeToken(); // Consume clause name - T.consumeOpen(); // Consume '(' - int paren = 0; - // Skip everything inside the clause - while (Tok.isNot(tok::r_paren) || paren != 0) { - if (Tok.is(tok::l_paren)) - paren++; - if (Tok.is(tok::r_paren)) - paren--; - ConsumeAnyToken(); - } - // Parse ')' - if (Tok.is(tok::r_paren)) - T.consumeClose(); - continue; - } - + + // Array A will be used for sorting + SmallVector> A; + + // The function will get the score for each clause and sort it + // based on the score number + + getArrayVariantMatchForContext(VMIs, OMPCtx, A) ; + + ParseScope OMPDirectiveScope(this, ScopeFlags); + Actions.StartOpenMPDSABlock(DKind, DirName, Actions.getCurScope(), Loc); + + while(Tok.isNot(tok::annot_pragma_openmp_end)){ + OpenMPClauseKind CKind = Tok.isAnnotation() ? OMPC_unknown : getOpenMPClauseKind(PP.getSpelling(Tok)); - SourceLocation Loc = ConsumeToken(); - - // Parse '('. - T.consumeOpen(); - - // Skip ContextSelectors for when clause - if (CKind == OMPC_when) { - OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo(); - // parse and skip the ContextSelectors - parseOMPContextSelectors(Loc, TI); - - // Parse ':' - ConsumeAnyToken(); - } - - // If no directive is passed, skip in OpenMP 5.0. - // TODO: Generate nothing directive from OpenMP 5.1. - if (Tok.is(tok::r_paren)) { - SkipUntil(tok::annot_pragma_openmp_end); - break; - } - - // Parse Directive - ReadDirectiveWithinMetadirective = true; - Directive = ParseOpenMPDeclarativeOrExecutableDirective(StmtCtx); - ReadDirectiveWithinMetadirective = false; - break; - } - break; - } + + Actions.StartOpenMPClause(CKind); + OMPClause *Clause = ParseOpenMPMetaDirectiveClause( DKind, CKind); + + FirstClauses[(unsigned) CKind].setInt(true); + if (Clause) { + FirstClauses[(unsigned) CKind].setPointer(Clause); + Clauses.push_back(Clause); + }// end of if statement + + if (Tok.is(tok::comma)) + ConsumeToken(); + + Actions.EndOpenMPClause(); + + if (Tok.is(tok::r_paren)) + ConsumeAnyToken(); + + }// end of the while loop + + // End location of the directive + EndLoc = Tok.getLocation(); + + //Consume final annot_pragma_openmp_end + ConsumeAnnotationToken(); + + SmallVector Clauses_new; + unsigned count = 0; + + for ( auto &it1 : A){ + count = 0; + for ( auto &it2 : Clauses){ + if ( count == it1.first ){ + Clauses_new.push_back(it2); + break; + } else count++; + } + }// end of the for loop + + Clauses_new.push_back(Clauses.back()); + + // Parsing the OpenMP region which will take the + // metadirective + + Actions.ActOnOpenMPRegionStart(DKind, getCurScope()); + ParsingOpenMPDirectiveRAII NormalScope(*this, /*value=*/ false); + // This is parsing the region + StmtResult AStmt = ParseStatement(); + + StmtResult AssociatedStmt = (Sema::CompoundScopeRAII(Actions), AStmt); + // Ending of the parallel region + AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses_new); + Directive = Actions.ActOnOpenMPExecutableDirective( + DKind, DirName, CancelRegion, Clauses_new, AssociatedStmt.get(), Loc, + EndLoc); + // Exit scope + Actions.EndOpenMPDSABlock(Directive.get()); + OMPDirectiveScope.Exit(); + break; + } // end of case OMPD_metadirective: case OMPD_threadprivate: { // FIXME: Should this be permitted in C++? if ((StmtCtx & ParsedStmtContext::AllowDeclarationsInC) == @@ -3050,6 +3075,174 @@ return Actions.ActOnOpenMPUsesAllocatorClause(Loc, T.getOpenLocation(), T.getCloseLocation(), Data); } +/// Parsing of OpenMP MetaDirective Clauses + +OMPClause *Parser::ParseOpenMPMetaDirectiveClause(OpenMPDirectiveKind DKind, + OpenMPClauseKind CKind) { + OMPClause *Clause = nullptr; + bool ErrorFound = false; + bool WrongDirective = false; + SmallVector, + llvm::omp::Clause_enumSize + 1> + FirstClauses(llvm::omp::Clause_enumSize + 1); + + // Check if it is called from metadirective. + if (DKind != OMPD_metadirective) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + ErrorFound = true; + } + + // Check if clause is allowed for the given directive. + if (CKind != OMPC_unknown && + !isAllowedClauseForDirective(DKind, CKind, getLangOpts().OpenMP)) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + ErrorFound = true; + WrongDirective = true; + } + + // Check if clause is not allowed + if (CKind == OMPC_unknown) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << "Unknown clause: Not allowed"; + ErrorFound = true; + WrongDirective = true; + } + + + if (CKind == OMPC_default || CKind == OMPC_when) { + SourceLocation Loc = ConsumeToken(); + SourceLocation DelimLoc; + // Parse '('. + BalancedDelimiterTracker T(*this, tok::l_paren, + tok::annot_pragma_openmp_end); + if (T.expectAndConsume(diag::err_expected_lparen_after, + getOpenMPClauseName(CKind).data())) + return nullptr; + + OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo(); + if (CKind == OMPC_when) { + // parse and get condition expression to pass to the When clause + parseOMPContextSelectors(Loc, TI); + + // Parse ':' + if (Tok.is(tok::colon)) + ConsumeAnyToken(); + else { + Diag(Tok, diag::warn_pragma_expected_colon) << "when clause"; + return nullptr; + } + } + + // Parse Directive + OpenMPDirectiveKind DirKind = OMPD_unknown; + SmallVector Clauses; + StmtResult AssociatedStmt; + StmtResult Directive = StmtError(); + + if (Tok.isNot(tok::r_paren)) { + ParsingOpenMPDirectiveRAII DirScope(*this); + ParenBraceBracketBalancer BalancerRAIIObj(*this); + DeclarationNameInfo DirName; + unsigned ScopeFlags = Scope::FnScope | Scope::DeclScope | + Scope::CompoundStmtScope | + Scope::OpenMPDirectiveScope; + + DirKind = parseOpenMPDirectiveKind(*this); + ConsumeToken(); + ParseScope OMPDirectiveScope(this, ScopeFlags); + Actions.StartOpenMPDSABlock(DirKind, DirName, Actions.getCurScope(), Loc); + + int paren = 0; + + while (Tok.isNot(tok::r_paren) || paren != 0) { + if (Tok.is(tok::l_paren)) + paren++; + if (Tok.is(tok::r_paren)) + paren--; + + OpenMPClauseKind CKind = Tok.isAnnotation() + ? OMPC_unknown + : getOpenMPClauseKind(PP.getSpelling(Tok)); + + if (CKind == OMPC_unknown && + !isAllowedClauseForDirective(DirKind, CKind, getLangOpts().OpenMP)) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + ErrorFound = true; + WrongDirective = true; + } + + + Actions.StartOpenMPClause(CKind); + + OMPClause *Clause = ParseOpenMPClause( + DirKind, CKind, !FirstClauses[(unsigned)CKind].getInt()); + FirstClauses[(unsigned)CKind].setInt(true); + if (Clause) { + FirstClauses[(unsigned)CKind].setPointer(Clause); + Clauses.push_back(Clause); + } + + // Skip ',' if any. + if (Tok.is(tok::comma)) + ConsumeToken(); + Actions.EndOpenMPClause(); + + + } + + Actions.ActOnOpenMPRegionStart(DirKind, getCurScope()); + ParsingOpenMPDirectiveRAII NormalScope(*this, /*Value=*/false); + + /* Get Stmt and revert back */ + TentativeParsingAction TPA(*this); + while (Tok.isNot(tok::annot_pragma_openmp_end)) { + ConsumeAnyToken(); + } + + ConsumeAnnotationToken(); + ParseScope InnerStmtScope(this, Scope::DeclScope, + getLangOpts().C99 || getLangOpts().CPlusPlus, + Tok.is(tok::l_brace)); + + StmtResult AStmt = ParseStatement(); + InnerStmtScope.Exit(); + TPA.Revert(); + /* End Get Stmt */ + + AssociatedStmt = (Sema::CompoundScopeRAII(Actions), AStmt); + AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses); + + Directive = Actions.ActOnOpenMPExecutableDirective( + DirKind, DirName, OMPD_unknown, llvm::makeArrayRef(Clauses), + AssociatedStmt.get(), Loc, Tok.getLocation()); + + + Actions.EndOpenMPDSABlock(Directive.get()); + OMPDirectiveScope.Exit(); + + } + + // Parse ')' + T.consumeClose(); + + if (WrongDirective) + return nullptr; + + Clause = Actions.ActOnOpenMPWhenClause(TI, DirKind, Directive, Loc, + DelimLoc, Tok.getLocation()); + } else { + ErrorFound = false; + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + } + + return ErrorFound ? nullptr : Clause; +} + + /// Parsing of OpenMP clauses. /// diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -37,6 +37,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Frontend/OpenMP/OMPContext.h" #include using namespace clang; @@ -3930,6 +3931,7 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) { switch (DKind) { + case OMPD_metadirective: // added case OMPD_parallel: case OMPD_parallel_for: case OMPD_parallel_for_simd: @@ -4339,8 +4341,8 @@ case OMPD_declare_variant: case OMPD_begin_declare_variant: case OMPD_end_declare_variant: - case OMPD_metadirective: - llvm_unreachable("OpenMP Directive is not allowed"); + //case OMPD_metadirective: + //llvm_unreachable("OpenMP Directive is not allowed"); case OMPD_unknown: default: llvm_unreachable("Unknown OpenMP directive"); @@ -4522,6 +4524,7 @@ StmtResult Sema::ActOnOpenMPRegionEnd(StmtResult S, ArrayRef Clauses) { + handleDeclareVariantConstructTrait(DSAStack, DSAStack->getCurrentDirective(), /* ScopeEntry */ false); if (DSAStack->getCurrentDirective() == OMPD_atomic || @@ -4590,6 +4593,7 @@ else if (Clause->getClauseKind() == OMPC_linear) LCs.push_back(cast(Clause)); } + // Capture allocator expressions if used. for (Expr *E : DSAStack->getInnerAllocators()) MarkDeclarationsReferencedInExpr(E); @@ -4611,6 +4615,7 @@ << SourceRange(OC->getBeginLoc(), OC->getEndLoc()); ErrorFound = true; } + // OpenMP 5.0, 2.9.2 Worksharing-Loop Construct, Restrictions. // If an order(concurrent) clause is present, an ordered clause may not appear // on the same directive. @@ -4623,6 +4628,7 @@ } ErrorFound = true; } + if (isOpenMPWorksharingDirective(DSAStack->getCurrentDirective()) && isOpenMPSimdDirective(DSAStack->getCurrentDirective()) && OC && OC->getNumForLoops()) { @@ -4635,7 +4641,9 @@ } StmtResult SR = S; unsigned CompletedRegions = 0; + for (OpenMPDirectiveKind ThisCaptureRegion : llvm::reverse(CaptureRegions)) { + // Mark all variables in private list clauses as used in inner region. // Required for proper codegen of combined directives. // TODO: add processing for other clauses. @@ -4656,6 +4664,7 @@ } } } + if (ThisCaptureRegion == OMPD_target) { // Capture allocator traits in the target region. They are used implicitly // and, thus, are not captured by default. @@ -4671,6 +4680,7 @@ } } } + if (ThisCaptureRegion == OMPD_parallel) { // Capture temp arrays for inscan reductions and locals in aligned // clauses. @@ -4687,10 +4697,14 @@ } } } + if (++CompletedRegions == CaptureRegions.size()) DSAStack->setBodyComplete(); + SR = ActOnCapturedRegionEnd(SR.get()); + } + return SR; } @@ -5963,6 +5977,12 @@ llvm::SmallVector AllowedNameModifiers; switch (Kind) { + + case OMPD_metadirective: + Res = ActOnOpenMPMetaDirective(ClausesWithImplicit, AStmt, StartLoc, + EndLoc); + AllowedNameModifiers.push_back(OMPD_metadirective); + break; case OMPD_parallel: Res = ActOnOpenMPParallelDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc); @@ -7341,7 +7361,120 @@ AppendArgs.size(), SR); FD->addAttr(NewAttr); } +/// + +StmtResult Sema::ActOnOpenMPMetaDirective(ArrayRef Clauses, + Stmt *AStmt, + SourceLocation StartLoc, + SourceLocation EndLoc) { + + if (!AStmt) + return StmtError(); + + auto *CS = cast(AStmt); + // + CS->getCapturedDecl()->setNothrow(); + + StmtResult IfStmt = StmtError(); + Stmt *ElseStmt = NULL; + + for (auto i = Clauses.rbegin(); i < Clauses.rend(); i++) { + OMPWhenClause *WhenClause = dyn_cast(*i); + Expr *WhenCondExpr = NULL; + Stmt *ThenStmt = NULL; + OpenMPDirectiveKind DKind = WhenClause->getDKind(); + + if (DKind != OMPD_unknown) + ThenStmt = CompoundStmt::Create(Context, {WhenClause->getDirective()}, + SourceLocation(), SourceLocation()); + + for (const OMPTraitSet &Set : WhenClause->getTI().Sets) { + for (const OMPTraitSelector &Selector : Set.Selectors) { + switch (Selector.Kind) { + case TraitSelector::device_arch: { + bool archMatch = false; + for (const OMPTraitProperty &Property : Selector.Properties) { + for (auto &T : getLangOpts().OMPTargetTriples) { + if (T.getArchName() == Property.RawString) { + archMatch = true; + break; + } + } + if (archMatch) + break; + } + // Create a true/false boolean expression and assign to WhenCondExpr + auto *C = new (Context) + CXXBoolLiteralExpr(archMatch, Context.BoolTy, StartLoc); + WhenCondExpr = dyn_cast(C); + break; + } + case TraitSelector::user_condition: { + assert(Selector.ScoreOrCondition && + "Ill-formed user condition, expected condition expression!"); + + WhenCondExpr = Selector.ScoreOrCondition; + break; + } + case TraitSelector::implementation_vendor: { + bool vendorMatch = false; + for (const OMPTraitProperty &Property : Selector.Properties) { + for (auto &T : getLangOpts().OMPTargetTriples) { + if (T.getVendorName() == Property.RawString) { + vendorMatch = true; + break; + } + } + if (vendorMatch) + break; + } + // Create a true/false boolean expression and assign to WhenCondExpr + auto *C = new (Context) + CXXBoolLiteralExpr(vendorMatch, Context.BoolTy, StartLoc); + WhenCondExpr = dyn_cast(C); + break; + } + case TraitSelector::device_isa: + case TraitSelector::device_kind: + case TraitSelector::implementation_extension: + default: + break; + } + } + } //// + + if (WhenCondExpr == NULL) { + if (ElseStmt != NULL) { + Diag(WhenClause->getBeginLoc(), diag::err_omp_misplaced_default_clause); + return StmtError(); + } + if (DKind == OMPD_unknown) + ElseStmt = CompoundStmt::Create(Context, {CS->getCapturedStmt()}, + SourceLocation(), SourceLocation()); + else + ElseStmt = ThenStmt; + continue; + } + + if (ThenStmt == NULL) + ThenStmt = CompoundStmt::Create(Context, {CS->getCapturedStmt()}, + SourceLocation(), SourceLocation()); + + IfStmt = + ActOnIfStmt(SourceLocation(), /*false*/ IfStatementKind::Ordinary, SourceLocation(), NULL, + ActOnCondition(getCurScope(), SourceLocation(), + WhenCondExpr, Sema::ConditionKind::Boolean), + SourceLocation(), ThenStmt, SourceLocation(), ElseStmt); + ElseStmt = IfStmt.get(); + } + + return OMPMetaDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, + IfStmt.get()); + +} + +/// StmtResult Sema::ActOnOpenMPParallelDirective(ArrayRef Clauses, Stmt *AStmt, SourceLocation StartLoc, @@ -14837,6 +14970,19 @@ return std::string(Out.str()); } +/// ActOnOpenMPWheClause --- Abid + +OMPClause * +Sema::ActOnOpenMPWhenClause(OMPTraitInfo &TI, OpenMPDirectiveKind DKind, + StmtResult Directive, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc) { + return new (Context) + OMPWhenClause(TI, DKind, Directive.get(), StartLoc, LParenLoc, EndLoc); +} + + + + OMPClause *Sema::ActOnOpenMPDefaultClause(DefaultKind Kind, SourceLocation KindKwLoc, SourceLocation StartLoc, diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp --- a/clang/lib/Sema/SemaStmt.cpp +++ b/clang/lib/Sema/SemaStmt.cpp @@ -4791,8 +4791,9 @@ CapturedStmt *Res = CapturedStmt::Create( getASTContext(), S, static_cast(RSI->CapRegionKind), Captures, CaptureInits, CD, RD); - + CD->setBody(Res->getCapturedStmt()); + RD->completeDefinition(); return Res; diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPContext.h b/llvm/include/llvm/Frontend/OpenMP/OMPContext.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPContext.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPContext.h @@ -189,6 +189,15 @@ int getBestVariantMatchForContext(const SmallVectorImpl &VMIs, const OMPContext &Ctx); +/// Sort array \p A of clause index with score +/// This will be used to produce AST clauses +/// in a sorted order with the clause with the highiest order +/// on the top and default clause at the bottom +void getArrayVariantMatchForContext( + const SmallVectorImpl &VMIs, const OMPContext &Ctx, + SmallVector> &A); + +// new-- } // namespace omp template <> struct DenseMapInfo { diff --git a/llvm/lib/Frontend/OpenMP/OMPContext.cpp b/llvm/lib/Frontend/OpenMP/OMPContext.cpp --- a/llvm/lib/Frontend/OpenMP/OMPContext.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPContext.cpp @@ -20,6 +20,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include #define DEBUG_TYPE "openmp-ir-builder" using namespace llvm; @@ -339,6 +340,53 @@ return Score; } +/// +/// Takes \p VMI and \p Ctx and sort the +/// scores using \p A +void llvm::omp::getArrayVariantMatchForContext(const SmallVectorImpl &VMIs, + const OMPContext &Ctx, SmallVector> &A){ + + //APInt BestScore(64, 0); + APInt Score (64, 0); + llvm::DenseMap m; + + /*for (unsigned u = 0, e = VMIs.size(); u < e; ++u) { + const VariantMatchInfo &VMI = VMIs[u]; + SmallVector ConstructMatches; + APInt Score = getVariantMatchScore(VMI, Ctx, ConstructMatches); + m.insert({u, Score}); + } + + */ + + for (unsigned u = 0, e = VMIs.size(); u < e; ++u) { + const VariantMatchInfo &VMI = VMIs[u]; + + SmallVector ConstructMatches; + // If the variant is not applicable its not the best. + if (!isVariantApplicableInContextHelper(VMI, Ctx, &ConstructMatches, + /* DeviceSetOnly */ false)){ + Score = 0; + m.insert({u, Score}); + continue; + } + + // Check if its clearly not the best. + Score = getVariantMatchScore(VMI, Ctx, ConstructMatches); + m.insert({u, Score}); + } + + for (auto& it : m) + A.push_back(it); + + std::sort(A.begin(), A.end(), [] (std::pair&a, + std::pair&b){ + return a.second.ugt(b.second); + }); +} + + + int llvm::omp::getBestVariantMatchForContext( const SmallVectorImpl &VMIs, const OMPContext &Ctx) {