diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -7720,6 +7720,75 @@ } }; +/// This represents 'nocontext' clause in the '#pragma omp ...' directive. +/// +/// \code +/// #pragma omp dispatch nocontext(a > 5) +/// \endcode +/// In this example directive '#pragma omp dispatch' has simple 'nocontext' +/// clause with condition 'a > 5'. +class OMPNocontextClause final : public OMPClause, public OMPClauseWithPreInit { + friend class OMPClauseReader; + + /// Location of '('. + SourceLocation LParenLoc; + + /// Condition of the 'if' clause. + Stmt *Condition = nullptr; + + /// Set condition. + void setCondition(Expr *Cond) { Condition = Cond; } + + /// Sets the location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + +public: + /// Build 'nocontext' clause with condition \a Cond. + /// + /// \param Cond Condition of the clause. + /// \param HelperCond Helper condition for the construct. + /// \param CaptureRegion Innermost OpenMP region where expressions in this + /// clause must be captured. + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + OMPNocontextClause(Expr *Cond, Stmt *HelperCond, + OpenMPDirectiveKind CaptureRegion, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_nocontext, StartLoc, EndLoc), + OMPClauseWithPreInit(this), LParenLoc(LParenLoc), Condition(Cond) { + setPreInitStmt(HelperCond, CaptureRegion); + } + + /// Build an empty clause. + OMPNocontextClause() + : OMPClause(llvm::omp::OMPC_nocontext, SourceLocation(), + SourceLocation()), + OMPClauseWithPreInit(this) {} + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + /// Returns condition. + Expr *getCondition() const { return cast_or_null(Condition); } + + child_range children() { return child_range(&Condition, &Condition + 1); } + + const_child_range children() const { + return const_child_range(&Condition, &Condition + 1); + } + + child_range used_children(); + const_child_range used_children() const { + auto Children = const_cast(this)->used_children(); + return const_child_range(Children.begin(), Children.end()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_nocontext; + } +}; + /// This represents 'detach' clause in the '#pragma omp task' directive. /// /// \code diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3226,6 +3226,14 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPNocontextClause( + OMPNocontextClause *C) { + TRY_TO(VisitOMPClauseWithPreInit(C)); + TRY_TO(TraverseStmt(C->getCondition())); + return true; +} + template template bool RecursiveASTVisitor::VisitOMPClauseList(T *Node) { diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11014,6 +11014,11 @@ SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); + /// Called on well-formed 'nocontext' clause. + OMPClause *ActOnOpenMPNocontextClause(Expr *Condition, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc); /// Called on well-formed 'threads' clause. OMPClause *ActOnOpenMPThreadsClause(SourceLocation StartLoc, SourceLocation EndLoc); diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -98,6 +98,8 @@ return static_cast(C); case OMPC_novariants: return static_cast(C); + case OMPC_nocontext: + return static_cast(C); case OMPC_default: case OMPC_proc_bind: case OMPC_safelen: @@ -247,6 +249,7 @@ case OMPC_order: case OMPC_destroy: case OMPC_novariants: + case OMPC_nocontext: case OMPC_detach: case OMPC_inclusive: case OMPC_exclusive: @@ -309,6 +312,12 @@ return child_range(&Condition, &Condition + 1); } +OMPClause::child_range OMPNocontextClause::used_children() { + if (Stmt **C = getAddrOfExprAsWritten(getPreInitStmt())) + return child_range(C, C + 1); + return child_range(&Condition, &Condition + 1); +} + OMPOrderedClause *OMPOrderedClause::Create(const ASTContext &C, Expr *Num, unsigned NumLoops, SourceLocation StartLoc, @@ -1834,6 +1843,15 @@ } } +void OMPClausePrinter::VisitOMPNocontextClause(OMPNocontextClause *Node) { + OS << "nocontext"; + if (Expr *E = Node->getCondition()) { + OS << "("; + E->printPretty(OS, nullptr, Policy, 0); + OS << ")"; + } +} + template void OMPClausePrinter::VisitOMPClauseList(T *Node, char StartSym) { for (typename T::varlist_iterator I = Node->varlist_begin(), diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -489,6 +489,12 @@ Profiler->VisitStmt(C->getCondition()); } +void OMPClauseProfiler::VisitOMPNocontextClause(const OMPNocontextClause *C) { + VistOMPClauseWithPreInit(C); + if (C->getCondition()) + Profiler->VisitStmt(C->getCondition()); +} + void OMPClauseProfiler::VisitOMPDefaultClause(const OMPDefaultClause *C) { } void OMPClauseProfiler::VisitOMPProcBindClause(const OMPProcBindClause *C) { } diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -177,6 +177,7 @@ case OMPC_nontemporal: case OMPC_destroy: case OMPC_novariants: + case OMPC_nocontext: case OMPC_detach: case OMPC_inclusive: case OMPC_exclusive: @@ -420,6 +421,7 @@ case OMPC_destroy: case OMPC_detach: case OMPC_novariants: + case OMPC_nocontext: case OMPC_inclusive: case OMPC_exclusive: case OMPC_uses_allocators: diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -5613,6 +5613,7 @@ case OMPC_link: case OMPC_use: case OMPC_novariants: + case OMPC_nocontext: llvm_unreachable("Clause is not allowed in 'omp atomic'."); } } diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -2777,6 +2777,7 @@ case OMPC_depobj: case OMPC_detach: case OMPC_novariants: + case OMPC_nocontext: // OpenMP [2.5, Restrictions] // At most one num_threads clause can appear on the directive. // OpenMP [2.8.1, simd construct, Restrictions] @@ -2801,6 +2802,7 @@ // At most one detach clause can appear on the directive. // OpenMP 5.1, 2.3.6 dispatch Construct, Restrictions. // At most one novariants clause can appear on a dispatch directive. + // At most one nocontext clause can appear on a dispatch directive. if (!FirstClause) { Diag(Tok, diag::err_omp_more_one_clause) << getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0; diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -6174,6 +6174,7 @@ case OMPC_final: case OMPC_priority: case OMPC_novariants: + case OMPC_nocontext: // Do not analyze if no parent parallel directive. if (isOpenMPParallelDirective(Kind)) break; @@ -12789,6 +12790,9 @@ case OMPC_novariants: Res = ActOnOpenMPNovariantsClause(Expr, StartLoc, LParenLoc, EndLoc); break; + case OMPC_nocontext: + Res = ActOnOpenMPNocontextClause(Expr, StartLoc, LParenLoc, EndLoc); + break; case OMPC_device: case OMPC_if: case OMPC_default: @@ -13562,12 +13566,13 @@ } break; case OMPC_novariants: + case OMPC_nocontext: switch (DKind) { case OMPD_dispatch: CaptureRegion = OMPD_task; break; default: - llvm_unreachable("Unknown OpenMP directive"); + llvm_unreachable("Unexpected OpenMP directive"); } break; case OMPC_firstprivate: @@ -14075,6 +14080,7 @@ case OMPC_nontemporal: case OMPC_destroy: case OMPC_novariants: + case OMPC_nocontext: case OMPC_detach: case OMPC_inclusive: case OMPC_exclusive: @@ -14343,6 +14349,7 @@ case OMPC_order: case OMPC_destroy: case OMPC_novariants: + case OMPC_nocontext: case OMPC_detach: case OMPC_inclusive: case OMPC_exclusive: @@ -14585,6 +14592,7 @@ case OMPC_nontemporal: case OMPC_order: case OMPC_novariants: + case OMPC_nocontext: case OMPC_detach: case OMPC_inclusive: case OMPC_exclusive: @@ -14905,6 +14913,37 @@ ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc); } +OMPClause *Sema::ActOnOpenMPNocontextClause(Expr *Condition, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + Expr *ValExpr = Condition; + Stmt *HelperValStmt = nullptr; + OpenMPDirectiveKind CaptureRegion = OMPD_unknown; + if (!Condition->isValueDependent() && !Condition->isTypeDependent() && + !Condition->isInstantiationDependent() && + !Condition->containsUnexpandedParameterPack()) { + ExprResult Val = CheckBooleanCondition(StartLoc, Condition); + if (Val.isInvalid()) + return nullptr; + + ValExpr = MakeFullExpr(Val.get()).get(); + + OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective(); + CaptureRegion = + getOpenMPCaptureRegionForClause(DKind, OMPC_nocontext, LangOpts.OpenMP); + if (CaptureRegion != OMPD_unknown && !CurContext->isDependentContext()) { + ValExpr = MakeFullExpr(ValExpr).get(); + llvm::MapVector Captures; + ValExpr = tryBuildCapture(*this, ValExpr, Captures).get(); + HelperValStmt = buildPreInits(Context, Captures); + } + } + + return new (Context) OMPNocontextClause(ValExpr, HelperValStmt, CaptureRegion, + StartLoc, LParenLoc, EndLoc); +} + OMPClause *Sema::ActOnOpenMPVarListClause( OpenMPClauseKind Kind, ArrayRef VarList, Expr *DepModOrTailExpr, const OMPVarListLocTy &Locs, SourceLocation ColonLoc, @@ -15076,6 +15115,7 @@ case OMPC_order: case OMPC_destroy: case OMPC_novariants: + case OMPC_nocontext: case OMPC_detach: case OMPC_uses_allocators: default: diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -2220,6 +2220,17 @@ EndLoc); } + /// Build a new OpenMP 'nocontext' clause. + /// + /// By default, performs semantic analysis to build the new OpenMP clause. + /// Subclasses may override this routine to provide different behavior. + OMPClause *RebuildOMPNocontextClause(Expr *Condition, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + return getSema().ActOnOpenMPNocontextClause(Condition, StartLoc, LParenLoc, + EndLoc); + } + /// Rebuild the operand to an Objective-C \@synchronized statement. /// /// By default, performs semantic analysis to build the new statement. @@ -9399,6 +9410,16 @@ Cond.get(), C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc()); } +template +OMPClause * +TreeTransform::TransformOMPNocontextClause(OMPNocontextClause *C) { + ExprResult Cond = getDerived().TransformExpr(C->getCondition()); + if (Cond.isInvalid()) + return nullptr; + return getDerived().RebuildOMPNocontextClause( + Cond.get(), C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc()); +} + template OMPClause *TreeTransform::TransformOMPUnifiedAddressClause( OMPUnifiedAddressClause *C) { diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -11980,6 +11980,9 @@ case llvm::omp::OMPC_novariants: C = new (Context) OMPNovariantsClause(); break; + case llvm::omp::OMPC_nocontext: + C = new (Context) OMPNocontextClause(); + break; case llvm::omp::OMPC_detach: C = new (Context) OMPDetachClause(); break; @@ -12171,6 +12174,12 @@ C->setLParenLoc(Record.readSourceLocation()); } +void OMPClauseReader::VisitOMPNocontextClause(OMPNocontextClause *C) { + VisitOMPClauseWithPreInit(C); + C->setCondition(Record.readSubExpr()); + C->setLParenLoc(Record.readSourceLocation()); +} + void OMPClauseReader::VisitOMPUnifiedAddressClause(OMPUnifiedAddressClause *) {} void OMPClauseReader::VisitOMPUnifiedSharedMemoryClause( diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -6243,6 +6243,12 @@ Record.AddSourceLocation(C->getLParenLoc()); } +void OMPClauseWriter::VisitOMPNocontextClause(OMPNocontextClause *C) { + VisitOMPClauseWithPreInit(C); + Record.AddStmt(C->getCondition()); + Record.AddSourceLocation(C->getLParenLoc()); +} + void OMPClauseWriter::VisitOMPPrivateClause(OMPPrivateClause *C) { Record.push_back(C->varlist_size()); Record.AddSourceLocation(C->getLParenLoc()); diff --git a/clang/test/OpenMP/dispatch_ast_print.cpp b/clang/test/OpenMP/dispatch_ast_print.cpp --- a/clang/test/OpenMP/dispatch_ast_print.cpp +++ b/clang/test/OpenMP/dispatch_ast_print.cpp @@ -51,22 +51,22 @@ void test_one() { int aaa, bbb, var; - //PRINT: #pragma omp dispatch depend(in : var) nowait novariants(aaa > 5) + //PRINT: #pragma omp dispatch depend(in : var) nowait novariants(aaa > 5) nocontext(bbb > 5) //DUMP: OMPDispatchDirective //DUMP: OMPDependClause //DUMP: OMPNowaitClause //DUMP: OMPNovariantsClause - #pragma omp dispatch depend(in:var) nowait novariants(aaa > 5) + #pragma omp dispatch depend(in:var) nowait novariants(aaa > 5) nocontext(bbb > 5) foo(aaa, &bbb); int *dp = get_device_ptr(); int dev = get_device(); - //PRINT: #pragma omp dispatch device(dev) is_device_ptr(dp) novariants(dev > 10) + //PRINT: #pragma omp dispatch device(dev) is_device_ptr(dp) novariants(dev > 10) nocontext(dev > 5) //DUMP: OMPDispatchDirective //DUMP: OMPDeviceClause //DUMP: OMPIs_device_ptrClause //DUMP: OMPNovariantsClause - #pragma omp dispatch device(dev) is_device_ptr(dp) novariants(dev > 10) + #pragma omp dispatch device(dev) is_device_ptr(dp) novariants(dev > 10) nocontext(dev > 5) foo(aaa, dp); //PRINT: #pragma omp dispatch diff --git a/clang/test/OpenMP/dispatch_messages.cpp b/clang/test/OpenMP/dispatch_messages.cpp --- a/clang/test/OpenMP/dispatch_messages.cpp +++ b/clang/test/OpenMP/dispatch_messages.cpp @@ -46,6 +46,24 @@ // expected-error@+1 {{use of undeclared identifier 'x'}} #pragma omp dispatch novariants(x) disp_call(); + + // expected-error@+1 {{expected '(' after 'nocontext'}} + #pragma omp dispatch nocontext + disp_call(); + + // expected-error@+3 {{expected expression}} + // expected-error@+2 {{expected ')'}} + // expected-note@+1 {{to match this '('}} + #pragma omp dispatch nocontext ( + disp_call(); + + // expected-error@+1 {{cannot contain more than one 'nocontext' clause}} + #pragma omp dispatch nocontext(dnum> 4) nocontext(3) + disp_call(); + + // expected-error@+1 {{use of undeclared identifier 'x'}} + #pragma omp dispatch nocontext(x) + disp_call(); } void testit_two() { diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2295,6 +2295,10 @@ Visitor->AddStmt(C->getCondition()); } +void OMPClauseEnqueue::VisitOMPNocontextClause(const OMPNocontextClause *C) { + Visitor->AddStmt(C->getCondition()); +} + void OMPClauseEnqueue::VisitOMPUnifiedAddressClause( const OMPUnifiedAddressClause *) {} diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -730,6 +730,7 @@ CHECK_SIMPLE_CLAUSE(Init, OMPC_init) CHECK_SIMPLE_CLAUSE(Use, OMPC_use) CHECK_SIMPLE_CLAUSE(Novariants, OMPC_novariants) +CHECK_SIMPLE_CLAUSE(Nocontext, OMPC_nocontext) CHECK_REQ_SCALAR_INT_CLAUSE(Allocator, OMPC_allocator) CHECK_REQ_SCALAR_INT_CLAUSE(Grainsize, OMPC_grainsize) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -282,6 +282,10 @@ let clangClass = "OMPNovariantsClause"; let flangClass = "ScalarLogicalExpr"; } +def OMPC_Nocontext : Clause<"nocontext"> { + let clangClass = "OMPNocontextClause"; + let flangClass = "ScalarLogicalExpr"; +} def OMPC_Detach : Clause<"detach"> { let clangClass = "OMPDetachClause"; } @@ -1667,7 +1671,8 @@ VersionedClause, VersionedClause, VersionedClause, - VersionedClause + VersionedClause, + VersionedClause ]; } def OMP_Unknown : Directive<"unknown"> {