diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -6697,9 +6697,11 @@ /// former is a flat representation the actual main difference is that the /// latter uses clang::Expr to store the score/condition while the former is /// independent of clang. Thus, expressions and conditions are evaluated in - /// this method. + /// this method. If \p DeviceSetOnly is set only the device selector set, if + /// present, is put into \p VMI. void getAsVariantMatchInfo(ASTContext &ASTCtx, - llvm::omp::VariantMatchInfo &VMI) const; + llvm::omp::VariantMatchInfo &VMI, + bool DeviceSetOnly) const; /// Print a human readable representation into \p OS. void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td --- a/clang/include/clang/Basic/DiagnosticParseKinds.td +++ b/clang/include/clang/Basic/DiagnosticParseKinds.td @@ -1248,6 +1248,8 @@ "unexpected '%0' clause, '%1' is specified already">; def err_expected_end_declare_target : Error< "expected '#pragma omp end declare target'">; +def err_expected_end_declare_target_or_variant : Error< + "expected '#pragma omp end declare %select{target|variant}0'">; def err_omp_declare_target_unexpected_clause: Error< "unexpected '%0' clause, only %select{'to' or 'link'|'to', 'link' or 'device_type'}1 clauses expected">; def err_omp_expected_clause: Error< diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h --- a/clang/include/clang/Parse/Parser.h +++ b/clang/include/clang/Parse/Parser.h @@ -2954,14 +2954,31 @@ /// Parses OpenMP context selectors. bool parseOMPContextSelectors(SourceLocation Loc, OMPTraitInfo &TI); + /// Parse a `match` clause for an '#pragma omp declare variant'. Return true + /// if there was an error. + bool ParseOMPDeclareVariantMatchClause(SourceLocation Loc, OMPTraitInfo &TI); + /// Parse clauses for '#pragma omp declare variant'. void ParseOMPDeclareVariantClauses(DeclGroupPtrTy Ptr, CachedTokens &Toks, SourceLocation Loc); + /// Parse '#pragma omp end declare variant'. + void ParseOMPEndDeclareVariantDirective(OpenMPDirectiveKind DKind, + SourceLocation Loc); /// Parse clauses for '#pragma omp declare target'. DeclGroupPtrTy ParseOMPDeclareTargetClauses(); /// Parse '#pragma omp end declare target'. void ParseOMPEndDeclareTargetDirective(OpenMPDirectiveKind DKind, SourceLocation Loc); + + /// Parse the "end" directive \p ExpectedKind matching the "begin" directive + /// \p MatchingKind. The actual kind found is \p FoundKind. Returns true if + /// the expected kind was not found. + bool parseOMPEndDirective(OpenMPDirectiveKind MatchingKind, + OpenMPDirectiveKind ExpectedKind, + OpenMPDirectiveKind FoundKind, + SourceLocation MatchingLoc, + SourceLocation FoundLoc); + /// Parses declarative OpenMP directives. DeclGroupPtrTy ParseOpenMPDeclarativeDirectiveWithExtDecl( AccessSpecifier &AS, ParsedAttributesWithRange &Attrs, diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -1724,9 +1724,12 @@ << ")"; } -void OMPTraitInfo::getAsVariantMatchInfo( - ASTContext &ASTCtx, llvm::omp::VariantMatchInfo &VMI) const { +void OMPTraitInfo::getAsVariantMatchInfo(ASTContext &ASTCtx, + llvm::omp::VariantMatchInfo &VMI, + bool DeviceSetOnly) const { for (const OMPTraitSet &Set : Sets) { + if (DeviceSetOnly && Set.Kind != llvm::omp::TraitSet::device) + continue; for (const OMPTraitSelector &Selector : Set.Selectors) { // User conditions are special as we evaluate the condition here. diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -941,6 +941,8 @@ break; } break; + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_unknown: @@ -1202,6 +1204,8 @@ case OMPD_end_declare_target: case OMPD_requires: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: llvm_unreachable("OpenMP Directive is not allowed"); case OMPD_unknown: llvm_unreachable("Unknown OpenMP directive"); diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -6717,6 +6717,8 @@ case OMPD_target_update: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_declare_reduction: @@ -7028,6 +7030,8 @@ case OMPD_target_update: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_declare_reduction: @@ -8807,6 +8811,8 @@ case OMPD_target_update: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_declare_reduction: @@ -9570,6 +9576,8 @@ case OMPD_target_update: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_declare_reduction: @@ -10207,6 +10215,8 @@ case OMPD_teams_distribute_parallel_for_simd: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_declare_reduction: @@ -11077,7 +11087,8 @@ for (const auto *A : FD->specific_attrs()) { const OMPTraitInfo &TI = A->getTraitInfos(); VMIs.push_back(VariantMatchInfo()); - TI.getAsVariantMatchInfo(CGM.getContext(), VMIs.back()); + TI.getAsVariantMatchInfo(CGM.getContext(), VMIs.back(), + /* DeviceSetOnly */ false); VariantExprs.push_back(A->getVariantFuncRef()); } diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp @@ -801,6 +801,8 @@ case OMPD_target_update: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_declare_reduction: @@ -877,6 +879,8 @@ case OMPD_target_update: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_declare_reduction: @@ -1046,6 +1050,8 @@ case OMPD_target_update: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_declare_reduction: @@ -1128,6 +1134,8 @@ case OMPD_target_update: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_declare_reduction: diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -47,6 +47,8 @@ OMPD_target_teams_distribute_parallel, OMPD_mapper, OMPD_variant, + OMPD_begin, + OMPD_begin_declare, }; // Helper to unify the enum class OpenMPDirectiveKind with its extension @@ -100,6 +102,7 @@ .Case("update", OMPD_update) .Case("mapper", OMPD_mapper) .Case("variant", OMPD_variant) + .Case("begin", OMPD_begin) .Default(OMPD_unknown); } @@ -108,18 +111,21 @@ // E.g.: OMPD_for OMPD_simd ===> OMPD_for_simd // TODO: add other combined directives in topological order. static const OpenMPDirectiveKindExWrapper F[][3] = { + {OMPD_begin, OMPD_declare, OMPD_begin_declare}, + {OMPD_end, OMPD_declare, OMPD_end_declare}, {OMPD_cancellation, OMPD_point, OMPD_cancellation_point}, {OMPD_declare, OMPD_reduction, OMPD_declare_reduction}, {OMPD_declare, OMPD_mapper, OMPD_declare_mapper}, {OMPD_declare, OMPD_simd, OMPD_declare_simd}, {OMPD_declare, OMPD_target, OMPD_declare_target}, {OMPD_declare, OMPD_variant, OMPD_declare_variant}, + {OMPD_begin_declare, OMPD_variant, OMPD_begin_declare_variant}, + {OMPD_end_declare, OMPD_variant, OMPD_end_declare_variant}, {OMPD_distribute, OMPD_parallel, OMPD_distribute_parallel}, {OMPD_distribute_parallel, OMPD_for, OMPD_distribute_parallel_for}, {OMPD_distribute_parallel_for, OMPD_simd, OMPD_distribute_parallel_for_simd}, {OMPD_distribute, OMPD_simd, OMPD_distribute_simd}, - {OMPD_end, OMPD_declare, OMPD_end_declare}, {OMPD_end_declare, OMPD_target, OMPD_end_declare_target}, {OMPD_target, OMPD_data, OMPD_target_data}, {OMPD_target, OMPD_enter, OMPD_target_enter}, @@ -1339,6 +1345,29 @@ return; } + OMPTraitInfo TI; + if (ParseOMPDeclareVariantMatchClause(Loc, TI)) + return; + + Optional> DeclVarData = + Actions.checkOpenMPDeclareVariantFunction( + Ptr, AssociatedFunction.get(), TI, + SourceRange(Loc, Tok.getLocation())); + + // Skip last tokens. + while (Tok.isNot(tok::annot_pragma_openmp_end)) + ConsumeAnyToken(); + if (DeclVarData.hasValue() && !TI.Sets.empty()) + Actions.ActOnOpenMPDeclareVariantDirective( + DeclVarData.getValue().first, DeclVarData.getValue().second, TI, + SourceRange(Loc, Tok.getLocation())); + + // Skip the last annot_pragma_openmp_end. + (void)ConsumeAnnotationToken(); +} + +bool Parser::ParseOMPDeclareVariantMatchClause(SourceLocation Loc, + OMPTraitInfo &TI) { // Parse 'match'. OpenMPClauseKind CKind = Tok.isAnnotation() ? OMPC_unknown @@ -1350,7 +1379,7 @@ ; // Skip the last annot_pragma_openmp_end. (void)ConsumeAnnotationToken(); - return; + return true; } (void)ConsumeToken(); // Parse '('. @@ -1361,31 +1390,15 @@ ; // Skip the last annot_pragma_openmp_end. (void)ConsumeAnnotationToken(); - return; + return true; } // Parse inner context selectors. - OMPTraitInfo TI; parseOMPContextSelectors(Loc, TI); // Parse ')' (void)T.consumeClose(); - - Optional> DeclVarData = - Actions.checkOpenMPDeclareVariantFunction( - Ptr, AssociatedFunction.get(), TI, - SourceRange(Loc, Tok.getLocation())); - - // Skip last tokens. - while (Tok.isNot(tok::annot_pragma_openmp_end)) - ConsumeAnyToken(); - if (DeclVarData.hasValue() && !TI.Sets.empty()) - Actions.ActOnOpenMPDeclareVariantDirective( - DeclVarData.getValue().first, DeclVarData.getValue().second, TI, - SourceRange(Loc, Tok.getLocation())); - - // Skip the last annot_pragma_openmp_end. - (void)ConsumeAnnotationToken(); + return false; } /// Parsing of simple OpenMP clauses like 'default' or 'proc_bind'. @@ -1523,13 +1536,28 @@ return Actions.BuildDeclaratorGroup(Decls); } +bool Parser::parseOMPEndDirective(OpenMPDirectiveKind MatchingKind, + OpenMPDirectiveKind ExpectedKind, + OpenMPDirectiveKind FoundKind, + SourceLocation MatchingLoc, + SourceLocation FoundLoc) { + int DiagSelection = ExpectedKind == OMPD_end_declare_target ? 0 : 1; + + if (FoundKind != ExpectedKind) { + Diag(FoundLoc, diag::err_expected_end_declare_target_or_variant) + << DiagSelection; + Diag(MatchingLoc, diag::note_matching) + << ("'#pragma omp " + getOpenMPDirectiveName(MatchingKind) + "'").str(); + return true; + } + return false; +} + void Parser::ParseOMPEndDeclareTargetDirective(OpenMPDirectiveKind DKind, - SourceLocation DTLoc) { - if (DKind != OMPD_end_declare_target) { - Diag(Tok, diag::err_expected_end_declare_target); - Diag(DTLoc, diag::note_matching) << "'#pragma omp declare target'"; + SourceLocation DKLoc) { + if (parseOMPEndDirective(OMPD_declare_target, OMPD_end_declare_target, DKind, + DKLoc, Tok.getLocation())) return; - } ConsumeAnyToken(); if (Tok.isNot(tok::annot_pragma_openmp_end)) { Diag(Tok, diag::warn_omp_extra_tokens_at_eol) @@ -1741,6 +1769,57 @@ } break; } + case OMPD_begin_declare_variant: { + // The syntax is: + // { #pragma omp begin declare variant clause } + // + // { #pragma omp end declare variant } + // + SourceLocation BeginLoc = ConsumeToken(); + + OMPTraitInfo TI; + if (ParseOMPDeclareVariantMatchClause(Loc, TI)) + break; + + // Skip last tokens. + while (Tok.isNot(tok::annot_pragma_openmp_end)) + ConsumeAnyToken(); + + VariantMatchInfo VMI; + ASTContext &ASTCtx = Actions.getASTContext(); + TI.getAsVariantMatchInfo(ASTCtx, VMI, /* DeviseSetOnly */ true); + OMPContext OMPCtx(ASTCtx.getLangOpts().OpenMPIsDevice, + ASTCtx.getTargetInfo().getTriple()); + + bool IsApplicableOpenMPSelector = isVariantApplicableInContext(VMI, OMPCtx); + if (IsApplicableOpenMPSelector) + break; + + // Elide all the code till the matching end declare variant was found. + unsigned Nesting = 1; + SourceLocation DKLoc; + OpenMPDirectiveKind DK = OMPD_unknown; + do { + DKLoc = Tok.getLocation(); + DK = parseOpenMPDirectiveKind(*this); + if (DK == OMPD_end_declare_variant) + --Nesting; + if (DK == OMPD_begin_declare_variant) + ++Nesting; + if (!Nesting || isEofOrEom()) + break; + ConsumeAnyToken(); + } while (true); + + parseOMPEndDirective(OMPD_begin_declare_variant, OMPD_end_declare_variant, + DK, BeginLoc, DKLoc); + break; + } + case OMPD_end_declare_variant: + // FIXME: With the sema changes we will keep track of nesting and be able to + // diagnose unmatchend OMPD_end_declare_variant. + ConsumeToken(); + break; case OMPD_declare_variant: case OMPD_declare_simd: { // The syntax is: @@ -2233,6 +2312,8 @@ case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_requires: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_variant: Diag(Tok, diag::err_omp_unexpected_directive) << 1 << getOpenMPDirectiveName(DKind); diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -3747,6 +3747,8 @@ case OMPD_end_declare_target: case OMPD_requires: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: llvm_unreachable("OpenMP Directive is not allowed"); case OMPD_unknown: llvm_unreachable("Unknown OpenMP directive"); @@ -4946,6 +4948,8 @@ case OMPD_declare_simd: case OMPD_requires: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: llvm_unreachable("OpenMP Directive is not allowed"); case OMPD_unknown: llvm_unreachable("Unknown OpenMP directive"); @@ -11074,6 +11078,8 @@ case OMPD_declare_mapper: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_teams: @@ -11144,6 +11150,8 @@ case OMPD_declare_mapper: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_teams: @@ -11219,6 +11227,8 @@ case OMPD_declare_mapper: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_simd: @@ -11291,6 +11301,8 @@ case OMPD_declare_mapper: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_simd: @@ -11364,6 +11376,8 @@ case OMPD_declare_mapper: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_simd: @@ -11436,6 +11450,8 @@ case OMPD_declare_mapper: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_simd: @@ -11507,6 +11523,8 @@ case OMPD_declare_mapper: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_simd: @@ -11581,6 +11599,8 @@ case OMPD_declare_mapper: case OMPD_declare_simd: case OMPD_declare_variant: + case OMPD_begin_declare_variant: + case OMPD_end_declare_variant: case OMPD_declare_target: case OMPD_end_declare_target: case OMPD_simd: diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -91,6 +91,8 @@ __OMP_DIRECTIVE_EXT(master_taskloop_simd, "master taskloop simd") __OMP_DIRECTIVE_EXT(parallel_master_taskloop_simd, "parallel master taskloop simd") +__OMP_DIRECTIVE_EXT(begin_declare_variant, "begin declare variant") +__OMP_DIRECTIVE_EXT(end_declare_variant, "end declare variant") // Has to be the last because Clang implicitly expects it to be. __OMP_DIRECTIVE(unknown)