Index: clang/include/clang/AST/OpenMPClause.h =================================================================== --- clang/include/clang/AST/OpenMPClause.h +++ clang/include/clang/AST/OpenMPClause.h @@ -8405,6 +8405,96 @@ } }; +/// This represents 'bind' clause in the '#pragma omp ...' directives. +/// +/// \code +/// #pragma omp loop bind(parallel) +/// \endcode +class OMPBindClause final : public OMPClause { + friend class OMPClauseReader; + + /// Location of '('. + SourceLocation LParenLoc; + + /// The binding kind of 'bind' clause. + OpenMPBindClauseKind Kind = OMPC_BIND_unknown; + + /// Start location of the kind in source code. + SourceLocation KindLoc; + + /// Sets the location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + + /// Set the binding kind. + void setBindKind(OpenMPBindClauseKind K) { Kind = K; } + + /// Set the binding kind location. + void setBindKindLoc(SourceLocation KLoc) { KindLoc = KLoc; } + + /// Build 'bind' clause with kind \a K ('teams', 'parallel', or 'thread'). + /// + /// \param K Binding kind of the clause ('teams', 'parallel' or 'thread'). + /// \param KLoc Starting location of the binding kind. + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + OMPBindClause(OpenMPBindClauseKind K, SourceLocation KLoc, + SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_bind, StartLoc, EndLoc), LParenLoc(LParenLoc), + Kind(K), KindLoc(KLoc) {} + + /// Build an empty clause. + OMPBindClause() + : OMPClause(llvm::omp::OMPC_bind, SourceLocation(), SourceLocation()) {} + +public: + /// Build 'bind' clause with kind \a K ('teams', 'parallel', or 'thread'). + /// + /// \param C AST context + /// \param K Binding kind of the clause ('teams', 'parallel' or 'thread'). + /// \param KLoc Starting location of the binding kind. + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + static OMPBindClause *Create(const ASTContext &C, OpenMPBindClauseKind K, + SourceLocation KLoc, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc); + + /// Build an empty 'bind' clause. + /// + /// \param C AST context + static OMPBindClause *CreateEmpty(const ASTContext &C); + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + /// Returns kind of the clause. + OpenMPBindClauseKind getBindKind() const { return Kind; } + + /// Returns location of clause kind. + SourceLocation getBindKindLoc() const { return KindLoc; } + + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + + const_child_range children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_bind; + } +}; + /// This class implements a simple visitor for OMPClause /// subclasses. template class Ptr, typename RetTy> Index: clang/include/clang/AST/RecursiveASTVisitor.h =================================================================== --- clang/include/clang/AST/RecursiveASTVisitor.h +++ clang/include/clang/AST/RecursiveASTVisitor.h @@ -3676,6 +3676,11 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPBindClause(OMPBindClause *C) { + return true; +} + // FIXME: look at the following tricky-seeming exprs to see if we // need to recurse on anything. These are ones that have methods // returning decls or qualtypes or nestednamespecifier -- though I'm Index: clang/include/clang/Basic/DiagnosticSemaKinds.td =================================================================== --- clang/include/clang/Basic/DiagnosticSemaKinds.td +++ clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10808,6 +10808,9 @@ def err_omp_lastprivate_loop_var_non_loop_iteration : Error< "only loop iteration variables are allowed in 'lastprivate' clause in " "'omp loop' directives">; +def err_omp_loop_directive_without_bind : Error< + "'omp loop' directive without 'bind' clause must be nested in another " + "construct">; def err_omp_interop_variable_expected : Error< "expected%select{| non-const}0 variable of type 'omp_interop_t'">; def err_omp_interop_variable_wrong_type : Error< Index: clang/include/clang/Basic/OpenMPKinds.h =================================================================== --- clang/include/clang/Basic/OpenMPKinds.h +++ clang/include/clang/Basic/OpenMPKinds.h @@ -174,6 +174,13 @@ OMPC_ADJUST_ARGS_unknown, }; +/// OpenMP bindings for the 'bind' clause. +enum OpenMPBindClauseKind { +#define OPENMP_BIND_KIND(Name) OMPC_BIND_##Name, +#include "clang/Basic/OpenMPKinds.def" + OMPC_BIND_unknown +}; + unsigned getOpenMPSimpleClauseType(OpenMPClauseKind Kind, llvm::StringRef Str, const LangOptions &LangOpts); const char *getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind, unsigned Type); Index: clang/include/clang/Basic/OpenMPKinds.def =================================================================== --- clang/include/clang/Basic/OpenMPKinds.def +++ clang/include/clang/Basic/OpenMPKinds.def @@ -62,6 +62,9 @@ #ifndef OPENMP_ADJUST_ARGS_KIND #define OPENMP_ADJUST_ARGS_KIND(Name) #endif +#ifndef OPENMP_BIND_KIND +#define OPENMP_BIND_KIND(Name) +#endif // Static attributes for 'schedule' clause. OPENMP_SCHEDULE_KIND(static) @@ -156,6 +159,12 @@ OPENMP_ADJUST_ARGS_KIND(nothing) OPENMP_ADJUST_ARGS_KIND(need_device_ptr) +// Binding kinds for the 'bind' clause. +OPENMP_BIND_KIND(teams) +OPENMP_BIND_KIND(parallel) +OPENMP_BIND_KIND(thread) + +#undef OPENMP_BIND_KIND #undef OPENMP_ADJUST_ARGS_KIND #undef OPENMP_REDUCTION_MODIFIER #undef OPENMP_DEVICE_MODIFIER Index: clang/include/clang/Sema/Sema.h =================================================================== --- clang/include/clang/Sema/Sema.h +++ clang/include/clang/Sema/Sema.h @@ -11416,6 +11416,12 @@ SourceLocation ColonLoc, SourceLocation EndLoc, Expr *Modifier, ArrayRef Locators); + /// Called on a well-formed 'bind' clause. + OMPClause *ActOnOpenMPBindClause(OpenMPBindClauseKind Kind, + SourceLocation KindLoc, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc); /// The kind of conversion being performed. enum CheckedConversionKind { Index: clang/lib/AST/OpenMPClause.cpp =================================================================== --- clang/lib/AST/OpenMPClause.cpp +++ clang/lib/AST/OpenMPClause.cpp @@ -161,6 +161,7 @@ case OMPC_uses_allocators: case OMPC_affinity: case OMPC_when: + case OMPC_bind: break; default: break; @@ -259,6 +260,7 @@ case OMPC_uses_allocators: case OMPC_affinity: case OMPC_when: + case OMPC_bind: break; default: break; @@ -1586,6 +1588,16 @@ return new (Mem) OMPInitClause(N); } +OMPBindClause * +OMPBindClause::Create(const ASTContext &C, OpenMPBindClauseKind K, + SourceLocation KLoc, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc) { + return new (C) OMPBindClause(K, KLoc, StartLoc, LParenLoc, EndLoc); +} + +OMPBindClause *OMPBindClause::CreateEmpty(const ASTContext &C) { + return new (C) OMPBindClause(); +} //===----------------------------------------------------------------------===// // OpenMP clauses printing methods //===----------------------------------------------------------------------===// @@ -2297,6 +2309,12 @@ OS << ")"; } +void OMPClausePrinter::VisitOMPBindClause(OMPBindClause *Node) { + OS << "bind(" + << getOpenMPSimpleClauseTypeName(OMPC_bind, unsigned(Node->getBindKind())) + << ")"; +} + void OMPTraitInfo::getAsVariantMatchInfo(ASTContext &ASTCtx, VariantMatchInfo &VMI) const { for (const OMPTraitSet &Set : Sets) { Index: clang/lib/AST/StmtProfile.cpp =================================================================== --- clang/lib/AST/StmtProfile.cpp +++ clang/lib/AST/StmtProfile.cpp @@ -878,6 +878,7 @@ Profiler->VisitStmt(E); } void OMPClauseProfiler::VisitOMPOrderClause(const OMPOrderClause *C) {} +void OMPClauseProfiler::VisitOMPBindClause(const OMPBindClause *C) {} } // namespace void Index: clang/lib/Basic/OpenMPKinds.cpp =================================================================== --- clang/lib/Basic/OpenMPKinds.cpp +++ clang/lib/Basic/OpenMPKinds.cpp @@ -130,6 +130,11 @@ #define OPENMP_ADJUST_ARGS_KIND(Name) .Case(#Name, OMPC_ADJUST_ARGS_##Name) #include "clang/Basic/OpenMPKinds.def" .Default(OMPC_ADJUST_ARGS_unknown); + case OMPC_bind: + return llvm::StringSwitch(Str) +#define OPENMP_BIND_KIND(Name) .Case(#Name, OMPC_BIND_##Name) +#include "clang/Basic/OpenMPKinds.def" + .Default(OMPC_BIND_unknown); case OMPC_unknown: case OMPC_threadprivate: case OMPC_if: @@ -385,6 +390,16 @@ #include "clang/Basic/OpenMPKinds.def" } llvm_unreachable("Invalid OpenMP 'adjust_args' clause kind"); + case OMPC_bind: + switch (Type) { + case OMPC_BIND_unknown: + return "unknown"; +#define OPENMP_BIND_KIND(Name) \ + case OMPC_BIND_##Name: \ + return #Name; +#include "clang/Basic/OpenMPKinds.def" + } + llvm_unreachable("Invalid OpenMP 'bind' clause type"); case OMPC_unknown: case OMPC_threadprivate: case OMPC_if: Index: clang/lib/CodeGen/CGStmtOpenMP.cpp =================================================================== --- clang/lib/CodeGen/CGStmtOpenMP.cpp +++ clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -5992,6 +5992,7 @@ case OMPC_adjust_args: case OMPC_append_args: case OMPC_memory_order: + case OMPC_bind: llvm_unreachable("Clause is not allowed in 'omp atomic'."); } } Index: clang/lib/Parse/ParseOpenMP.cpp =================================================================== --- clang/lib/Parse/ParseOpenMP.cpp +++ clang/lib/Parse/ParseOpenMP.cpp @@ -3056,7 +3056,7 @@ /// clause: /// if-clause | final-clause | num_threads-clause | safelen-clause | /// default-clause | private-clause | firstprivate-clause | shared-clause -/// | linear-clause | aligned-clause | collapse-clause | +/// | linear-clause | aligned-clause | collapse-clause | bind-clause | /// lastprivate-clause | reduction-clause | proc_bind-clause | /// schedule-clause | copyin-clause | copyprivate-clause | untied-clause | /// mergeable-clause | flush-clause | read-clause | write-clause | @@ -3146,6 +3146,7 @@ case OMPC_proc_bind: case OMPC_atomic_default_mem_order: case OMPC_order: + case OMPC_bind: // OpenMP [2.14.3.1, Restrictions] // Only a single default clause may be specified on a parallel, task or // teams directive. @@ -3154,6 +3155,8 @@ // OpenMP [5.0, Requires directive, Restrictions] // At most one atomic_default_mem_order clause can appear // on the directive + // OpenMP 5.1, 2.11.7 loop Construct, Restrictions. + // At most one bind clause can appear on a loop directive. if (!FirstClause && CKind != OMPC_order) { Diag(Tok, diag::err_omp_more_one_clause) << getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0; @@ -3500,6 +3503,9 @@ /// proc_bind-clause: /// 'proc_bind' '(' 'master' | 'close' | 'spread' ')' /// +/// bind-clause: +/// 'bind' '(' 'teams' | 'parallel' | 'thread' ')' +/// /// update-clause: /// 'update' '(' 'in' | 'out' | 'inout' | 'mutexinoutset' ')' /// Index: clang/lib/Sema/SemaOpenMP.cpp =================================================================== --- clang/lib/Sema/SemaOpenMP.cpp +++ clang/lib/Sema/SemaOpenMP.cpp @@ -4687,6 +4687,7 @@ OpenMPDirectiveKind CurrentRegion, const DeclarationNameInfo &CurrentName, OpenMPDirectiveKind CancelRegion, + OpenMPBindClauseKind BindKind, SourceLocation StartLoc) { if (Stack->getCurScope()) { OpenMPDirectiveKind ParentRegion = Stack->getParentDirective(); @@ -4702,6 +4703,14 @@ ShouldBeInTeamsRegion, ShouldBeInLoopSimdRegion, } Recommend = NoRecommend; + // OpenMP 5.1 [2.11.7, loop Construct, Restrictions] + // If a loop construct is not nested inside another OpenMP construct and it + // appears in a procedure, the bind clause must be present. + if (CurrentRegion == OMPD_loop && ParentRegion == OMPD_unknown && + BindKind == OMPC_BIND_unknown) { + SemaRef.Diag(StartLoc, diag::err_omp_loop_directive_without_bind); + return true; + } if (isOpenMPSimdDirective(ParentRegion) && ((SemaRef.LangOpts.OpenMP <= 45 && CurrentRegion != OMPD_ordered) || (SemaRef.LangOpts.OpenMP >= 50 && CurrentRegion != OMPD_ordered && @@ -4897,6 +4906,16 @@ CurrentRegion != OMPD_loop; Recommend = ShouldBeInParallelRegion; } + if (!NestingProhibited && CurrentRegion == OMPD_loop) { + // OpenMP [5.1, 2.11.7, loop Construct, Restrictions] + // If the bind clause is present on the loop construct and binding is + // teams then the corresponding loop region must be strictly nested inside + // a teams region. + NestingProhibited = BindKind == OMPC_BIND_teams && + ParentRegion != OMPD_teams && + ParentRegion != OMPD_target_teams; + Recommend = ShouldBeInTeamsRegion; + } if (!NestingProhibited && isOpenMPNestingDistributeDirective(CurrentRegion)) { // OpenMP 4.5 [2.17 Nesting of Regions] @@ -5770,10 +5789,14 @@ OpenMPDirectiveKind CancelRegion, ArrayRef Clauses, Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc) { StmtResult Res = StmtError(); + OpenMPBindClauseKind BindKind = OMPC_BIND_unknown; + if (const OMPBindClause *BC = + OMPExecutableDirective::getSingleClause(Clauses)) + BindKind = BC->getBindKind(); // First check CancelRegion which is then used in checkNestingOfRegions. if (checkCancelRegion(*this, Kind, CancelRegion, StartLoc) || checkNestingOfRegions(*this, DSAStack, Kind, DirName, CancelRegion, - StartLoc)) + BindKind, StartLoc)) return StmtError(); llvm::SmallVector ClausesWithImplicit; @@ -6352,6 +6375,7 @@ case OMPC_exclusive: case OMPC_uses_allocators: case OMPC_affinity: + case OMPC_bind: continue; case OMPC_allocator: case OMPC_flush: @@ -13460,6 +13484,7 @@ case OMPC_uses_allocators: case OMPC_affinity: case OMPC_when: + case OMPC_bind: default: llvm_unreachable("Clause is not allowed."); } @@ -14290,6 +14315,7 @@ case OMPC_exclusive: case OMPC_uses_allocators: case OMPC_affinity: + case OMPC_bind: default: llvm_unreachable("Unexpected OpenMP clause."); } @@ -14681,6 +14707,10 @@ Res = ActOnOpenMPUpdateClause(static_cast(Argument), ArgumentLoc, StartLoc, LParenLoc, EndLoc); break; + case OMPC_bind: + Res = ActOnOpenMPBindClause(static_cast(Argument), + ArgumentLoc, StartLoc, LParenLoc, EndLoc); + break; case OMPC_if: case OMPC_final: case OMPC_num_threads: @@ -15047,6 +15077,7 @@ case OMPC_uses_allocators: case OMPC_affinity: case OMPC_when: + case OMPC_bind: default: llvm_unreachable("Clause is not allowed."); } @@ -15840,6 +15871,7 @@ case OMPC_detach: case OMPC_uses_allocators: case OMPC_when: + case OMPC_bind: default: llvm_unreachable("Clause is not allowed."); } @@ -21521,3 +21553,20 @@ return OMPAffinityClause::Create(Context, StartLoc, LParenLoc, ColonLoc, EndLoc, Modifier, Vars); } + +OMPClause *Sema::ActOnOpenMPBindClause(OpenMPBindClauseKind Kind, + SourceLocation KindLoc, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + if (Kind == OMPC_BIND_unknown) { + Diag(KindLoc, diag::err_omp_unexpected_clause_value) + << getListOfPossibleValues(OMPC_bind, /*First=*/0, + /*Last=*/unsigned(OMPC_BIND_unknown)) + << getOpenMPClauseName(OMPC_bind); + return nullptr; + } + + return OMPBindClause::Create(Context, Kind, KindLoc, StartLoc, LParenLoc, + EndLoc); +} Index: clang/lib/Sema/TreeTransform.h =================================================================== --- clang/lib/Sema/TreeTransform.h +++ clang/lib/Sema/TreeTransform.h @@ -2256,6 +2256,19 @@ EndLoc); } + /// Build a new OpenMP 'bind' clause. + /// + /// By default, performs semantic analysis to build the new OpenMP clause. + /// Subclasses may override this routine to provide different behavior. + OMPClause *RebuildOMPBindClause(OpenMPBindClauseKind Kind, + SourceLocation KindLoc, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + return getSema().ActOnOpenMPBindClause(Kind, KindLoc, StartLoc, LParenLoc, + EndLoc); + } + /// Rebuild the operand to an Objective-C \@synchronized statement. /// /// By default, performs semantic analysis to build the new statement. @@ -10242,6 +10255,13 @@ C->getEndLoc()); } +template +OMPClause *TreeTransform::TransformOMPBindClause(OMPBindClause *C) { + return getDerived().RebuildOMPBindClause( + C->getBindKind(), C->getBindKindLoc(), C->getBeginLoc(), + C->getLParenLoc(), C->getEndLoc()); +} + //===----------------------------------------------------------------------===// // Expression transformation //===----------------------------------------------------------------------===// Index: clang/lib/Serialization/ASTReader.cpp =================================================================== --- clang/lib/Serialization/ASTReader.cpp +++ clang/lib/Serialization/ASTReader.cpp @@ -11969,6 +11969,9 @@ case llvm::omp::OMPC_filter: C = new (Context) OMPFilterClause(); break; + case llvm::omp::OMPC_bind: + C = OMPBindClause::CreateEmpty(Context); + break; #define OMP_CLAUSE_NO_CLASS(Enum, Str) \ case llvm::omp::Enum: \ break; @@ -12953,6 +12956,12 @@ C->setLParenLoc(Record.readSourceLocation()); } +void OMPClauseReader::VisitOMPBindClause(OMPBindClause *C) { + C->setBindKind(static_cast(Record.readInt())); + C->setLParenLoc(Record.readSourceLocation()); + C->setBindKindLoc(Record.readSourceLocation()); +} + OMPTraitInfo *ASTRecordReader::readOMPTraitInfo() { OMPTraitInfo &TI = getContext().getNewOMPTraitInfo(); TI.Sets.resize(readUInt32()); Index: clang/lib/Serialization/ASTWriter.cpp =================================================================== --- clang/lib/Serialization/ASTWriter.cpp +++ clang/lib/Serialization/ASTWriter.cpp @@ -6721,6 +6721,12 @@ Record.AddStmt(E); } +void OMPClauseWriter::VisitOMPBindClause(OMPBindClause *C) { + Record.push_back(unsigned(C->getBindKind())); + Record.AddSourceLocation(C->getLParenLoc()); + Record.AddSourceLocation(C->getBindKindLoc()); +} + void ASTRecordWriter::writeOMPTraitInfo(const OMPTraitInfo *TI) { writeUInt32(TI->Sets.size()); for (const auto &Set : TI->Sets) { Index: clang/test/OpenMP/generic_loop_ast_print.cpp =================================================================== --- clang/test/OpenMP/generic_loop_ast_print.cpp +++ clang/test/OpenMP/generic_loop_ast_print.cpp @@ -23,7 +23,7 @@ //PRINT: template void templ_foo(T t) { //PRINT: T j, z; -//PRINT: #pragma omp loop collapse(C) reduction(+: z) lastprivate(j) +//PRINT: #pragma omp loop collapse(C) reduction(+: z) lastprivate(j) bind(thread) //PRINT: for (T i = 0; i < t; ++i) //PRINT: for (j = 0; j < t; ++j) //PRINT: z += i + j; @@ -38,12 +38,13 @@ //DUMP: DeclRefExpr{{.*}}'z' 'T' //DUMP: OMPLastprivateClause //DUMP: DeclRefExpr{{.*}}'j' 'T' +//DUMP: OMPBindClause //DUMP: ForStmt //DUMP: ForStmt //PRINT: template<> void templ_foo(int t) { //PRINT: int j, z; -//PRINT: #pragma omp loop collapse(2) reduction(+: z) lastprivate(j) +//PRINT: #pragma omp loop collapse(2) reduction(+: z) lastprivate(j) bind(thread) //PRINT: for (int i = 0; i < t; ++i) //PRINT: for (j = 0; j < t; ++j) //PRINT: z += i + j; @@ -60,12 +61,13 @@ //DUMP: DeclRefExpr{{.*}}'z' 'int':'int' //DUMP: OMPLastprivateClause //DUMP: DeclRefExpr{{.*}}'j' 'int':'int' +//DUMP: OMPBindClause //DUMP: ForStmt template void templ_foo(T t) { T j,z; - #pragma omp loop collapse(C) reduction(+:z) lastprivate(j) + #pragma omp loop collapse(C) reduction(+:z) lastprivate(j) bind(thread) for (T i = 0; ivarlists()) Visitor->AddStmt(E); } +void OMPClauseEnqueue::VisitOMPBindClause(const OMPBindClause *C) {} + } // namespace void EnqueueVisitor::EnqueueChildren(const OMPClause *S) { Index: flang/lib/Semantics/check-omp-structure.cpp =================================================================== --- flang/lib/Semantics/check-omp-structure.cpp +++ flang/lib/Semantics/check-omp-structure.cpp @@ -1480,6 +1480,7 @@ CHECK_SIMPLE_CLAUSE(AdjustArgs, OMPC_adjust_args) CHECK_SIMPLE_CLAUSE(AppendArgs, OMPC_append_args) CHECK_SIMPLE_CLAUSE(MemoryOrder, OMPC_memory_order) +CHECK_SIMPLE_CLAUSE(Bind, OMPC_bind) CHECK_REQ_SCALAR_INT_CLAUSE(Grainsize, OMPC_grainsize) CHECK_REQ_SCALAR_INT_CLAUSE(NumTasks, OMPC_num_tasks) Index: llvm/include/llvm/Frontend/OpenMP/OMP.td =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMP.td +++ llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -365,6 +365,10 @@ } def OMPC_When: Clause<"when"> {} +def OMPC_bind : Clause<"bind"> { + let clangClass = "OMPBindClause"; +} + //===----------------------------------------------------------------------===// // Definition of OpenMP directives //===----------------------------------------------------------------------===// @@ -1739,6 +1743,7 @@ VersionedClause, ]; let allowedOnceClauses = [ + VersionedClause, VersionedClause, VersionedClause, ];