diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h --- a/clang/include/clang-c/Index.h +++ b/clang/include/clang-c/Index.h @@ -2568,7 +2568,11 @@ */ CXCursor_OMPScanDirective = 287, - CXCursor_LastStmt = CXCursor_OMPScanDirective, + /** OpenMP metadirective directive. + */ + CXCursor_OMPMetaDirective = 288, + + CXCursor_LastStmt = CXCursor_OMPMetaDirective, /** * Cursor that represents the translation unit itself. diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -7862,6 +7862,75 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const OMPTraitInfo &TI); llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const OMPTraitInfo *TI); +/// This represents 'when' clause in the '#pragma omp ...' directive +/// +/// \code +/// #pragma omp metadirective when(user={condition(N<10)}: parallel) +/// \endcode +/// In this example directive '#pragma omp metadirective' has simple 'when' +/// clause with user defined condition. +class OMPWhenClause final : public OMPClause { + friend class OMPClauseReader; + + OMPTraitInfo *TI; + OpenMPDirectiveKind DKind; + Stmt *Directive; + + /// Location of '('. + SourceLocation LParenLoc; + +public: + /// Build 'when' clause with argument \a A ('none' or 'shared'). + /// + /// \param T TraitInfor containing information about the context selector + /// \param DKind The directive associated with the when clause + /// \param D The statement associated with the when clause + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + OMPWhenClause(OMPTraitInfo &T, OpenMPDirectiveKind dKind, Stmt *D, + SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_when, StartLoc, EndLoc), TI(&T), DKind(dKind), + Directive(D), LParenLoc(LParenLoc) {} + + /// Build an empty clause. + OMPWhenClause() + : OMPClause(llvm::omp::OMPC_when, SourceLocation(), SourceLocation()) {} + + /// Sets the location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + /// Returns the directive variant kind + OpenMPDirectiveKind getDKind() { return DKind; } + + Stmt *getDirective() const { return Directive; } + + /// Returns the OMPTraitInfo + OMPTraitInfo &getTI() { return *TI; } + + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + + const_child_range children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_when; + } +}; + /// Clang specific specialization of the OMPContext to lookup target features. struct TargetOMPContext final : public llvm::omp::OMPContext { diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -2777,6 +2777,9 @@ return TraverseOMPExecutableDirective(S); } +DEF_TRAVERSE_STMT(OMPMetaDirective, + { TRY_TO(TraverseOMPExecutableDirective(S)); }) + DEF_TRAVERSE_STMT(OMPParallelDirective, { TRY_TO(TraverseOMPExecutableDirective(S)); }) @@ -3031,6 +3034,18 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPWhenClause(OMPWhenClause *C) { + for (const OMPTraitSet &Set : C->getTI().Sets) { + for (const OMPTraitSelector &Selector : Set.Selectors) { + if (Selector.Kind == llvm::omp::TraitSelector::user_condition && + Selector.ScoreOrCondition) + TRY_TO(TraverseStmt(Selector.ScoreOrCondition)); + } + } + return true; +} + template bool RecursiveASTVisitor::VisitOMPDefaultClause(OMPDefaultClause *) { return true; diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -362,6 +362,44 @@ } }; +/// This represents '#pragma omp metadirective' directive. +/// +/// \code +/// #pragma omp metadirective when(user={condition(N>10)}: parallel for) +/// \endcode +/// In this example directive '#pragma omp metadirective' has clauses 'when' +/// with a dynamic user condition to check if a variable 'N > 10' +/// +class OMPMetaDirective final : public OMPExecutableDirective { + friend class ASTStmtReader; + friend class OMPExecutableDirective; + Stmt *IfStmt; + + OMPMetaDirective(SourceLocation StartLoc, SourceLocation EndLoc) + : OMPExecutableDirective(OMPMetaDirectiveClass, + llvm::omp::OMPD_metadirective, StartLoc, + EndLoc) {} + explicit OMPMetaDirective() + : OMPExecutableDirective(OMPMetaDirectiveClass, + llvm::omp::OMPD_metadirective, SourceLocation(), + SourceLocation()) {} + +public: + static OMPMetaDirective *Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation EndLoc, + ArrayRef Clauses, + Stmt *AssociatedStmt, Stmt *IfStmt); + static OMPMetaDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses, + EmptyShell); + + void setIfStmt(Stmt *stmt) { IfStmt = stmt; } + Stmt *getIfStmt() const { return IfStmt; } + + static bool classof(const Stmt *T) { + return T->getStmtClass() == OMPMetaDirectiveClass; + } +}; + /// This represents '#pragma omp parallel' directive. /// /// \code diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10393,6 +10393,9 @@ : Note<"jump bypasses OpenMP structured block">; def note_omp_exits_structured_block : Note<"jump exits scope of OpenMP structured block">; +def err_omp_misplaced_default_clause : Error< + "misplaced default clause! Only one default clause is allowed in " + "metadirective in the end">; } // end of OpenMP category let CategoryName = "Related Result Type Issue" in { diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -217,6 +217,7 @@ // OpenMP Directives. def OMPExecutableDirective : StmtNode; +def OMPMetaDirective : StmtNode; def OMPLoopDirective : StmtNode; def OMPParallelDirective : StmtNode; def OMPSimdDirective : StmtNode; diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h --- a/clang/include/clang/Parse/Parser.h +++ b/clang/include/clang/Parse/Parser.h @@ -3164,6 +3164,13 @@ /// \param StmtCtx The context in which we're parsing the directive. StmtResult ParseOpenMPDeclarativeOrExecutableDirective(ParsedStmtContext StmtCtx); + /// Parses clause for metadirective + /// + /// \param DKind Kind of current directive. + /// \param CKind Kind of current clause. + /// + OMPClause *ParseOpenMPMetaClause(OpenMPDirectiveKind DKind, + OpenMPClauseKind CKind); /// Parses clause of kind \a CKind for directive of a kind \a Kind. /// /// \param DKind Kind of current directive. diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -10137,6 +10137,11 @@ void ActOnOpenMPLoopInitialization(SourceLocation ForLoc, Stmt *Init); // OpenMP directives and clauses. + /// Called on well-formed '\#pragma omp metadirective' after parsing + /// of the associated statement. + StmtResult ActOnOpenMPMetaDirective(ArrayRef Clauses, + Stmt *AStmt, SourceLocation StartLoc, + SourceLocation EndLoc); /// Called on correct id-expression from the '#pragma omp /// threadprivate'. ExprResult ActOnOpenMPIdExpression(Scope *CurScope, CXXScopeSpec &ScopeSpec, @@ -10632,6 +10637,12 @@ SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); + /// Called on well-formed 'when' clause. + OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, OpenMPDirectiveKind DKind, + StmtResult Directive, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc); /// Called on well-formed 'default' clause. OMPClause *ActOnOpenMPDefaultClause(llvm::omp::DefaultKind Kind, SourceLocation KindLoc, diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h --- a/clang/include/clang/Serialization/ASTBitCodes.h +++ b/clang/include/clang/Serialization/ASTBitCodes.h @@ -1824,21 +1824,21 @@ /// A CXXBoolLiteralExpr record. EXPR_CXX_BOOL_LITERAL, - EXPR_CXX_NULL_PTR_LITERAL, // CXXNullPtrLiteralExpr - EXPR_CXX_TYPEID_EXPR, // CXXTypeidExpr (of expr). - EXPR_CXX_TYPEID_TYPE, // CXXTypeidExpr (of type). - EXPR_CXX_THIS, // CXXThisExpr - EXPR_CXX_THROW, // CXXThrowExpr - EXPR_CXX_DEFAULT_ARG, // CXXDefaultArgExpr - EXPR_CXX_DEFAULT_INIT, // CXXDefaultInitExpr - EXPR_CXX_BIND_TEMPORARY, // CXXBindTemporaryExpr + EXPR_CXX_NULL_PTR_LITERAL, // CXXNullPtrLiteralExpr + EXPR_CXX_TYPEID_EXPR, // CXXTypeidExpr (of expr). + EXPR_CXX_TYPEID_TYPE, // CXXTypeidExpr (of type). + EXPR_CXX_THIS, // CXXThisExpr + EXPR_CXX_THROW, // CXXThrowExpr + EXPR_CXX_DEFAULT_ARG, // CXXDefaultArgExpr + EXPR_CXX_DEFAULT_INIT, // CXXDefaultInitExpr + EXPR_CXX_BIND_TEMPORARY, // CXXBindTemporaryExpr EXPR_CXX_SCALAR_VALUE_INIT, // CXXScalarValueInitExpr EXPR_CXX_NEW, // CXXNewExpr EXPR_CXX_DELETE, // CXXDeleteExpr EXPR_CXX_PSEUDO_DESTRUCTOR, // CXXPseudoDestructorExpr - EXPR_EXPR_WITH_CLEANUPS, // ExprWithCleanups + EXPR_EXPR_WITH_CLEANUPS, // ExprWithCleanups EXPR_CXX_DEPENDENT_SCOPE_MEMBER, // CXXDependentScopeMemberExpr EXPR_CXX_DEPENDENT_SCOPE_DECL_REF, // DependentScopeDeclRefExpr @@ -1846,41 +1846,42 @@ EXPR_CXX_UNRESOLVED_MEMBER, // UnresolvedMemberExpr EXPR_CXX_UNRESOLVED_LOOKUP, // UnresolvedLookupExpr - EXPR_CXX_EXPRESSION_TRAIT, // ExpressionTraitExpr - EXPR_CXX_NOEXCEPT, // CXXNoexceptExpr + EXPR_CXX_EXPRESSION_TRAIT, // ExpressionTraitExpr + EXPR_CXX_NOEXCEPT, // CXXNoexceptExpr - EXPR_OPAQUE_VALUE, // OpaqueValueExpr - EXPR_BINARY_CONDITIONAL_OPERATOR, // BinaryConditionalOperator - EXPR_TYPE_TRAIT, // TypeTraitExpr - EXPR_ARRAY_TYPE_TRAIT, // ArrayTypeTraitIntExpr + EXPR_OPAQUE_VALUE, // OpaqueValueExpr + EXPR_BINARY_CONDITIONAL_OPERATOR, // BinaryConditionalOperator + EXPR_TYPE_TRAIT, // TypeTraitExpr + EXPR_ARRAY_TYPE_TRAIT, // ArrayTypeTraitIntExpr - EXPR_PACK_EXPANSION, // PackExpansionExpr - EXPR_SIZEOF_PACK, // SizeOfPackExpr - EXPR_SUBST_NON_TYPE_TEMPLATE_PARM, // SubstNonTypeTemplateParmExpr - EXPR_SUBST_NON_TYPE_TEMPLATE_PARM_PACK,// SubstNonTypeTemplateParmPackExpr - EXPR_FUNCTION_PARM_PACK, // FunctionParmPackExpr - EXPR_MATERIALIZE_TEMPORARY, // MaterializeTemporaryExpr - EXPR_CXX_FOLD, // CXXFoldExpr - EXPR_CONCEPT_SPECIALIZATION,// ConceptSpecializationExpr - EXPR_REQUIRES, // RequiresExpr + EXPR_PACK_EXPANSION, // PackExpansionExpr + EXPR_SIZEOF_PACK, // SizeOfPackExpr + EXPR_SUBST_NON_TYPE_TEMPLATE_PARM, // SubstNonTypeTemplateParmExpr + EXPR_SUBST_NON_TYPE_TEMPLATE_PARM_PACK, // SubstNonTypeTemplateParmPackExpr + EXPR_FUNCTION_PARM_PACK, // FunctionParmPackExpr + EXPR_MATERIALIZE_TEMPORARY, // MaterializeTemporaryExpr + EXPR_CXX_FOLD, // CXXFoldExpr + EXPR_CONCEPT_SPECIALIZATION, // ConceptSpecializationExpr + EXPR_REQUIRES, // RequiresExpr // CUDA - EXPR_CUDA_KERNEL_CALL, // CUDAKernelCallExpr + EXPR_CUDA_KERNEL_CALL, // CUDAKernelCallExpr // OpenCL - EXPR_ASTYPE, // AsTypeExpr + EXPR_ASTYPE, // AsTypeExpr // Microsoft - EXPR_CXX_PROPERTY_REF_EXPR, // MSPropertyRefExpr + EXPR_CXX_PROPERTY_REF_EXPR, // MSPropertyRefExpr EXPR_CXX_PROPERTY_SUBSCRIPT_EXPR, // MSPropertySubscriptExpr - EXPR_CXX_UUIDOF_EXPR, // CXXUuidofExpr (of expr). - EXPR_CXX_UUIDOF_TYPE, // CXXUuidofExpr (of type). - STMT_SEH_LEAVE, // SEHLeaveStmt - STMT_SEH_EXCEPT, // SEHExceptStmt - STMT_SEH_FINALLY, // SEHFinallyStmt - STMT_SEH_TRY, // SEHTryStmt + EXPR_CXX_UUIDOF_EXPR, // CXXUuidofExpr (of expr). + EXPR_CXX_UUIDOF_TYPE, // CXXUuidofExpr (of type). + STMT_SEH_LEAVE, // SEHLeaveStmt + STMT_SEH_EXCEPT, // SEHExceptStmt + STMT_SEH_FINALLY, // SEHFinallyStmt + STMT_SEH_TRY, // SEHTryStmt // OpenMP directives + STMT_OMP_META_DIRECTIVE, STMT_OMP_PARALLEL_DIRECTIVE, STMT_OMP_SIMD_DIRECTIVE, STMT_OMP_FOR_DIRECTIVE, @@ -1940,10 +1941,10 @@ EXPR_OMP_ITERATOR, // ARC - EXPR_OBJC_BRIDGED_CAST, // ObjCBridgedCastExpr + EXPR_OBJC_BRIDGED_CAST, // ObjCBridgedCastExpr - STMT_MS_DEPENDENT_EXISTS, // MSDependentExistsStmt - EXPR_LAMBDA, // LambdaExpr + STMT_MS_DEPENDENT_EXISTS, // MSDependentExistsStmt + EXPR_LAMBDA, // LambdaExpr STMT_COROUTINE_BODY, STMT_CORETURN, EXPR_COAWAIT, diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -156,6 +156,7 @@ case OMPC_exclusive: case OMPC_uses_allocators: case OMPC_affinity: + case OMPC_when: break; default: break; @@ -250,6 +251,7 @@ case OMPC_exclusive: case OMPC_uses_allocators: case OMPC_affinity: + case OMPC_when: break; default: break; @@ -1499,6 +1501,85 @@ // OpenMP clauses printing methods //===----------------------------------------------------------------------===// +void OMPClausePrinter::VisitOMPWhenClause(OMPWhenClause *Node) { + if (Node->getTI().Sets.size() == 0) { + OS << "default("; + return; + } + OS << "when("; + int count = 0; + for (const OMPTraitSet &Set : Node->getTI().Sets) { + if (count == 0) + count++; + else + OS << ", "; + for (const OMPTraitSelector &Selector : Set.Selectors) { + switch (Selector.Kind) { + case TraitSelector::device_kind: { + OS << "device={kind("; + for (const OMPTraitProperty &Property : Selector.Properties) { + OS << Property.RawString; + } + OS << ")}"; + break; + } + case TraitSelector::device_arch: { + OS << "device={arch("; + for (const OMPTraitProperty &Property : Selector.Properties) { + OS << Property.RawString; + } + OS << ")}"; + break; + } + case TraitSelector::device_isa: { + OS << "device={isa("; + for (const OMPTraitProperty &Property : Selector.Properties) { + OS << Property.RawString; + } + OS << ")}"; + break; + } + case TraitSelector::implementation_vendor: { + OS << "implementation={vendor("; + for (const OMPTraitProperty &Property : Selector.Properties) { + OS << Property.RawString; + } + OS << ")}"; + break; + } + case TraitSelector::implementation_extension: { + OS << "implementation={extension("; + for (const OMPTraitProperty &Property : Selector.Properties) { + OS << Property.RawString; + } + OS << ")}"; + break; + } + case TraitSelector::user_condition: { + OS << "user={condition("; + Selector.ScoreOrCondition->printPretty(OS, nullptr, Policy, 0); + OS << ")}"; + break; + } + case TraitSelector::invalid: + case TraitSelector::construct_target: + case TraitSelector::construct_teams: + case TraitSelector::construct_parallel: + case TraitSelector::construct_for: + case TraitSelector::construct_simd: + case TraitSelector::implementation_unified_address: + case TraitSelector::implementation_unified_shared_memory: + case TraitSelector::implementation_reverse_offload: + case TraitSelector::implementation_dynamic_allocators: + case TraitSelector::implementation_atomic_default_mem_order: + default: + break; + } + } + } + OS << ": "; +} + void OMPClausePrinter::VisitOMPIfClause(OMPIfClause *Node) { OS << "if("; if (Node->getNameModifier() != OMPD_unknown) diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -191,6 +191,25 @@ llvm::copy(A, getFinalsConditions().begin()); } +OMPMetaDirective *OMPMetaDirective::Create(const ASTContext &C, + SourceLocation StartLoc, + SourceLocation EndLoc, + ArrayRef Clauses, + Stmt *AssociatedStmt, Stmt *IfStmt) { + auto *Dir = createDirective( + C, Clauses, AssociatedStmt, /*NumChildren=*/1, StartLoc, EndLoc); + Dir->setIfStmt(IfStmt); + return Dir; +} + +OMPMetaDirective *OMPMetaDirective::CreateEmpty(const ASTContext &C, + unsigned NumClauses, + EmptyShell) { + return createEmptyDirective(C, NumClauses, + /*HasAssociatedStmt=*/true, + /*NumChildren=*/1); +} + OMPParallelDirective *OMPParallelDirective::Create( const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, ArrayRef Clauses, Stmt *AssociatedStmt, Expr *TaskRedRef, diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -644,12 +644,25 @@ if (Clause && !Clause->isImplicit()) { OS << ' '; Printer.Visit(Clause); + if (dyn_cast(S)) { + OMPWhenClause *c = dyn_cast(Clause); + if (c != NULL) { + if (c->getDKind() != llvm::omp::OMPD_unknown) + OS << getOpenMPDirectiveName(c->getDKind()); + OS << ")"; + } + } } OS << NL; if (!ForceNoStmt && S->hasAssociatedStmt()) PrintStmt(S->getRawStmt()); } +void StmtPrinter::VisitOMPMetaDirective(OMPMetaDirective *Node) { + Indent() << "#pragma omp metadirective"; + PrintOMPExecutableDirective(Node); +} + void StmtPrinter::VisitOMPParallelDirective(OMPParallelDirective *Node) { Indent() << "#pragma omp parallel"; PrintOMPExecutableDirective(Node); diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -476,6 +476,11 @@ Profiler->VisitStmt(Evt); } +void OMPClauseProfiler::VisitOMPWhenClause(const OMPWhenClause *C) { + if (C->getDirective()) + Profiler->VisitStmt(C->getDirective()); +} + void OMPClauseProfiler::VisitOMPDefaultClause(const OMPDefaultClause *C) { } void OMPClauseProfiler::VisitOMPProcBindClause(const OMPProcBindClause *C) { } @@ -847,6 +852,10 @@ P.Visit(*I); } +void StmtProfiler::VisitOMPMetaDirective(const OMPMetaDirective *S) { + VisitOMPExecutableDirective(S); +} + void StmtProfiler::VisitOMPLoopDirective(const OMPLoopDirective *S) { VisitOMPExecutableDirective(S); } diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -180,6 +180,7 @@ case OMPC_exclusive: case OMPC_uses_allocators: case OMPC_affinity: + case OMPC_when: break; default: break; @@ -420,6 +421,7 @@ case OMPC_exclusive: case OMPC_uses_allocators: case OMPC_affinity: + case OMPC_when: break; default: break; @@ -578,6 +580,9 @@ OpenMPDirectiveKind DKind) { assert(unsigned(DKind) < llvm::omp::Directive_enumSize); switch (DKind) { + case OMPD_metadirective: + CaptureRegions.push_back(OMPD_metadirective); + break; case OMPD_parallel: case OMPD_parallel_for: case OMPD_parallel_for_simd: diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -6950,6 +6950,7 @@ case OMPD_parallel_master_taskloop: case OMPD_parallel_master_taskloop_simd: case OMPD_requires: + case OMPD_metadirective: case OMPD_unknown: break; default: @@ -8983,6 +8984,7 @@ case OMPD_parallel_master_taskloop: case OMPD_parallel_master_taskloop_simd: case OMPD_requires: + case OMPD_metadirective: case OMPD_unknown: default: llvm_unreachable("Unexpected directive."); @@ -9671,6 +9673,12 @@ if (!S) return; + if (isa(S)) { + const auto &M = *cast(S); + scanForTargetRegionsFunctions(M.getIfStmt(), ParentName); + return; + } + // Codegen OMP target directives that offload compute to the device. bool RequiresDeviceCodegen = isa(S) && @@ -10451,6 +10459,7 @@ case OMPD_target_parallel_for: case OMPD_target_parallel_for_simd: case OMPD_requires: + case OMPD_metadirective: case OMPD_unknown: default: llvm_unreachable("Unexpected standalone target data directive."); diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp @@ -801,6 +801,7 @@ case OMPD_parallel_master_taskloop_simd: case OMPD_requires: case OMPD_unknown: + case OMPD_metadirective: default: llvm_unreachable("Unexpected directive."); } @@ -881,6 +882,7 @@ case OMPD_parallel_master_taskloop: case OMPD_parallel_master_taskloop_simd: case OMPD_requires: + case OMPD_metadirective: case OMPD_unknown: default: break; @@ -1056,6 +1058,7 @@ case OMPD_parallel_master_taskloop_simd: case OMPD_requires: case OMPD_unknown: + case OMPD_metadirective: default: llvm_unreachable("Unexpected directive."); } @@ -1143,6 +1146,7 @@ case OMPD_parallel_master_taskloop_simd: case OMPD_requires: case OMPD_unknown: + case OMPD_metadirective: default: break; } diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp --- a/clang/lib/CodeGen/CGStmt.cpp +++ b/clang/lib/CodeGen/CGStmt.cpp @@ -193,6 +193,9 @@ case Stmt::SEHTryStmtClass: EmitSEHTryStmt(cast(*S)); break; + case Stmt::OMPMetaDirectiveClass: + EmitOMPMetaDirective(cast(*S)); + break; case Stmt::OMPParallelDirectiveClass: EmitOMPParallelDirective(cast(*S)); break; diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -5373,6 +5373,7 @@ case OMPC_exclusive: case OMPC_uses_allocators: case OMPC_affinity: + case OMPC_when: default: llvm_unreachable("Clause is not allowed in 'omp atomic'."); } @@ -6561,6 +6562,10 @@ CGM.getOpenMPRuntime().emitMasterRegion(*this, CodeGen, S.getBeginLoc()); } +void CodeGenFunction::EmitOMPMetaDirective(const OMPMetaDirective &S) { + EmitStmt(S.getIfStmt()); +} + void CodeGenFunction::EmitOMPParallelMasterTaskLoopDirective( const OMPParallelMasterTaskLoopDirective &S) { auto &&CodeGen = [this, &S](CodeGenFunction &CGF, PrePostActionTy &Action) { diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -3349,6 +3349,7 @@ const RegionCodeGenTy &BodyGen, OMPTargetDataInfo &InputInfo); + void EmitOMPMetaDirective(const OMPMetaDirective &S); void EmitOMPParallelDirective(const OMPParallelDirective &S); void EmitOMPSimdDirective(const OMPSimdDirective &S); void EmitOMPForDirective(const OMPForDirective &S); diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -2017,6 +2017,7 @@ case OMPD_target_teams_distribute_parallel_for: case OMPD_target_teams_distribute_parallel_for_simd: case OMPD_target_teams_distribute_simd: + case OMPD_metadirective: Diag(Tok, diag::err_omp_unexpected_directive) << 1 << getOpenMPDirectiveName(DKind); break; @@ -2089,6 +2090,53 @@ bool HasAssociatedStatement = true; switch (DKind) { + case OMPD_metadirective: { + ConsumeToken(); + + ParseScope OMPDirectiveScope(this, ScopeFlags); + Actions.StartOpenMPDSABlock(DKind, DirName, Actions.getCurScope(), Loc); + + while (Tok.isNot(tok::annot_pragma_openmp_end)) { + OpenMPClauseKind CKind = Tok.isAnnotation() + ? OMPC_unknown + : getOpenMPClauseKind(PP.getSpelling(Tok)); + Actions.StartOpenMPClause(CKind); + OMPClause *Clause = ParseOpenMPMetaClause(DKind, CKind); + FirstClauses[(unsigned)CKind].setInt(true); + if (Clause) { + FirstClauses[(unsigned)CKind].setPointer(Clause); + Clauses.push_back(Clause); + } + + // Skip ',' if any. + if (Tok.is(tok::comma)) + ConsumeToken(); + Actions.EndOpenMPClause(); + // Consume trailing ')' if any + if (Tok.is(tok::r_paren)) + ConsumeAnyToken(); + } + // End location of the directive. + EndLoc = Tok.getLocation(); + // Consume final annot_pragma_openmp_end. + ConsumeAnnotationToken(); + + // The body is a block scope like in Lambdas and Blocks. + Actions.ActOnOpenMPRegionStart(DKind, getCurScope()); + ParsingOpenMPDirectiveRAII NormalScope(*this, /*Value=*/false); + StmtResult AStmt = ParseStatement(); + StmtResult AssociatedStmt = (Sema::CompoundScopeRAII(Actions), AStmt); + AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses); + + Directive = Actions.ActOnOpenMPExecutableDirective( + DKind, DirName, CancelRegion, Clauses, AssociatedStmt.get(), Loc, + EndLoc); + + // Exit scope. + Actions.EndOpenMPDSABlock(Directive.get()); + OMPDirectiveScope.Exit(); + break; + } case OMPD_threadprivate: { // FIXME: Should this be permitted in C++? if ((StmtCtx & ParsedStmtContext::AllowDeclarationsInC) == @@ -2488,6 +2536,145 @@ T.getCloseLocation(), Data); } +OMPClause *Parser::ParseOpenMPMetaClause(OpenMPDirectiveKind DKind, + OpenMPClauseKind CKind) { + OMPClause *Clause = nullptr; + bool ErrorFound = false; + bool WrongDirective = false; + SmallVector, + llvm::omp::Clause_enumSize + 1> + FirstClauses(llvm::omp::Clause_enumSize + 1); + + // Check if it is called from metadirective. + if (DKind != OMPD_metadirective) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + ErrorFound = true; + } + + // Check if clause is allowed for the given directive. + if (CKind != OMPC_unknown && + !isAllowedClauseForDirective(DKind, CKind, getLangOpts().OpenMP)) { + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + ErrorFound = true; + WrongDirective = true; + } + + if (CKind == OMPC_default || CKind == OMPC_when) { + SourceLocation Loc = ConsumeToken(); + SourceLocation DelimLoc; + // Parse '('. + BalancedDelimiterTracker T(*this, tok::l_paren, + tok::annot_pragma_openmp_end); + if (T.expectAndConsume(diag::err_expected_lparen_after, + getOpenMPClauseName(CKind).data())) + return nullptr; + + OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo(); + if (CKind == OMPC_when) { + // parse and get condition expression to pass to the When clause + parseOMPContextSelectors(Loc, TI); + + // Parse ':' + if (Tok.is(tok::colon)) + ConsumeAnyToken(); + else { + Diag(Tok, diag::warn_pragma_expected_colon) << "when clause"; + return nullptr; + } + } + + // Parse Directive + OpenMPDirectiveKind DirKind = OMPD_unknown; + SmallVector Clauses; + StmtResult AssociatedStmt; + StmtResult Directive = StmtError(); + + if (Tok.isNot(tok::r_paren)) { + ParsingOpenMPDirectiveRAII DirScope(*this); + ParenBraceBracketBalancer BalancerRAIIObj(*this); + DeclarationNameInfo DirName; + unsigned ScopeFlags = Scope::FnScope | Scope::DeclScope | + Scope::CompoundStmtScope | + Scope::OpenMPDirectiveScope; + + DirKind = parseOpenMPDirectiveKind(*this); + ConsumeToken(); + ParseScope OMPDirectiveScope(this, ScopeFlags); + Actions.StartOpenMPDSABlock(DirKind, DirName, Actions.getCurScope(), Loc); + + int paren = 0; + while (Tok.isNot(tok::r_paren) || paren != 0) { + if (Tok.is(tok::l_paren)) + paren++; + if (Tok.is(tok::r_paren)) + paren--; + + OpenMPClauseKind CKind = Tok.isAnnotation() + ? OMPC_unknown + : getOpenMPClauseKind(PP.getSpelling(Tok)); + Actions.StartOpenMPClause(CKind); + + OMPClause *Clause = ParseOpenMPClause( + DirKind, CKind, !FirstClauses[(unsigned)CKind].getInt()); + FirstClauses[(unsigned)CKind].setInt(true); + if (Clause) { + FirstClauses[(unsigned)CKind].setPointer(Clause); + Clauses.push_back(Clause); + } + + // Skip ',' if any. + if (Tok.is(tok::comma)) + ConsumeToken(); + Actions.EndOpenMPClause(); + } + + Actions.ActOnOpenMPRegionStart(DirKind, getCurScope()); + ParsingOpenMPDirectiveRAII NormalScope(*this, /*Value=*/false); + + /* Get Stmt and revert back */ + TentativeParsingAction TPA(*this); + while (Tok.isNot(tok::annot_pragma_openmp_end)) { + ConsumeAnyToken(); + } + ConsumeAnnotationToken(); + ParseScope InnerStmtScope(this, Scope::DeclScope, + getLangOpts().C99 || getLangOpts().CPlusPlus, + Tok.is(tok::l_brace)); + StmtResult AStmt = ParseStatement(); + InnerStmtScope.Exit(); + TPA.Revert(); + /* End Get Stmt */ + + AssociatedStmt = (Sema::CompoundScopeRAII(Actions), AStmt); + AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses); + + Directive = Actions.ActOnOpenMPExecutableDirective( + DirKind, DirName, OMPD_unknown, llvm::makeArrayRef(Clauses), + AssociatedStmt.get(), Loc, Tok.getLocation()); + + Actions.EndOpenMPDSABlock(Directive.get()); + OMPDirectiveScope.Exit(); + } + + // Parse ')' + T.consumeClose(); + + if (WrongDirective) + return nullptr; + + Clause = Actions.ActOnOpenMPWhenClause(TI, DirKind, Directive, Loc, + DelimLoc, Tok.getLocation()); + } else { + ErrorFound = false; + Diag(Tok, diag::err_omp_unexpected_clause) + << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); + } + + return ErrorFound ? nullptr : Clause; +} + /// Parsing of OpenMP clauses. /// /// clause: @@ -2691,6 +2878,7 @@ case OMPC_threadprivate: case OMPC_uniform: case OMPC_match: + case OMPC_when: if (!WrongDirective) Diag(Tok, diag::err_omp_unexpected_clause) << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind); diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp --- a/clang/lib/Sema/SemaExceptionSpec.cpp +++ b/clang/lib/Sema/SemaExceptionSpec.cpp @@ -1486,6 +1486,7 @@ case Stmt::OMPTeamsDistributeParallelForDirectiveClass: case Stmt::OMPTeamsDistributeParallelForSimdDirectiveClass: case Stmt::OMPTeamsDistributeSimdDirectiveClass: + case Stmt::OMPMetaDirectiveClass: case Stmt::ReturnStmtClass: case Stmt::SEHExceptStmtClass: case Stmt::SEHFinallyStmtClass: diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -3744,6 +3744,7 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) { switch (DKind) { + case OMPD_metadirective: case OMPD_parallel: case OMPD_parallel_for: case OMPD_parallel_for_simd: @@ -5127,6 +5128,11 @@ llvm::SmallVector AllowedNameModifiers; switch (Kind) { + case OMPD_metadirective: + Res = + ActOnOpenMPMetaDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc); + AllowedNameModifiers.push_back(OMPD_metadirective); + break; case OMPD_parallel: Res = ActOnOpenMPParallelDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc); @@ -5546,6 +5552,7 @@ case OMPC_atomic_default_mem_order: case OMPC_device_type: case OMPC_match: + case OMPC_when: default: llvm_unreachable("Unexpected clause"); } @@ -6307,6 +6314,117 @@ FD->addAttr(NewAttr); } +StmtResult Sema::ActOnOpenMPMetaDirective(ArrayRef Clauses, + Stmt *AStmt, SourceLocation StartLoc, + SourceLocation EndLoc) { + if (!AStmt) + return StmtError(); + + auto *CS = cast(AStmt); + // 1.2.2 OpenMP Language Terminology + // Structured block - An executable statement with a single entry at the + // top and a single exit at the bottom. + // The point of exit cannot be a branch out of the structured block. + // longjmp() and throw() must not violate the entry/exit criteria. + CS->getCapturedDecl()->setNothrow(); + + StmtResult IfStmt = StmtError(); + Stmt *ElseStmt = NULL; + + for (auto i = Clauses.rbegin(); i < Clauses.rend(); i++) { + OMPWhenClause *WhenClause = dyn_cast(*i); + Expr *WhenCondExpr = NULL; + Stmt *ThenStmt = NULL; + OpenMPDirectiveKind DKind = WhenClause->getDKind(); + + if (DKind != OMPD_unknown) + ThenStmt = CompoundStmt::Create(Context, {WhenClause->getDirective()}, + SourceLocation(), SourceLocation()); + + for (const OMPTraitSet &Set : WhenClause->getTI().Sets) { + for (const OMPTraitSelector &Selector : Set.Selectors) { + switch (Selector.Kind) { + case TraitSelector::device_arch: { + bool archMatch = false; + for (const OMPTraitProperty &Property : Selector.Properties) { + for (auto &T : getLangOpts().OMPTargetTriples) { + if (T.getArchName() == Property.RawString) { + archMatch = true; + break; + } + } + if (archMatch) + break; + } + // Create a true/false boolean expression and assign to WhenCondExpr + auto *C = new (Context) + CXXBoolLiteralExpr(archMatch, Context.BoolTy, StartLoc); + WhenCondExpr = dyn_cast(C); + break; + } + case TraitSelector::user_condition: { + assert(Selector.ScoreOrCondition && + "Ill-formed user condition, expected condition expression!"); + + WhenCondExpr = Selector.ScoreOrCondition; + break; + } + case TraitSelector::implementation_vendor: { + bool vendorMatch = false; + for (const OMPTraitProperty &Property : Selector.Properties) { + for (auto &T : getLangOpts().OMPTargetTriples) { + if (T.getVendorName() == Property.RawString) { + vendorMatch = true; + break; + } + } + if (vendorMatch) + break; + } + // Create a true/false boolean expression and assign to WhenCondExpr + auto *C = new (Context) + CXXBoolLiteralExpr(vendorMatch, Context.BoolTy, StartLoc); + WhenCondExpr = dyn_cast(C); + break; + } + case TraitSelector::device_isa: + case TraitSelector::device_kind: + case TraitSelector::implementation_extension: + default: + break; + } + } + } + + if (WhenCondExpr == NULL) { + if (ElseStmt != NULL) { + Diag(WhenClause->getBeginLoc(), diag::err_omp_misplaced_default_clause); + return StmtError(); + } + if (DKind == OMPD_unknown) + ElseStmt = CompoundStmt::Create(Context, {CS->getCapturedStmt()}, + SourceLocation(), SourceLocation()); + else + ElseStmt = ThenStmt; + continue; + } + + if (ThenStmt == NULL) + ThenStmt = CompoundStmt::Create(Context, {CS->getCapturedStmt()}, + SourceLocation(), SourceLocation()); + + IfStmt = + ActOnIfStmt(SourceLocation(), false, SourceLocation(), NULL, + ActOnCondition(getCurScope(), SourceLocation(), + WhenCondExpr, Sema::ConditionKind::Boolean), + SourceLocation(), ThenStmt, SourceLocation(), ElseStmt); + ElseStmt = IfStmt.get(); + } + + return OMPMetaDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, + IfStmt.get()); +} + StmtResult Sema::ActOnOpenMPParallelDirective(ArrayRef Clauses, Stmt *AStmt, SourceLocation StartLoc, @@ -11804,6 +11922,7 @@ case OMPC_exclusive: case OMPC_uses_allocators: case OMPC_affinity: + case OMPC_when: default: llvm_unreachable("Clause is not allowed."); } @@ -11957,6 +12076,7 @@ case OMPD_atomic: case OMPD_teams_distribute: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with if-clause"); case OMPD_unknown: default: @@ -12036,6 +12156,7 @@ case OMPD_teams_distribute: case OMPD_teams_distribute_simd: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with num_threads-clause"); case OMPD_unknown: default: @@ -12113,6 +12234,7 @@ case OMPD_atomic: case OMPD_distribute_simd: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with num_teams-clause"); case OMPD_unknown: default: @@ -12190,6 +12312,7 @@ case OMPD_atomic: case OMPD_distribute_simd: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with thread_limit-clause"); case OMPD_unknown: default: @@ -12267,6 +12390,7 @@ case OMPD_distribute_simd: case OMPD_target_teams: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with schedule clause"); case OMPD_unknown: default: @@ -12344,6 +12468,7 @@ case OMPD_atomic: case OMPD_target_teams: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with schedule clause"); case OMPD_unknown: default: @@ -12421,6 +12546,7 @@ case OMPD_atomic: case OMPD_distribute_simd: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with num_teams-clause"); case OMPD_unknown: default: @@ -12500,12 +12626,22 @@ case OMPD_atomic: case OMPD_distribute_simd: case OMPD_requires: + case OMPD_metadirective: llvm_unreachable("Unexpected OpenMP directive with grainsize-clause"); case OMPD_unknown: default: llvm_unreachable("Unknown OpenMP directive"); } break; + case OMPC_when: + if (DKind == OMPD_metadirective) { + CaptureRegion = OMPD_metadirective; + } else if (DKind == OMPD_unknown) { + llvm_unreachable("Unknown OpenMP directive"); + } else { + llvm_unreachable("Unexpected OpenMP directive with when clause"); + } + break; case OMPC_firstprivate: case OMPC_lastprivate: case OMPC_reduction: @@ -13012,6 +13148,7 @@ case OMPC_exclusive: case OMPC_uses_allocators: case OMPC_affinity: + case OMPC_when: default: llvm_unreachable("Clause is not allowed."); } @@ -13039,6 +13176,14 @@ return std::string(Out.str()); } +OMPClause * +Sema::ActOnOpenMPWhenClause(OMPTraitInfo &TI, OpenMPDirectiveKind DKind, + StmtResult Directive, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc) { + return new (Context) + OMPWhenClause(TI, DKind, Directive.get(), StartLoc, LParenLoc, EndLoc); +} + OMPClause *Sema::ActOnOpenMPDefaultClause(DefaultKind Kind, SourceLocation KindKwLoc, SourceLocation StartLoc, @@ -13251,6 +13396,7 @@ case OMPC_exclusive: case OMPC_uses_allocators: case OMPC_affinity: + case OMPC_when: default: llvm_unreachable("Clause is not allowed."); } @@ -13489,6 +13635,7 @@ case OMPC_exclusive: case OMPC_uses_allocators: case OMPC_affinity: + case OMPC_when: default: llvm_unreachable("Clause is not allowed."); } @@ -13767,6 +13914,7 @@ case OMPC_destroy: case OMPC_detach: case OMPC_uses_allocators: + case OMPC_when: default: llvm_unreachable("Clause is not allowed."); } diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -1638,6 +1638,18 @@ EndLoc); } + /// Build a new OpenMP 'when' clause. + /// + /// By default, performs semantic analysis to build the new OpenMP clause. + /// Subclasses may override this routine to provide different behavior. + OMPClause *RebuildOMPWhenClause(OMPTraitInfo &TI, OpenMPDirectiveKind DKind, + Stmt *Directive, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + return getSema().ActOnOpenMPWhenClause(TI, DKind, Directive, StartLoc, + LParenLoc, EndLoc); + } + /// Build a new OpenMP 'default' clause. /// /// By default, performs semantic analysis to build the new OpenMP clause. @@ -8376,6 +8388,17 @@ AssociatedStmt.get(), D->getBeginLoc(), D->getEndLoc()); } +template +StmtResult +TreeTransform::TransformOMPMetaDirective(OMPMetaDirective *D) { + DeclarationNameInfo DirName; + getDerived().getSema().StartOpenMPDSABlock(OMPD_metadirective, DirName, + nullptr, D->getBeginLoc()); + StmtResult Res = getDerived().TransformOMPExecutableDirective(D); + getDerived().getSema().EndOpenMPDSABlock(Res.get()); + return Res; +} + template StmtResult TreeTransform::TransformOMPParallelDirective(OMPParallelDirective *D) { @@ -9050,6 +9073,13 @@ E.get(), C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc()); } +template +OMPClause *TreeTransform::TransformOMPWhenClause(OMPWhenClause *C) { + return getDerived().RebuildOMPWhenClause(C->getTI(), C->getDKind(), + C->getDirective(), C->getBeginLoc(), + C->getLParenLoc(), C->getEndLoc()); +} + template OMPClause * TreeTransform::TransformOMPDefaultClause(OMPDefaultClause *C) { diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -11753,6 +11753,9 @@ OMPClause *OMPClauseReader::readClause() { OMPClause *C = nullptr; switch (llvm::omp::Clause(Record.readInt())) { + case llvm::omp::OMPC_when: + C = new (Context) OMPWhenClause(); + break; case llvm::omp::OMPC_if: C = new (Context) OMPIfClause(); break; @@ -12069,6 +12072,10 @@ C->setLParenLoc(Record.readSourceLocation()); } +void OMPClauseReader::VisitOMPWhenClause(OMPWhenClause *C) { + C->setLParenLoc(Record.readSourceLocation()); +} + void OMPClauseReader::VisitOMPDefaultClause(OMPDefaultClause *C) { C->setDefaultKind(static_cast(Record.readInt())); C->setLParenLoc(Record.readSourceLocation()); diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2280,6 +2280,13 @@ VisitOMPExecutableDirective(D); } +void ASTStmtReader::VisitOMPMetaDirective(OMPMetaDirective *D) { + VisitStmt(D); + // The NumClauses field was read in ReadStmtFromStream. + Record.skipInts(1); + VisitOMPExecutableDirective(D); +} + void ASTStmtReader::VisitOMPParallelDirective(OMPParallelDirective *D) { VisitStmt(D); VisitOMPExecutableDirective(D); @@ -3120,6 +3127,11 @@ nullptr); break; + case STMT_OMP_META_DIRECTIVE: + S = OMPMetaDirective::CreateEmpty( + Context, Record[ASTStmtReader::NumStmtFields], Empty); + break; + case STMT_OMP_PARALLEL_DIRECTIVE: S = OMPParallelDirective::CreateEmpty(Context, diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -6192,6 +6192,10 @@ Record.AddSourceLocation(C->getLParenLoc()); } +void OMPClauseWriter::VisitOMPWhenClause(OMPWhenClause *C) { + Record.AddSourceLocation(C->getLParenLoc()); +} + void OMPClauseWriter::VisitOMPDefaultClause(OMPDefaultClause *C) { Record.push_back(unsigned(C->getDefaultKind())); Record.AddSourceLocation(C->getLParenLoc()); diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -2171,6 +2171,13 @@ Record.AddSourceLocation(E->getEndLoc()); } +void ASTStmtWriter::VisitOMPMetaDirective(OMPMetaDirective *D) { + VisitStmt(D); + Record.push_back(D->getNumClauses()); + VisitOMPExecutableDirective(D); + Code = serialization::STMT_OMP_META_DIRECTIVE; +} + void ASTStmtWriter::VisitOMPLoopDirective(OMPLoopDirective *D) { VisitStmt(D); Record.writeUInt32(D->getCollapsedNumber()); diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp --- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp +++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp @@ -1292,6 +1292,7 @@ case Stmt::OMPTargetTeamsDistributeParallelForDirectiveClass: case Stmt::OMPTargetTeamsDistributeParallelForSimdDirectiveClass: case Stmt::OMPTargetTeamsDistributeSimdDirectiveClass: + case Stmt::OMPMetaDirectiveClass: case Stmt::CapturedStmtClass: { const ExplodedNode *node = Bldr.generateSink(S, Pred, Pred->getState()); Engine.addAbortedBlock(node, currBldrCtx->getBlock()); diff --git a/clang/test/OpenMP/metadirective_ast_print.cpp b/clang/test/OpenMP/metadirective_ast_print.cpp new file mode 100644 --- /dev/null +++ b/clang/test/OpenMP/metadirective_ast_print.cpp @@ -0,0 +1,24 @@ +// RUN: %clang_cc1 -verify -fopenmp -ast-print %s | FileCheck %s +// expected-no-diagnostics + +int main() { + int N = 15; +#pragma omp metadirective when(user = {condition(N > 10)} \ + : parallel for) default() + for (int i = 0; i < N; i++) + ; + +#pragma omp metadirective when(user = {condition(N < 10)} \ + :) default(parallel for) + for (int i = 0; i < N; i++) + ; + +#pragma omp metadirective when(device = {arch("nvptx64")}, user = {condition(N >= 100)} \ + : parallel for) + for (int i = 0; i < N; i++) + ; + return 0; +} +// CHECK: #pragma omp metadirective when(user={condition(N > 10)}: parallel for) default() +// CHECK: #pragma omp metadirective when(user={condition(N < 10)}: ) default(parallel for) +// CHECK: #pragma omp metadirective when(device={arch(nvptx64)}, user={condition(N >= 100)}: parallel for) diff --git a/clang/test/OpenMP/metadirective_codegen.cpp b/clang/test/OpenMP/metadirective_codegen.cpp new file mode 100644 --- /dev/null +++ b/clang/test/OpenMP/metadirective_codegen.cpp @@ -0,0 +1,35 @@ +// RUN: %clang_cc1 -verify -fopenmp -emit-llvm %s -o - | FileCheck %s +// expected-no-diagnostics + +int main(int argc, char **argv) { + int N = 15; +#pragma omp metadirective when(user = {condition(N <= 1)} \ + : parallel) \ + when(user = {condition(N > 10)} \ + : parallel for) + for (int i = 0; i < N; i++) + ; + // CHECK: %cmp{{[0-9]*}} = icmp sle i32 %{{[0-9]+}}, 1 + // CHECK: br i1 %cmp{{[0-9]*}}, label %if.then{{[0-9]*}}, label %if.else{{[0-9]*}} + // CHECK: if.then{{[0-9]*}}: + // CHECK: call {{.*}}void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @0, i32 1, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i32*)* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), i32* %N) + // CHECK: br label %if.end{{[0-9]*}} + // CHECK: if.else{{[0-9]*}}: + // CHECK: %{{[0-9]+}} = load i32, i32* %N, align 4 + // CHECK: %cmp{{[0-9]+}} = icmp sgt i32 %{{[0-9]+}}, 10 + // CHECK: br i1 %cmp{{[0-9]*}}, label %if.then{{[0-9]*}}, label %if.end{{[0-9]*}} + // CHECK: if.then{{[0-9]*}}: + // CHECK: call {{.*}}void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @0, i32 1, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i32*)* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), i32* %N) + // CHECK: br label %if.end + // CHECK: if.end{{[0-9]*}}: + // CHECK: br label %if.end{{[0-9]*}} + // CHECK: if.end{{[0-9]*}}: + // CHECK: ret i32 0 + + return 0; +} +// CHECK: define internal void [[OMP_OUTLINED:@.*]](i32* noalias %.global_tid., i32* noalias %.bound_tid., i32* nonnull align 4 dereferenceable(4) %N) +// CHECK: declare !callback !{{[0-9]+}} void @__kmpc_fork_call(%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) +// CHECK: define internal void [[OMP_OUTLINED:@.+]](i32* noalias %.global_tid., i32* noalias %.bound_tid., i32* nonnull align 4 dereferenceable(4) %N) +// CHECK: declare void @__kmpc_for_static_init_4(%struct.ident_t*, i32, i32, i32*, i32*, i32*, i32*, i32, i32) +// CHECK: declare void @__kmpc_for_static_fini(%struct.ident_t*, i32) diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2219,6 +2219,10 @@ Visitor->AddStmt(C->getNumForLoops()); } +void OMPClauseEnqueue::VisitOMPWhenClause(const OMPWhenClause *C) { + Visitor->AddStmt(C->getDirective()); +} + void OMPClauseEnqueue::VisitOMPDefaultClause(const OMPDefaultClause *C) {} void OMPClauseEnqueue::VisitOMPProcBindClause(const OMPProcBindClause *C) {} @@ -5524,6 +5528,8 @@ return cxstring::createRef("CXXAccessSpecifier"); case CXCursor_ModuleImportDecl: return cxstring::createRef("ModuleImport"); + case CXCursor_OMPMetaDirective: + return cxstring::createRef("OMPMetaDirective"); case CXCursor_OMPParallelDirective: return cxstring::createRef("OMPParallelDirective"); case CXCursor_OMPSimdDirective: diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp --- a/clang/tools/libclang/CXCursor.cpp +++ b/clang/tools/libclang/CXCursor.cpp @@ -639,6 +639,9 @@ case Stmt::MSDependentExistsStmtClass: K = CXCursor_UnexposedStmt; break; + case Stmt::OMPMetaDirectiveClass: + K = CXCursor_OMPMetaDirective; + break; case Stmt::OMPParallelDirectiveClass: K = CXCursor_OMPParallelDirective; break; diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -59,6 +59,9 @@ let clangClass = "OMPCollapseClause"; let flangClassValue = "ScalarIntConstantExpr"; } +def OMPC_When: Clause<"when"> { + let clangClass = "OMPWhenClause"; +} def OMPC_Default : Clause<"default"> { let clangClass = "OMPDefaultClause"; let flangClass = "OmpDefaultClause"; @@ -294,6 +297,14 @@ // Definition of OpenMP directives //===----------------------------------------------------------------------===// +def OMP_Metadirective : Directive<"metadirective"> { + let allowedClauses = [ + VersionedClause + ]; + let allowedOnceClauses = [ + VersionedClause + ]; +} def OMP_ThreadPrivate : Directive<"threadprivate"> {} def OMP_Parallel : Directive<"parallel"> { let allowedClauses = [ diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -111,6 +111,7 @@ __OMP_CLAUSE(uses_allocators, OMPUsesAllocatorsClause) __OMP_CLAUSE(affinity, OMPAffinityClause) __OMP_CLAUSE(use_device_addr, OMPUseDeviceAddrClause) +__OMP_CLAUSE(when, OMPWhenClause) __OMP_CLAUSE_NO_CLASS(uniform) __OMP_CLAUSE_NO_CLASS(device_type)