diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -7649,6 +7649,76 @@ } }; +/// This represents 'novariants' clause in the '#pragma omp ...' directive. +/// +/// \code +/// #pragma omp dispatch novariants(a > 5) +/// \endcode +/// In this example directive '#pragma omp dispatch' has simple 'novariants' +/// clause with condition 'a > 5'. +class OMPNovariantsClause final: public OMPClause, public OMPClauseWithPreInit { + friend class OMPClauseReader; + + /// Location of '('. + SourceLocation LParenLoc; + + /// Condition of the 'if' clause. + Stmt *Condition = nullptr; + + /// Set condition. + void setCondition(Expr *Cond) { Condition = Cond; } + + /// Sets the location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + +public: + /// Build 'novariants' clause with condition \a Cond. + /// + /// \param Cond Condition of the clause. + /// \param HelperCond Helper condition for the construct. + /// \param CaptureRegion Innermost OpenMP region where expressions in this + /// clause must be captured. + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + OMPNovariantsClause(Expr *Cond, Stmt *HelperCond, + OpenMPDirectiveKind CaptureRegion, + SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_novariants, StartLoc, EndLoc), + OMPClauseWithPreInit(this), LParenLoc(LParenLoc), Condition(Cond) { + setPreInitStmt(HelperCond, CaptureRegion); + } + + /// Build an empty clause. + OMPNovariantsClause() + : OMPClause(llvm::omp::OMPC_novariants, SourceLocation(), + SourceLocation()), + OMPClauseWithPreInit(this) {} + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + /// Returns condition. + Expr *getCondition() const { return cast_or_null(Condition); } + + child_range children() { return child_range(&Condition, &Condition + 1); } + + const_child_range children() const { + return const_child_range(&Condition, &Condition + 1); + } + + child_range used_children(); + const_child_range used_children() const { + auto Children = const_cast(this)->used_children(); + return const_child_range(Children.begin(), Children.end()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_novariants; + } +}; + /// This represents 'detach' clause in the '#pragma omp task' directive. /// /// \code diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3218,6 +3218,14 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPNovariantsClause( + OMPNovariantsClause *C) { + TRY_TO(VisitOMPClauseWithPreInit(C)); + TRY_TO(TraverseStmt(C->getCondition())); + return true; +} + template template bool RecursiveASTVisitor::VisitOMPClauseList(T *Node) { diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11017,7 +11017,11 @@ SourceLocation LParenLoc, SourceLocation VarLoc, SourceLocation EndLoc); - + /// Called on well-formed 'novariants' clause. + OMPClause *ActOnOpenMPNovariantsClause(Expr *Condition, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc); /// Called on well-formed 'threads' clause. OMPClause *ActOnOpenMPThreadsClause(SourceLocation StartLoc, SourceLocation EndLoc); diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -96,6 +96,8 @@ return static_cast(C); case OMPC_priority: return static_cast(C); + case OMPC_novariants: + return static_cast(C); case OMPC_default: case OMPC_proc_bind: case OMPC_safelen: @@ -244,6 +246,7 @@ case OMPC_nontemporal: case OMPC_order: case OMPC_destroy: + case OMPC_novariants: case OMPC_detach: case OMPC_inclusive: case OMPC_exclusive: @@ -300,6 +303,12 @@ return child_range(&Priority, &Priority + 1); } +OMPClause::child_range OMPNovariantsClause::used_children() { + if (Stmt **C = getAddrOfExprAsWritten(getPreInitStmt())) + return child_range(C, C + 1); + return child_range(&Condition, &Condition + 1); +} + OMPOrderedClause *OMPOrderedClause::Create(const ASTContext &C, Expr *Num, unsigned NumLoops, SourceLocation StartLoc, @@ -1816,6 +1825,15 @@ } } +void OMPClausePrinter::VisitOMPNovariantsClause(OMPNovariantsClause *Node) { + OS << "novariants"; + if (Expr *E = Node->getCondition()) { + OS << "("; + E->printPretty(OS, nullptr, Policy, 0); + OS << ")"; + } +} + template void OMPClausePrinter::VisitOMPClauseList(T *Node, char StartSym) { for (typename T::varlist_iterator I = Node->varlist_begin(), diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -483,6 +483,12 @@ Profiler->VisitStmt(Evt); } +void OMPClauseProfiler::VisitOMPNovariantsClause(const OMPNovariantsClause *C) { + VistOMPClauseWithPreInit(C); + if (C->getCondition()) + Profiler->VisitStmt(C->getCondition()); +} + void OMPClauseProfiler::VisitOMPDefaultClause(const OMPDefaultClause *C) { } void OMPClauseProfiler::VisitOMPProcBindClause(const OMPProcBindClause *C) { } diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -176,6 +176,7 @@ case OMPC_match: case OMPC_nontemporal: case OMPC_destroy: + case OMPC_novariants: case OMPC_detach: case OMPC_inclusive: case OMPC_exclusive: @@ -418,6 +419,7 @@ case OMPC_nontemporal: case OMPC_destroy: case OMPC_detach: + case OMPC_novariants: case OMPC_inclusive: case OMPC_exclusive: case OMPC_uses_allocators: diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -2776,6 +2776,7 @@ case OMPC_allocator: case OMPC_depobj: case OMPC_detach: + case OMPC_novariants: // OpenMP [2.5, Restrictions] // At most one num_threads clause can appear on the directive. // OpenMP [2.8.1, simd construct, Restrictions] @@ -2798,6 +2799,8 @@ // At most one allocator clause can appear on the directive. // OpenMP 5.0, 2.10.1 task Construct, Restrictions. // At most one detach clause can appear on the directive. + // OpenMP 5.1, 2.3.6 dispatch Construct, Restrictions. + // At most one novariants clause can appear on a dispatch directive. if (!FirstClause) { Diag(Tok, diag::err_omp_more_one_clause) << getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0; diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -6173,6 +6173,7 @@ case OMPC_num_tasks: case OMPC_final: case OMPC_priority: + case OMPC_novariants: // Do not analyze if no parent parallel directive. if (isOpenMPParallelDirective(Kind)) break; @@ -12785,6 +12786,9 @@ case OMPC_detach: Res = ActOnOpenMPDetachClause(Expr, StartLoc, LParenLoc, EndLoc); break; + case OMPC_novariants: + Res = ActOnOpenMPNovariantsClause(Expr, StartLoc, LParenLoc, EndLoc); + break; case OMPC_device: case OMPC_if: case OMPC_default: @@ -13557,6 +13561,15 @@ llvm_unreachable("Unknown OpenMP directive"); } break; + case OMPC_novariants: + switch (DKind) { + case OMPD_dispatch: + CaptureRegion = OMPD_task; + break; + default: + llvm_unreachable("Unknown OpenMP directive"); + } + break; case OMPC_firstprivate: case OMPC_lastprivate: case OMPC_reduction: @@ -14061,6 +14074,7 @@ case OMPC_match: case OMPC_nontemporal: case OMPC_destroy: + case OMPC_novariants: case OMPC_detach: case OMPC_inclusive: case OMPC_exclusive: @@ -14317,6 +14331,7 @@ case OMPC_nontemporal: case OMPC_order: case OMPC_destroy: + case OMPC_novariants: case OMPC_detach: case OMPC_inclusive: case OMPC_exclusive: @@ -14558,6 +14573,7 @@ case OMPC_match: case OMPC_nontemporal: case OMPC_order: + case OMPC_novariants: case OMPC_detach: case OMPC_inclusive: case OMPC_exclusive: @@ -14848,6 +14864,36 @@ OMPDestroyClause(InteropVar, StartLoc, LParenLoc, VarLoc, EndLoc); } +OMPClause *Sema::ActOnOpenMPNovariantsClause(Expr *Condition, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + Expr *ValExpr = Condition; + Stmt *HelperValStmt = nullptr; + OpenMPDirectiveKind CaptureRegion = OMPD_unknown; + if (!Condition->isValueDependent() && !Condition->isTypeDependent() && + !Condition->isInstantiationDependent() && + !Condition->containsUnexpandedParameterPack()) { + ExprResult Val = CheckBooleanCondition(StartLoc, Condition); + if (Val.isInvalid()) + return nullptr; + + ValExpr = MakeFullExpr(Val.get()).get(); + + OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective(); + CaptureRegion = getOpenMPCaptureRegionForClause(DKind, OMPC_novariants, + LangOpts.OpenMP); + if (CaptureRegion != OMPD_unknown && !CurContext->isDependentContext()) { + ValExpr = MakeFullExpr(ValExpr).get(); + llvm::MapVector Captures; + HelperValStmt = buildPreInits(Context, Captures); + } + } + + return new (Context) OMPNovariantsClause( + ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc); +} + OMPClause *Sema::ActOnOpenMPVarListClause( OpenMPClauseKind Kind, ArrayRef VarList, Expr *DepModOrTailExpr, const OMPVarListLocTy &Locs, SourceLocation ColonLoc, @@ -15018,6 +15064,7 @@ case OMPC_match: case OMPC_order: case OMPC_destroy: + case OMPC_novariants: case OMPC_detach: case OMPC_uses_allocators: default: diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -2208,6 +2208,18 @@ VarLoc, EndLoc); } + /// Build a new OpenMP 'novariants' clause. + /// + /// By default, performs semantic analysis to build the new OpenMP clause. + /// Subclasses may override this routine to provide different behavior. + OMPClause *RebuildOMPNovariantsClause(Expr *Condition, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + return getSema().ActOnOpenMPNovariantsClause(Condition, StartLoc, LParenLoc, + EndLoc); + } + /// Rebuild the operand to an Objective-C \@synchronized statement. /// /// By default, performs semantic analysis to build the new statement. @@ -9377,6 +9389,16 @@ C->getEndLoc()); } +template +OMPClause * +TreeTransform::TransformOMPNovariantsClause(OMPNovariantsClause *C) { + ExprResult Cond = getDerived().TransformExpr(C->getCondition()); + if (Cond.isInvalid()) + return nullptr; + return getDerived().RebuildOMPNovariantsClause( + Cond.get(), C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc()); +} + template OMPClause *TreeTransform::TransformOMPUnifiedAddressClause( OMPUnifiedAddressClause *C) { diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -11977,6 +11977,9 @@ case llvm::omp::OMPC_destroy: C = new (Context) OMPDestroyClause(); break; + case llvm::omp::OMPC_novariants: + C = new (Context) OMPNovariantsClause(); + break; case llvm::omp::OMPC_detach: C = new (Context) OMPDetachClause(); break; @@ -12162,6 +12165,12 @@ C->setVarLoc(Record.readSourceLocation()); } +void OMPClauseReader::VisitOMPNovariantsClause(OMPNovariantsClause *C) { + VisitOMPClauseWithPreInit(C); + C->setCondition(Record.readSubExpr()); + C->setLParenLoc(Record.readSourceLocation()); +} + void OMPClauseReader::VisitOMPUnifiedAddressClause(OMPUnifiedAddressClause *) {} void OMPClauseReader::VisitOMPUnifiedSharedMemoryClause( diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -6237,6 +6237,12 @@ Record.AddSourceLocation(C->getVarLoc()); } +void OMPClauseWriter::VisitOMPNovariantsClause(OMPNovariantsClause *C) { + VisitOMPClauseWithPreInit(C); + Record.AddStmt(C->getCondition()); + Record.AddSourceLocation(C->getLParenLoc()); +} + void OMPClauseWriter::VisitOMPPrivateClause(OMPPrivateClause *C) { Record.push_back(C->varlist_size()); Record.AddSourceLocation(C->getLParenLoc()); diff --git a/clang/test/OpenMP/dispatch_ast_print.cpp b/clang/test/OpenMP/dispatch_ast_print.cpp --- a/clang/test/OpenMP/dispatch_ast_print.cpp +++ b/clang/test/OpenMP/dispatch_ast_print.cpp @@ -51,20 +51,22 @@ void test_one() { int aaa, bbb, var; - //PRINT: #pragma omp dispatch depend(in : var) nowait + //PRINT: #pragma omp dispatch depend(in : var) nowait novariants(aaa > 5) //DUMP: OMPDispatchDirective //DUMP: OMPDependClause //DUMP: OMPNowaitClause - #pragma omp dispatch depend(in:var) nowait + //DUMP: OMPNovariantsClause + #pragma omp dispatch depend(in:var) nowait novariants(aaa > 5) foo(aaa, &bbb); int *dp = get_device_ptr(); int dev = get_device(); - //PRINT: #pragma omp dispatch device(dev) is_device_ptr(dp) + //PRINT: #pragma omp dispatch device(dev) is_device_ptr(dp) novariants(dev > 10) //DUMP: OMPDispatchDirective //DUMP: OMPDeviceClause //DUMP: OMPIs_device_ptrClause - #pragma omp dispatch device(dev) is_device_ptr(dp) + //DUMP: OMPNovariantsClause + #pragma omp dispatch device(dev) is_device_ptr(dp) novariants(dev > 10) foo(aaa, dp); //PRINT: #pragma omp dispatch diff --git a/clang/test/OpenMP/dispatch_messages.cpp b/clang/test/OpenMP/dispatch_messages.cpp --- a/clang/test/OpenMP/dispatch_messages.cpp +++ b/clang/test/OpenMP/dispatch_messages.cpp @@ -28,6 +28,24 @@ // expected-error@+1 {{cannot contain more than one 'nowait' clause}} #pragma omp dispatch nowait device(dnum) nowait disp_call(); + + // expected-error@+1 {{expected '(' after 'novariants'}} + #pragma omp dispatch novariants + disp_call(); + + // expected-error@+3 {{expected expression}} + // expected-error@+2 {{expected ')'}} + // expected-note@+1 {{to match this '('}} + #pragma omp dispatch novariants ( + disp_call(); + + // expected-error@+1 {{cannot contain more than one 'novariants' clause}} + #pragma omp dispatch novariants(dnum> 4) novariants(3) + disp_call(); + + // expected-error@+1 {{use of undeclared identifier 'x'}} + #pragma omp dispatch novariants(x) + disp_call(); } void testit_two() { diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2291,6 +2291,10 @@ Visitor->AddStmt(C->getInteropVar()); } +void OMPClauseEnqueue::VisitOMPNovariantsClause(const OMPNovariantsClause *C) { + Visitor->AddStmt(C->getCondition()); +} + void OMPClauseEnqueue::VisitOMPUnifiedAddressClause( const OMPUnifiedAddressClause *) {} diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -729,6 +729,7 @@ CHECK_SIMPLE_CLAUSE(Write, OMPC_write) CHECK_SIMPLE_CLAUSE(Init, OMPC_init) CHECK_SIMPLE_CLAUSE(Use, OMPC_use) +CHECK_SIMPLE_CLAUSE(Novariants, OMPC_novariants) CHECK_REQ_SCALAR_INT_CLAUSE(Allocator, OMPC_allocator) CHECK_REQ_SCALAR_INT_CLAUSE(Grainsize, OMPC_grainsize) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -276,6 +276,10 @@ def OMPC_Destroy : Clause<"destroy"> { let clangClass = "OMPDestroyClause"; } +def OMPC_Novariants : Clause<"novariants"> { + let clangClass = "OMPNovariantsClause"; + let flangClass = "ScalarLogicalExpr"; +} def OMPC_Detach : Clause<"detach"> { let clangClass = "OMPDetachClause"; } @@ -1660,7 +1664,8 @@ VersionedClause, VersionedClause, VersionedClause, - VersionedClause + VersionedClause, + VersionedClause ]; } def OMP_Unknown : Directive<"unknown"> {