diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -8734,25 +8734,7 @@ template class ConstOMPClauseVisitor : public OMPClauseVisitorBase {}; - -class OMPClausePrinter final : public OMPClauseVisitor { - raw_ostream &OS; - const PrintingPolicy &Policy; - - /// Process clauses with list of variables. - template void VisitOMPClauseList(T *Node, char StartSym); - /// Process motion clauses. - template void VisitOMPMotionClause(T *Node); - -public: - OMPClausePrinter(raw_ostream &OS, const PrintingPolicy &Policy) - : OS(OS), Policy(Policy) {} - -#define GEN_CLANG_CLAUSE_CLASS -#define CLAUSE_CLASS(Enum, Str, Class) void Visit##Class(Class *S); -#include "llvm/Frontend/OpenMP/OMP.inc" -}; - + struct OMPTraitProperty { llvm::omp::TraitProperty Kind = llvm::omp::TraitProperty::invalid; @@ -8993,6 +8975,98 @@ } }; +/// This captures 'when' clause in the '#pragma omp metadirective' +/// \code +/// #pragma omp metadirective when(user={condition(N<100)}:parallel for) +/// \endcode +/// In the above example, the metadirective has a condition clause which when +/// satisfied will use the parallel for directive with the code enclosed by the +/// directive. +class OMPWhenClause final : public OMPClause { + friend class OMPClauseReader; + friend class OMPExecutableDirective; + template friend class OMPDeclarativeDirective; + + OMPTraitInfo *TI; + OpenMPDirectiveKind DKind; + Stmt *Directive; + + /// Location of '('. + SourceLocation LParenLoc; + + /// Sets the location of '('. + /// + /// \param Loc Location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + +public: + /// Build 'when' clause with argument \a A ('none' or 'shared'). + /// + /// \param T TraitInfor containing information about the context selector + /// \param DKind The directive associated with the when clause + /// \param D The statement associated with the when clause + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + OMPWhenClause(OMPTraitInfo &T, OpenMPDirectiveKind dKind, Stmt *D, + SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_when, StartLoc, EndLoc), TI(&T), DKind(dKind), + Directive(D), LParenLoc(LParenLoc) {} + + /// Build an empty clause. + OMPWhenClause() + : OMPClause(llvm::omp::OMPC_when, SourceLocation(), SourceLocation()) {} + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + /// Returns the directive variant kind + OpenMPDirectiveKind getDKind() const { return DKind; } + + Stmt *getDirective() const { return Directive; } + + /// Returns the OMPTraitInfo + OMPTraitInfo &getTI() const { return *TI; } + + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + + const_child_range children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_when; + } +}; + +class OMPClausePrinter final : public OMPClauseVisitor { + raw_ostream &OS; + const PrintingPolicy &Policy; + + /// Process clauses with list of variables. + template void VisitOMPClauseList(T *Node, char StartSym); + /// Process motion clauses. + template void VisitOMPMotionClause(T *Node); + +public: + OMPClausePrinter(raw_ostream &OS, const PrintingPolicy &Policy) + : OS(OS), Policy(Policy) {} + + void VisitOMPWhenClause(OMPWhenClause *Node); + +#define GEN_CLANG_CLAUSE_CLASS +#define CLAUSE_CLASS(Enum, Str, Class) void Visit##Class(Class *S); +#include "llvm/Frontend/OpenMP/OMP.inc" +}; } // namespace clang #endif // LLVM_CLANG_AST_OPENMPCLAUSE_H diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -499,7 +499,8 @@ /// Process clauses with pre-initis. bool VisitOMPClauseWithPreInit(OMPClauseWithPreInit *Node); bool VisitOMPClauseWithPostUpdate(OMPClauseWithPostUpdate *Node); - + bool VisitOMPWhenClause(OMPWhenClause *C); + bool PostVisitStmt(Stmt *S); }; @@ -3273,6 +3274,18 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPWhenClause(OMPWhenClause *C) { + for (const OMPTraitSet &Set : C->getTI().Sets) { + for (const OMPTraitSelector &Selector : Set.Selectors) { + if (Selector.Kind == llvm::omp::TraitSelector::user_condition && + Selector.ScoreOrCondition) + TRY_TO(TraverseStmt(Selector.ScoreOrCondition)); + } + } + return true; +} + template bool RecursiveASTVisitor::VisitOMPDefaultClause(OMPDefaultClause *) { return true; diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -300,7 +300,7 @@ template static T *createDirective(const ASTContext &C, ArrayRef Clauses, Stmt *AssociatedStmt, unsigned NumChildren, - Params &&... P) { + Params &&...P) { void *Mem = C.Allocate(sizeof(T) + OMPChildren::size(Clauses.size(), AssociatedStmt, NumChildren), @@ -316,7 +316,7 @@ template static T *createEmptyDirective(const ASTContext &C, unsigned NumClauses, bool HasAssociatedStmt, unsigned NumChildren, - Params &&... P) { + Params &&...P) { void *Mem = C.Allocate(sizeof(T) + OMPChildren::size(NumClauses, HasAssociatedStmt, NumChildren), @@ -478,8 +478,7 @@ /// Returns true if the current directive has one or more clauses of a /// specific kind. - template - bool hasClausesOfKind() const { + template bool hasClausesOfKind() const { auto Clauses = getClausesOfKind(); return Clauses.begin() != Clauses.end(); } @@ -2711,7 +2710,6 @@ static OMPTaskgroupDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses, EmptyShell); - /// Returns reference to the task_reduction return variable. const Expr *getReductionRef() const { return const_cast(this)->getReductionRef(); @@ -4621,10 +4619,10 @@ /// \param AssociatedStmt Statement, associated with the directive. /// \param Exprs Helper expressions for CodeGen. /// - static OMPDistributeParallelForSimdDirective *Create( - const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, - unsigned CollapsedNum, ArrayRef Clauses, - Stmt *AssociatedStmt, const HelperExprs &Exprs); + static OMPDistributeParallelForSimdDirective * + Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, + unsigned CollapsedNum, ArrayRef Clauses, + Stmt *AssociatedStmt, const HelperExprs &Exprs); /// Creates an empty directive with the place for \a NumClauses clauses. /// @@ -4632,9 +4630,9 @@ /// \param CollapsedNum Number of collapsed nested loops. /// \param NumClauses Number of clauses. /// - static OMPDistributeParallelForSimdDirective *CreateEmpty( - const ASTContext &C, unsigned NumClauses, unsigned CollapsedNum, - EmptyShell); + static OMPDistributeParallelForSimdDirective * + CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned CollapsedNum, + EmptyShell); static bool classof(const Stmt *T) { return T->getStmtClass() == OMPDistributeParallelForSimdDirectiveClass; @@ -4831,8 +4829,7 @@ /// static OMPTargetSimdDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses, - unsigned CollapsedNum, - EmptyShell); + unsigned CollapsedNum, EmptyShell); static bool classof(const Stmt *T) { return T->getStmtClass() == OMPTargetSimdDirectiveClass; @@ -5169,11 +5166,9 @@ /// \param Clauses List of clauses. /// \param AssociatedStmt Statement, associated with the directive. /// - static OMPTargetTeamsDirective *Create(const ASTContext &C, - SourceLocation StartLoc, - SourceLocation EndLoc, - ArrayRef Clauses, - Stmt *AssociatedStmt); + static OMPTargetTeamsDirective * + Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, + ArrayRef Clauses, Stmt *AssociatedStmt); /// Creates an empty directive with the place for \a NumClauses clauses. /// @@ -5244,9 +5239,10 @@ /// \param CollapsedNum Number of collapsed nested loops. /// \param NumClauses Number of clauses. /// - static OMPTargetTeamsDistributeDirective * - CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned CollapsedNum, - EmptyShell); + static OMPTargetTeamsDistributeDirective *CreateEmpty(const ASTContext &C, + unsigned NumClauses, + unsigned CollapsedNum, + EmptyShell); static bool classof(const Stmt *T) { return T->getStmtClass() == OMPTargetTeamsDistributeDirectiveClass; diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11086,6 +11086,8 @@ def warn_omp_unterminated_declare_target : Warning< "expected '#pragma omp end declare target' at end of file to match '#pragma omp %0'">, InGroup; +def err_omp_misplaced_default_clause : Error< + "Only one 'default clause' is allowed">; } // end of OpenMP category let CategoryName = "Related Result Type Issue" in { diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h --- a/clang/include/clang/Parse/Parser.h +++ b/clang/include/clang/Parse/Parser.h @@ -3367,6 +3367,14 @@ /// metadirective and therefore ends on the closing paren. StmtResult ParseOpenMPDeclarativeOrExecutableDirective( ParsedStmtContext StmtCtx, bool ReadDirectiveWithinMetadirective = false); + + /// Parse clause for metadirective + /// + /// \param Dkind Kind of current directive + /// \param CKind Kind of current clause + /// + OMPClause *ParseOpenMPMetaDirectiveClause(OpenMPDirectiveKind DKind, + OpenMPClauseKind CKind); /// Parses clause of kind \a CKind for directive of a kind \a Kind. /// /// \param DKind Kind of current directive. diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -66,6 +66,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TinyPtrVector.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Frontend/OpenMP/OMPContext.h" #include #include #include @@ -11202,10 +11203,18 @@ /// /// \returns Statement for finished OpenMP region. StmtResult ActOnOpenMPRegionEnd(StmtResult S, ArrayRef Clauses); + + /// Called on well-formed StmtResult ActOnOpenMPExecutableDirective( OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName, OpenMPDirectiveKind CancelRegion, ArrayRef Clauses, Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc); + + /// Called on meta directive + StmtResult ActOnOpenMPExecutableMetaDirective( + OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName, + OpenMPDirectiveKind CancelRegion, ArrayRef Clauses, + Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc); /// Called on well-formed '\#pragma omp parallel' after parsing /// of the associated statement. StmtResult ActOnOpenMPParallelDirective(ArrayRef Clauses, @@ -11708,7 +11717,9 @@ SourceLocation LParenLoc, SourceLocation EndLoc); /// Called on well-formed 'when' clause. - OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, SourceLocation StartLoc, + OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, OpenMPDirectiveKind DKind, + StmtResult Directive, + SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); /// Called on well-formed 'default' clause. diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -15,6 +15,7 @@ #include "clang/AST/Attr.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclOpenMP.h" +#include "clang/AST/StmtOpenMP.h" #include "clang/Basic/LLVM.h" #include "clang/Basic/OpenMPKinds.h" #include "clang/Basic/TargetInfo.h" @@ -180,7 +181,8 @@ return Res ? const_cast(Res) : nullptr; } -const OMPClauseWithPostUpdate *OMPClauseWithPostUpdate::get(const OMPClause *C) { +const OMPClauseWithPostUpdate * +OMPClauseWithPostUpdate::get(const OMPClause *C) { switch (C->getClauseKind()) { case OMPC_lastprivate: return static_cast(C); @@ -610,7 +612,8 @@ unsigned NumVars) { // Allocate space for 5 lists (Vars, Inits, Updates, Finals), 2 expressions // (Step and CalcStep), list of used expression + step. - void *Mem = C.Allocate(totalSizeToAlloc(5 * NumVars + 2 + NumVars +1)); + void *Mem = + C.Allocate(totalSizeToAlloc(5 * NumVars + 2 + NumVars + 1)); return new (Mem) OMPLinearClause(NumVars); } @@ -1059,7 +1062,7 @@ Clause->setOmpAllMemoryLoc(Data.OmpAllMemoryLoc); Clause->setModifier(DepModifier); Clause->setVarRefs(VL); - for (unsigned I = 0 ; I < NumLoops; ++I) + for (unsigned I = 0; I < NumLoops; ++I) Clause->setLoopData(I, nullptr); return Clause; } @@ -1667,6 +1670,31 @@ // OpenMP clauses printing methods //===----------------------------------------------------------------------===// +void OMPClausePrinter::VisitOMPWhenClause(OMPWhenClause *Node) { + OMPTraitInfo &TI = Node->getTI(); + if (TI.Sets.empty()) + OS << "default("; + else + OS << "when("; + TI.print(OS, Policy); + + if (Stmt *S = Node->getDirective()) { + OS << ":"; + OMPExecutableDirective *D = cast(S); + auto DKind = D->getDirectiveKind(); + OS << getOpenMPDirectiveName(DKind); + + OMPClausePrinter Printer(OS, Policy); + ArrayRef Clauses = D->clauses(); + for (auto *Clause : Clauses) + if (Clause && !Clause->isImplicit()) { + OS << ' '; + Printer.Visit(Clause); + } + } + OS << ")"; +} + void OMPClausePrinter::VisitOMPIfClause(OMPIfClause *Node) { OS << "if("; if (Node->getNameModifier() != OMPD_unknown) @@ -2011,7 +2039,7 @@ } } -template +template void OMPClausePrinter::VisitOMPClauseList(T *Node, char StartSym) { for (typename T::varlist_iterator I = Node->varlist_begin(), E = Node->varlist_end(); @@ -2307,8 +2335,9 @@ } void OMPClausePrinter::VisitOMPDistScheduleClause(OMPDistScheduleClause *Node) { - OS << "dist_schedule(" << getOpenMPSimpleClauseTypeName( - OMPC_dist_schedule, Node->getDistScheduleKind()); + OS << "dist_schedule(" + << getOpenMPSimpleClauseTypeName(OMPC_dist_schedule, + Node->getDistScheduleKind()); if (auto *E = Node->getChunkSize()) { OS << ", "; E->printPretty(OS, nullptr, Policy); @@ -2353,7 +2382,8 @@ } } -void OMPClausePrinter::VisitOMPHasDeviceAddrClause(OMPHasDeviceAddrClause *Node) { +void OMPClausePrinter::VisitOMPHasDeviceAddrClause( + OMPHasDeviceAddrClause *Node) { if (!Node->varlist_empty()) { OS << "has_device_addr"; VisitOMPClauseList(Node, '('); @@ -2483,8 +2513,7 @@ // TODO: This might not hold once we implement SIMD properly. assert(Selector.Properties.size() == 1 && Selector.Properties.front().Kind == - getOpenMPContextTraitPropertyForSelector( - Selector.Kind) && + getOpenMPContextTraitPropertyForSelector(Selector.Kind) && "Ill-formed construct selector!"); } } @@ -2508,8 +2537,8 @@ bool AllowsTraitScore = false; bool RequiresProperty = false; - isValidTraitSelectorForTraitSet( - Selector.Kind, Set.Kind, AllowsTraitScore, RequiresProperty); + isValidTraitSelectorForTraitSet(Selector.Kind, Set.Kind, AllowsTraitScore, + RequiresProperty); if (!RequiresProperty) continue; @@ -2552,12 +2581,11 @@ bool AllowsTraitScore = false; bool RequiresProperty = false; - isValidTraitSelectorForTraitSet( - Selector.Kind, Set.Kind, AllowsTraitScore, RequiresProperty); + isValidTraitSelectorForTraitSet(Selector.Kind, Set.Kind, AllowsTraitScore, + RequiresProperty); OS << '$' << 's' << unsigned(Selector.Kind); - if (!RequiresProperty || - Selector.Kind == TraitSelector::user_condition) + if (!RequiresProperty || Selector.Kind == TraitSelector::user_condition) continue; for (const OMPTraitProperty &Property : Selector.Properties) diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -719,11 +719,20 @@ bool ForceNoStmt) { OMPClausePrinter Printer(OS, Policy); ArrayRef Clauses = S->clauses(); - for (auto *Clause : Clauses) - if (Clause && !Clause->isImplicit()) { + for (auto *Clause : Clauses){ + if (Clause && !Clause->isImplicit()){ OS << ' '; Printer.Visit(Clause); - } + if (isa(S)){ + OMPWhenClause *WhenClause = dyn_cast(Clause); + if (WhenClause!=nullptr){ + if (WhenClause->getDKind() != llvm::omp::OMPD_unknown){ + Printer.VisitOMPWhenClause(WhenClause); + } + } + } + } + } OS << NL; if (!ForceNoStmt && S->hasAssociatedStmt()) PrintStmt(S->getRawStmt()); diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -2536,10 +2536,12 @@ BalancedDelimiterTracker T(*this, tok::l_paren, tok::annot_pragma_openmp_end); + while (Tok.isNot(tok::annot_pragma_openmp_end)) { OpenMPClauseKind CKind = Tok.isAnnotation() ? OMPC_unknown : getOpenMPClauseKind(PP.getSpelling(Tok)); + SourceLocation Loc = ConsumeToken(); // Parse '('. @@ -2557,7 +2559,7 @@ return Directive; } - // Parse ':' + // Parse ':' // You have parsed the OpenMP Context in the meta directive if (Tok.is(tok::colon)) ConsumeAnyToken(); else { @@ -2566,6 +2568,7 @@ return Directive; } } + // Skip Directive for now. We will parse directive in the second iteration int paren = 0; while (Tok.isNot(tok::r_paren) || paren != 0) { @@ -2579,62 +2582,56 @@ TPA.Commit(); return Directive; } - ConsumeAnyToken(); - } + ConsumeAnyToken(); + } + // Parse ')' if (Tok.is(tok::r_paren)) T.consumeClose(); - + VariantMatchInfo VMI; TI.getAsVariantMatchInfo(ASTContext, VMI); - - VMIs.push_back(VMI); - } - + + if (CKind == OMPC_when ) + VMIs.push_back(VMI); + } + + // This is the end of the first iteration + // The pointer is moved back TPA.Revert(); // End of the first iteration. Parser is reset to the start of metadirective + std::function DiagUnknownTrait = [this, Loc](StringRef ISATrait) { // TODO Track the selector locations in a way that is accessible here // to improve the diagnostic location. Diag(Loc, diag::warn_unknown_declare_variant_isa_trait) << ISATrait; }; - TargetOMPContext OMPCtx(ASTContext, std::move(DiagUnknownTrait), + // TargetOMPContext OMPCtx(ASTContext, std::move(DiagUnknownTrait), + + TargetOMPContext OMPCtx(ASTContext, /* DiagUnknownTrait */ nullptr, /* CurrentFunctionDecl */ nullptr, ArrayRef()); - - // A single match is returned for OpenMP 5.0 - int BestIdx = getBestVariantMatchForContext(VMIs, OMPCtx); - - int Idx = 0; - // In OpenMP 5.0 metadirective is either replaced by another directive or - // ignored. - // TODO: In OpenMP 5.1 generate multiple directives based upon the matches - // found by getBestWhenMatchForContext. - while (Tok.isNot(tok::annot_pragma_openmp_end)) { - // OpenMP 5.0 implementation - Skip to the best index found. - if (Idx++ != BestIdx) { - ConsumeToken(); // Consume clause name - T.consumeOpen(); // Consume '(' - int paren = 0; - // Skip everything inside the clause - while (Tok.isNot(tok::r_paren) || paren != 0) { - if (Tok.is(tok::l_paren)) - paren++; - if (Tok.is(tok::r_paren)) - paren--; - ConsumeAnyToken(); - } - // Parse ')' - if (Tok.is(tok::r_paren)) - T.consumeClose(); - continue; - } - + + // Array SortedCluases will be used for sorting clauses + // based on the context selector score + SmallVector> SortedCluases; + + // The function will get the score for each clause and sort it + // based on the score number + + getArrayVariantMatchForContext(VMIs, OMPCtx, SortedCluases) ; + + ParseScope OMPDirectiveScope(this, ScopeFlags); + Actions.StartOpenMPDSABlock(DKind, DirName, Actions.getCurScope(), Loc); + + while(Tok.isNot(tok::annot_pragma_openmp_end)){ + OpenMPClauseKind CKind = Tok.isAnnotation() ? OMPC_unknown : getOpenMPClauseKind(PP.getSpelling(Tok)); + SourceLocation Loc = ConsumeToken(); // Parse '('. @@ -2662,9 +2659,71 @@ StmtCtx, /*ReadDirectiveWithinMetadirective=*/true); break; + + Actions.StartOpenMPClause(CKind); + OMPClause *Clause = ParseOpenMPMetaDirectiveClause( DKind, CKind); + + FirstClauses[(unsigned) CKind].setInt(true); + if (Clause) { + FirstClauses[(unsigned) CKind].setPointer(Clause); + Clauses.push_back(Clause); + } + + if (Tok.is(tok::comma)) + ConsumeToken(); + + Actions.EndOpenMPClause(); + + if (Tok.is(tok::r_paren)) + ConsumeAnyToken(); + } - break; - } + + // End location of the directive + EndLoc = Tok.getLocation(); + + //Consume final annot_pragma_openmp_end + ConsumeAnnotationToken(); + + SmallVector Clauses_new; + unsigned count = 0; + + // SortedClauses has index and score, and are sorted with respect to the + // the context score. The first iteration will take each element. The + // first element will have the highiest score. The element will have the + // index of the cluase for the best score. The second iteration tries to + // find that specific clause by checking the count numder with the + // index (Iteration1.first) + for ( auto &Iteration1 : SortedCluases){ + count = 0; + for ( auto &Iteration2 : Clauses){ + if ( count == Iteration1.first ){ + Clauses_new.push_back(Iteration2); + break; + } else count++; + } + } + // Adding the default clasue at the end + Clauses_new.push_back(Clauses.back()); + + // Parsing the OpenMP region which will take the + // metadirective + + Actions.ActOnOpenMPRegionStart(DKind, getCurScope()); + ParsingOpenMPDirectiveRAII NormalScope(*this, /*value=*/ false); + // This is parsing the region + StmtResult AStmt = ParseStatement(); + + StmtResult AssociatedStmt = (Sema::CompoundScopeRAII(Actions), AStmt); + // Ending of the parallel region + AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses_new); + Directive = Actions.ActOnOpenMPExecutableDirective( + DKind, DirName, CancelRegion, Clauses_new, AssociatedStmt.get(), Loc, + EndLoc); + // Exit scope + Actions.EndOpenMPDSABlock(Directive.get()); + break; + } // end of case OMPD_metadirective: case OMPD_threadprivate: { // FIXME: Should this be permitted in C++? if ((StmtCtx & ParsedStmtContext::AllowDeclarationsInC) == @@ -3136,6 +3195,164 @@ return Actions.ActOnOpenMPUsesAllocatorClause(Loc, T.getOpenLocation(), T.getCloseLocation(), Data); } +/// Parsing of OpenMP MetaDirective Clauses + +OMPClause *Parser::ParseOpenMPMetaDirectiveClause(OpenMPDirectiveKind DKind, + OpenMPClauseKind CKind) { + OMPClause *Clause = nullptr; + bool ErrorFound = false; + bool WrongDirective = false; + SmallVector, + llvm::omp::Clause_enumSize + 1> + FirstClauses(llvm::omp::Clause_enumSize + 1); + + // Check if it is called from metadirective. + if (DKind != OMPD_metadirective) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + ErrorFound = true; + } + + // Check if clause is allowed for the given directive. + if (CKind != OMPC_unknown && + !isAllowedClauseForDirective(DKind, CKind, getLangOpts().OpenMP)) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + ErrorFound = true; + WrongDirective = true; + } + + // Check if clause is not allowed + if (CKind == OMPC_unknown) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << "Unknown clause: Not allowed"; + ErrorFound = true; + WrongDirective = true; + } + + if (CKind == OMPC_default || CKind == OMPC_when) { + SourceLocation Loc = ConsumeToken(); + SourceLocation DelimLoc; + // Parse '('. + BalancedDelimiterTracker T(*this, tok::l_paren, + tok::annot_pragma_openmp_end); + if (T.expectAndConsume(diag::err_expected_lparen_after, + getOpenMPClauseName(CKind).data())) + return nullptr; + + OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo(); + if (CKind == OMPC_when) { + // parse and get condition expression to pass to the When clause + parseOMPContextSelectors(Loc, TI); + + // Parse ':' + if (Tok.is(tok::colon)) + ConsumeAnyToken(); + else { + Diag(Tok, diag::warn_pragma_expected_colon) << "when clause"; + return nullptr; + } + } + + // Parse Directive + OpenMPDirectiveKind DirKind = OMPD_unknown; + SmallVector Clauses; + StmtResult AssociatedStmt; + StmtResult Directive = StmtError(); + + if (Tok.isNot(tok::r_paren)) { + ParsingOpenMPDirectiveRAII DirScope(*this); + ParenBraceBracketBalancer BalancerRAIIObj(*this); + DeclarationNameInfo DirName; + unsigned ScopeFlags = Scope::FnScope | Scope::DeclScope | + Scope::CompoundStmtScope | + Scope::OpenMPDirectiveScope; + + DirKind = parseOpenMPDirectiveKind(*this); + ConsumeToken(); + ParseScope OMPDirectiveScope(this, ScopeFlags); + Actions.StartOpenMPDSABlock(DirKind, DirName, Actions.getCurScope(), Loc); + + int paren = 0; + + while (Tok.isNot(tok::r_paren) || paren != 0) { + if (Tok.is(tok::l_paren)) + paren++; + if (Tok.is(tok::r_paren)) + paren--; + + OpenMPClauseKind CKind = Tok.isAnnotation() + ? OMPC_unknown + : getOpenMPClauseKind(PP.getSpelling(Tok)); + + if (CKind == OMPC_unknown && + !isAllowedClauseForDirective(DirKind, CKind, getLangOpts().OpenMP)) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + ErrorFound = true; + WrongDirective = true; + } + + Actions.StartOpenMPClause(CKind); + OMPClause *Clause = ParseOpenMPClause( + DirKind, CKind, !FirstClauses[(unsigned)CKind].getInt()); + FirstClauses[(unsigned)CKind].setInt(true); + if (Clause) { + FirstClauses[(unsigned)CKind].setPointer(Clause); + Clauses.push_back(Clause); + } + + // Skip ',' if any. + if (Tok.is(tok::comma)) + ConsumeToken(); + Actions.EndOpenMPClause(); + } + + Actions.ActOnOpenMPRegionStart(DirKind, getCurScope()); + ParsingOpenMPDirectiveRAII NormalScope(*this, /*Value=*/false); + + /* Get Stmt and revert back */ + TentativeParsingAction TPA(*this); + while (Tok.isNot(tok::annot_pragma_openmp_end)) { + ConsumeAnyToken(); + } + + ConsumeAnnotationToken(); + ParseScope InnerStmtScope(this, Scope::DeclScope, + getLangOpts().C99 || getLangOpts().CPlusPlus, + Tok.is(tok::l_brace)); + + StmtResult AStmt = ParseStatement(); + InnerStmtScope.Exit(); + TPA.Revert(); + /* End Get Stmt */ + + AssociatedStmt = (Sema::CompoundScopeRAII(Actions), AStmt); + AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses); + + Directive = Actions.ActOnOpenMPExecutableDirective( + DirKind, DirName, OMPD_unknown, llvm::makeArrayRef(Clauses), + AssociatedStmt.get(), Loc, Tok.getLocation()); + + Actions.EndOpenMPDSABlock(Directive.get()); + OMPDirectiveScope.Exit(); + } + // Parse ')' + T.consumeClose(); + + if (WrongDirective) + return nullptr; + + Clause = Actions.ActOnOpenMPWhenClause(TI, DirKind, Directive, Loc, + DelimLoc, Tok.getLocation()); + } else { + ErrorFound = false; + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + } + + return ErrorFound ? nullptr : Clause; +} /// Parsing of OpenMP clauses. /// diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -39,6 +39,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Frontend/OpenMP/OMPAssume.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Frontend/OpenMP/OMPContext.h" #include using namespace clang; @@ -4145,6 +4146,7 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) { switch (DKind) { + case OMPD_metadirective: case OMPD_parallel: case OMPD_parallel_for: case OMPD_parallel_for_simd: @@ -4579,7 +4581,6 @@ case OMPD_declare_variant: case OMPD_begin_declare_variant: case OMPD_end_declare_variant: - case OMPD_metadirective: llvm_unreachable("OpenMP Directive is not allowed"); case OMPD_unknown: default: @@ -4763,7 +4764,7 @@ } StmtResult Sema::ActOnOpenMPRegionEnd(StmtResult S, - ArrayRef Clauses) { + ArrayRef Clauses){ handleDeclareVariantConstructTrait(DSAStack, DSAStack->getCurrentDirective(), /* ScopeEntry */ false); if (DSAStack->getCurrentDirective() == OMPD_atomic || @@ -4854,6 +4855,7 @@ << SourceRange(OC->getBeginLoc(), OC->getEndLoc()); ErrorFound = true; } + // OpenMP 5.0, 2.9.2 Worksharing-Loop Construct, Restrictions. // If an order(concurrent) clause is present, an ordered clause may not appear // on the same directive. @@ -4865,7 +4867,7 @@ << SourceRange(OC->getBeginLoc(), OC->getEndLoc()); } ErrorFound = true; - } + } if (isOpenMPWorksharingDirective(DSAStack->getCurrentDirective()) && isOpenMPSimdDirective(DSAStack->getCurrentDirective()) && OC && OC->getNumForLoops()) { @@ -4877,8 +4879,9 @@ return StmtError(); } StmtResult SR = S; - unsigned CompletedRegions = 0; + unsigned CompletedRegions = 0; for (OpenMPDirectiveKind ThisCaptureRegion : llvm::reverse(CaptureRegions)) { + // Mark all variables in private list clauses as used in inner region. // Required for proper codegen of combined directives. // TODO: add processing for other clauses. @@ -4899,6 +4902,7 @@ } } } + if (ThisCaptureRegion == OMPD_target) { // Capture allocator traits in the target region. They are used implicitly // and, thus, are not captured by default. @@ -4914,6 +4918,7 @@ } } } + if (ThisCaptureRegion == OMPD_parallel) { // Capture temp arrays for inscan reductions and locals in aligned // clauses. @@ -4930,10 +4935,14 @@ } } } + if (++CompletedRegions == CaptureRegions.size()) DSAStack->setBodyComplete(); + SR = ActOnCapturedRegionEnd(SR.get()); + } + return SR; } @@ -6250,6 +6259,12 @@ llvm::SmallVector AllowedNameModifiers; switch (Kind) { + + case OMPD_metadirective: + Res = ActOnOpenMPMetaDirective(ClausesWithImplicit, AStmt, StartLoc, + EndLoc); + AllowedNameModifiers.push_back(OMPD_metadirective); + break; case OMPD_parallel: Res = ActOnOpenMPParallelDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc); @@ -7705,10 +7720,126 @@ FD->addAttr(NewAttr); } +StmtResult Sema::ActOnOpenMPMetaDirective(ArrayRef Clauses, + Stmt *AStmt, + SourceLocation StartLoc, + SourceLocation EndLoc){ + if (!AStmt) + return StmtError(); + + auto *CS = cast(AStmt); + // 1.2.2 OpenMP Language Terminology + // Structured block - An executable statement with a single entry at the + // top and a single exit at the bottom. + // The point of exit cannot be a branch out of the structured block. + // longjmp() and throw() must not violate the entry/exit criteria. + + CS->getCapturedDecl()->setNothrow(); + + StmtResult IfStmt = StmtError(); + Stmt *ElseStmt = nullptr; + + for (auto i = Clauses.rbegin(); i < Clauses.rend(); i++) { + OMPWhenClause *WhenClause = dyn_cast(*i); + Expr *WhenCondExpr = nullptr; + Stmt *ThenStmt = nullptr; + OpenMPDirectiveKind DKind = WhenClause->getDKind(); + + if (DKind != OMPD_unknown) + ThenStmt = CompoundStmt::Create(Context, {WhenClause->getDirective()}, + FPOptionsOverride(), SourceLocation(), + SourceLocation()); + + for (const OMPTraitSet &Set : WhenClause->getTI().Sets){ + for (const OMPTraitSelector &Selector : Set.Selectors){ + switch (Selector.Kind){ + case TraitSelector::device_arch:{ + bool ArchMatch = false; + for (const OMPTraitProperty &Property : Selector.Properties){ + for (auto &T : getLangOpts().OMPTargetTriples){ + if (T.getArchName() == Property.RawString){ + ArchMatch = true; + break; + } + } + if (ArchMatch) + break; + } + // Create a true/false boolean expression and assign to WhenCondExpr + auto *C = new (Context) + CXXBoolLiteralExpr(ArchMatch, Context.BoolTy, StartLoc); + WhenCondExpr = dyn_cast(C); + break; + } + case TraitSelector::user_condition:{ + assert(Selector.ScoreOrCondition && + "Ill-formed user condition, expected condition expression!"); + + WhenCondExpr = Selector.ScoreOrCondition; + break; + } + case TraitSelector::implementation_vendor:{ + bool vendorMatch = false; + for (const OMPTraitProperty &Property : Selector.Properties){ + for (auto &T : getLangOpts().OMPTargetTriples){ + if (T.getVendorName() == Property.RawString){ + vendorMatch = true; + break; + } + } + if (vendorMatch) + break; + } + // Create a true/false boolean expression and assign to WhenCondExpr + auto *WhenCondition = new (Context) + CXXBoolLiteralExpr(vendorMatch, Context.BoolTy, StartLoc); + WhenCondExpr = dyn_cast(WhenCondition); + break; + } + case TraitSelector::device_isa: + case TraitSelector::device_kind: + case TraitSelector::implementation_extension: + default: + break; + } + } + } + + if (WhenCondExpr == nullptr) { + if (ElseStmt != nullptr) { + Diag(WhenClause->getBeginLoc(), diag::err_omp_misplaced_default_clause); + return StmtError(); + } + if (DKind == OMPD_unknown) + ElseStmt = CompoundStmt::Create(Context, {CS->getCapturedStmt()}, + FPOptionsOverride(), SourceLocation(), + SourceLocation()); + else + ElseStmt = ThenStmt; + continue; + } + + if (ThenStmt == NULL) + ThenStmt = CompoundStmt::Create(Context, {CS->getCapturedStmt()}, + FPOptionsOverride(), SourceLocation(), + SourceLocation()); + + IfStmt = + ActOnIfStmt(SourceLocation(), /*false*/ IfStatementKind::Ordinary, SourceLocation(), nullptr, + ActOnCondition(getCurScope(), SourceLocation(), + WhenCondExpr, Sema::ConditionKind::Boolean), + SourceLocation(), ThenStmt, SourceLocation(), ElseStmt); + ElseStmt = IfStmt.get(); + } + + return OMPMetaDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, + IfStmt.get()); +} + StmtResult Sema::ActOnOpenMPParallelDirective(ArrayRef Clauses, Stmt *AStmt, SourceLocation StartLoc, - SourceLocation EndLoc) { + SourceLocation EndLoc){ if (!AStmt) return StmtError(); @@ -16690,6 +16821,17 @@ return std::string(Out.str()); } +OMPClause * +Sema::ActOnOpenMPWhenClause(OMPTraitInfo &TI, OpenMPDirectiveKind DKind, + StmtResult Directive, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc) { + return new (Context) + OMPWhenClause(TI, DKind, Directive.get(), StartLoc, LParenLoc, EndLoc); +} + + + + OMPClause *Sema::ActOnOpenMPDefaultClause(DefaultKind Kind, SourceLocation KindKwLoc, SourceLocation StartLoc, diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp --- a/clang/lib/Sema/SemaStmt.cpp +++ b/clang/lib/Sema/SemaStmt.cpp @@ -4822,7 +4822,7 @@ getASTContext(), S, static_cast(RSI->CapRegionKind), Captures, CaptureInits, CD, RD); - CD->setBody(Res->getCapturedStmt()); + CD->setBody(Res->getCapturedStmt()); RD->completeDefinition(); return Res; diff --git a/clang/test/OpenMP/metadirective_ast_print_new_1.cpp b/clang/test/OpenMP/metadirective_ast_print_new_1.cpp new file mode 100644 --- /dev/null +++ b/clang/test/OpenMP/metadirective_ast_print_new_1.cpp @@ -0,0 +1,20 @@ +// RUN: %clang_cc1 -verify -fopenmp -ast-print %s -o - | FileCheck %s +// expected-no-diagnostics +void bar(){ + int i=0; +} + +void myfoo(void){ + + int N = 13; + int b,n; + int a[100]; + + #pragma omp metadirective when(user={condition(N>10)}: target teams ) default(parallel for) + for (int i = 0; i < N; i++) + bar(); + +} + +// CHECK: void bar() +// CHECK: #pragma omp metadirective when(user={condition(N > 10)}: target teams) default(parallel for) diff --git a/clang/test/OpenMP/metadirective_ast_print_new_2.cpp b/clang/test/OpenMP/metadirective_ast_print_new_2.cpp new file mode 100644 --- /dev/null +++ b/clang/test/OpenMP/metadirective_ast_print_new_2.cpp @@ -0,0 +1,29 @@ +// RUN: %clang_cc1 -verify -fopenmp -ast-print %s -o - | FileCheck %s +// expected-no-diagnostics + +void bar(){ + int i=0; +} + +void myfoo(void){ + + int N = 13; + int b,n; + int a[100]; + + + #pragma omp metadirective when (user = {condition(N>10)}: target teams distribute parallel for ) \ + when (user = {condition(N==10)}: parallel for )\ + when (user = {condition(N==13)}: parallel for simd) \ + when ( device={arch("arm")}: target teams num_teams(512) thread_limit(32))\ + when ( device={arch("nvptx")}: target teams num_teams(512) thread_limit(32))\ + default ( parallel for)\ + + { for (int i = 0; i < N; i++) + bar(); + } +} + +// CHECK: bar() +// CHECK: myfoo +// CHECK: #pragma omp metadirective when(user={condition(N > 10)}: target teams distribute parallel for) when(user={condition(N == 13)}: parallel for simd) when(device={arch(nvptx)}: target teams) diff --git a/clang/test/OpenMP/metadirective_ast_print_new_3.cpp b/clang/test/OpenMP/metadirective_ast_print_new_3.cpp new file mode 100644 --- /dev/null +++ b/clang/test/OpenMP/metadirective_ast_print_new_3.cpp @@ -0,0 +1,22 @@ +// RUN: %clang_cc1 -verify -fopenmp -ast-print %s -o - | FileCheck %s +// expected-no-diagnostics + +int main() { + int N = 15; +#pragma omp metadirective when(user = {condition(N > 10)} : parallel for)\ + default(target teams) + for (int i = 0; i < N; i++) + ; + + +#pragma omp metadirective when(device = {arch("nvptx64")}, user = {condition(N >= 100)} : parallel for)\ + default(target parallel for) + for (int i = 0; i < N; i++) + ; + return 0; +} + + + +// CHECK: #pragma omp metadirective when(user={condition(N > 10)}: parallel for) default(target teams) +// CHECK: #pragma omp metadirective when(device={arch(nvptx64)}, user={condition(N >= 100)}: parallel for) default(target parallel for) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPContext.h b/llvm/include/llvm/Frontend/OpenMP/OMPContext.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPContext.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPContext.h @@ -189,6 +189,15 @@ int getBestVariantMatchForContext(const SmallVectorImpl &VMIs, const OMPContext &Ctx); +/// Sort array \p A of clause index with score +/// This will be used to produce AST clauses +/// in a sorted order with the clause with the highiest order +/// on the top and default clause at the bottom +void getArrayVariantMatchForContext( + const SmallVectorImpl &VMIs, const OMPContext &Ctx, + SmallVector> &A); + +// new-- } // namespace omp template <> struct DenseMapInfo { diff --git a/llvm/lib/Frontend/OpenMP/OMPContext.cpp b/llvm/lib/Frontend/OpenMP/OMPContext.cpp --- a/llvm/lib/Frontend/OpenMP/OMPContext.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPContext.cpp @@ -19,6 +19,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include #define DEBUG_TYPE "openmp-ir-builder" using namespace llvm; @@ -336,6 +337,44 @@ return Score; } +/// Takes \p VMI and \p Ctx and sort the +/// scores using \p VectorOfClauses +void llvm::omp::getArrayVariantMatchForContext( + const SmallVectorImpl &VMIs, const OMPContext &Ctx, + SmallVector> &VectorOfClauses) { + + // APInt BestScore(64, 0); + APInt Score(64, 0); + // The MapOfClauses will contain the index of the cluase and its context socre + llvm::DenseMap MapOfCluases; + + for (unsigned u = 0, e = VMIs.size(); u < e; ++u) { + const VariantMatchInfo &VMI = VMIs[u]; + + SmallVector ConstructMatches; + // If the variant is not applicable its not the best. + if (!isVariantApplicableInContextHelper(VMI, Ctx, &ConstructMatches, + /* DeviceSetOnly */ false)) { + Score = 0; + // adding index and its corresdoning score + MapOfCluases.insert({u, Score}); + continue; + } + // Else get the score + Score = getVariantMatchScore(VMI, Ctx, ConstructMatches); + MapOfCluases.insert({u, Score}); + } + + for (auto &it : MapOfCluases) + VectorOfClauses.push_back(it); + + // The following Lamda will sort the VectorOfClauses based on the score + std::sort(VectorOfClauses.begin(), VectorOfClauses.end(), + [](std::pair &a, std::pair &b) { + return a.second.ugt(b.second); + }); +} + int llvm::omp::getBestVariantMatchForContext( const SmallVectorImpl &VMIs, const OMPContext &Ctx) {