diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td --- a/clang/include/clang/Basic/DiagnosticParseKinds.td +++ b/clang/include/clang/Basic/DiagnosticParseKinds.td @@ -1363,6 +1363,7 @@ "expected ',' or ')' after iterator specifier">; def err_omp_decl_in_declare_simd_variant : Error< "function declaration is expected after 'declare %select{simd|variant}0' directive">; +def err_omp_sink_and_source_iteration_not_allowd: Error<" '%0 %select{sink:|source:}1' must be with '%select{omp_cur_iteration - 1|omp_cur_iteration}1'">; def err_omp_unknown_map_type : Error< "incorrect map type, expected one of 'to', 'from', 'tofrom', 'alloc', 'release', or 'delete'">; def err_omp_unknown_map_type_modifier : Error< diff --git a/clang/include/clang/Basic/OpenMPKinds.def b/clang/include/clang/Basic/OpenMPKinds.def --- a/clang/include/clang/Basic/OpenMPKinds.def +++ b/clang/include/clang/Basic/OpenMPKinds.def @@ -207,6 +207,8 @@ // Modifiers for the 'doacross' clause. OPENMP_DOACROSS_MODIFIER(source) OPENMP_DOACROSS_MODIFIER(sink) +OPENMP_DOACROSS_MODIFIER(sink_omp_cur_iteration) +OPENMP_DOACROSS_MODIFIER(source_omp_cur_iteration) #undef OPENMP_NUMTASKS_MODIFIER #undef OPENMP_GRAINSIZE_MODIFIER diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -2513,7 +2513,14 @@ void OMPClausePrinter::VisitOMPDoacrossClause(OMPDoacrossClause *Node) { OS << "doacross("; OpenMPDoacrossClauseModifier DepType = Node->getDependenceType(); - OS << (DepType == OMPC_DOACROSS_source ? "source:" : "sink:"); + if (DepType == OMPC_DOACROSS_source) + OS << "source:"; + else if (DepType == OMPC_DOACROSS_sink) + OS << "sink:"; + else if (DepType == OMPC_DOACROSS_source_omp_cur_iteration) + OS << "source: omp_cur_iteration"; + else if (DepType == OMPC_DOACROSS_sink_omp_cur_iteration) + OS << "sink: omp_cur_iteration - 1"; VisitOMPClauseList(Node, ' '); OS << ")"; } diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h --- a/clang/lib/CodeGen/CGOpenMPRuntime.h +++ b/clang/lib/CodeGen/CGOpenMPRuntime.h @@ -2291,10 +2291,12 @@ template <> class OMPDoacrossKind { public: bool isSource(const OMPDoacrossClause *C) { - return (C->getDependenceType() == OMPC_DOACROSS_source); + return C->getDependenceType() == OMPC_DOACROSS_source || + C->getDependenceType() == OMPC_DOACROSS_source_omp_cur_iteration; } bool isSink(const OMPDoacrossClause *C) { - return (C->getDependenceType() == OMPC_DOACROSS_sink); + return C->getDependenceType() == OMPC_DOACROSS_sink || + C->getDependenceType() == OMPC_DOACROSS_sink_omp_cur_iteration; } }; } // namespace diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -4410,16 +4410,55 @@ } if (Tok.is(tok::colon)) { Data.ColonLoc = ConsumeToken(); - } else { + } else if (Kind != OMPC_doacross || Tok.isNot(tok::r_paren) { Diag(Tok, DKind == OMPD_ordered ? diag::warn_pragma_expected_colon_r_paren : diag::warn_pragma_expected_colon) << (Kind == OMPC_depend ? "dependency type" : "dependence-type"); } - // Special processing for doacross(source) clause. - if (Kind == OMPC_doacross && Data.ExtraModifier == OMPC_DOACROSS_source) { - // Parse ')'. - T.consumeClose(); - return false; + if (Kind == OMPC_doacross) { + if (Tok.is(tok::identifier) && + Tok.getIdentifierInfo()->isStr("omp_cur_iteration")) { + Data.ExtraModifier = Data.ExtraModifier == OMPC_DOACROSS_source + ? OMPC_DOACROSS_source_omp_cur_iteration + : OMPC_DOACROSS_sink_omp_cur_iteration; + ConsumeToken(); + } + if (Data.ExtraModifier == OMPC_DOACROSS_sink_omp_cur_iteration) { + if (Tok.isNot(tok::minus)) { + Diag(Tok, diag::err_omp_sink_and_source_iteration_not_allowd) + << getOpenMPClauseName(Kind) << 0 << 0; + SkipUntil(tok::r_paren); + return false; + } else { + ConsumeToken(); + SourceLocation Loc = Tok.getLocation(); + uint64_t Value = 0; + if (Tok.isNot(tok::numeric_constant) || + (PP.parseSimpleIntegerLiteral(Tok, Value) && Value != 1)) { + Diag(Loc, diag::err_omp_sink_and_source_iteration_not_allowd) + << getOpenMPClauseName(Kind) << 0 << 0; + SkipUntil(tok::r_paren); + return false; + } + } + } + if (Data.ExtraModifier == OMPC_DOACROSS_source_omp_cur_iteration) { + if (Tok.isNot(tok::r_paren)) { + Diag(Tok, diag::err_omp_sink_and_source_iteration_not_allowd) + << getOpenMPClauseName(Kind) << 1 << 1; + SkipUntil(tok::r_paren); + return false; + } + } + // Only the 'sink' case has the expression list. + if (Kind == OMPC_doacross && + (Data.ExtraModifier == OMPC_DOACROSS_source || + Data.ExtraModifier == OMPC_DOACROSS_source_omp_cur_iteration || + Data.ExtraModifier == OMPC_DOACROSS_sink_omp_cur_iteration)) { + // Parse ')'. + T.consumeClose(); + return false; + } } } else if (Kind == OMPC_linear) { // Try to parse modifier if any. diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -9165,6 +9165,22 @@ } } +namespace { +// Utility for openmp doacross clause kind +class OMPDoacrossKind { +public: + bool isSource(const OMPDoacrossClause *C) { + return C->getDependenceType() == OMPC_DOACROSS_source || + C->getDependenceType() == OMPC_DOACROSS_source_omp_cur_iteration; + } + bool isSink(const OMPDoacrossClause *C) { + return C->getDependenceType() == OMPC_DOACROSS_sink; + } + bool isSinkIter(const OMPDoacrossClause *C) { + return C->getDependenceType() == OMPC_DOACROSS_sink_omp_cur_iteration; + } +}; +} // namespace /// Called on a for stmt to check and extract its iteration space /// for further processing (such as collapsing). static bool checkOpenMPIterationSpace( @@ -9332,7 +9348,8 @@ DependC->setLoopData(CurrentNestedLoopCount, nullptr); continue; } - if (DoacrossC && DoacrossC->getDependenceType() == OMPC_DOACROSS_sink && + OMPDoacrossKind ODK; + if (DoacrossC && ODK.isSink(DoacrossC) && Pair.second.size() <= CurrentNestedLoopCount) { // Erroneous case - clause has some problems. DoacrossC->setLoopData(CurrentNestedLoopCount, nullptr); @@ -9342,12 +9359,27 @@ SourceLocation DepLoc = DependC ? DependC->getDependencyLoc() : DoacrossC->getDependenceLoc(); if ((DependC && DependC->getDependencyKind() == OMPC_DEPEND_source) || - (DoacrossC && DoacrossC->getDependenceType() == OMPC_DOACROSS_source)) + (DoacrossC && ODK.isSource(DoacrossC))) CntValue = ISC.buildOrderedLoopData( DSA.getCurScope(), ResultIterSpaces[CurrentNestedLoopCount].CounterVar, Captures, DepLoc); - else + else if (DoacrossC && ODK.isSinkIter(DoacrossC)) { + Expr *Cnt = SemaRef + .DefaultLvalueConversion( + ResultIterSpaces[CurrentNestedLoopCount].CounterVar) + .get(); + if (!Cnt) + continue; + // build CounterVar - 1 + Expr *Inc = + SemaRef.ActOnIntegerConstant(DoacrossC->getColonLoc(), /*Val=*/1) + .get(); + CntValue = ISC.buildOrderedLoopData( + DSA.getCurScope(), + ResultIterSpaces[CurrentNestedLoopCount].CounterVar, Captures, + DepLoc, Inc, clang::OO_Minus); + } else CntValue = ISC.buildOrderedLoopData( DSA.getCurScope(), ResultIterSpaces[CurrentNestedLoopCount].CounterVar, Captures, @@ -11284,8 +11316,9 @@ if (DC || DOC) { DependFound = DC ? C : nullptr; DoacrossFound = DOC ? C : nullptr; + OMPDoacrossKind ODK; if ((DC && DC->getDependencyKind() == OMPC_DEPEND_source) || - (DOC && DOC->getDependenceType() == OMPC_DOACROSS_source)) { + (DOC && (ODK.isSource(DOC)))) { if ((DC && DependSourceClause) || (DOC && DoacrossSourceClause)) { Diag(C->getBeginLoc(), diag::err_omp_more_one_clause) << getOpenMPDirectiveName(OMPD_ordered) @@ -11303,7 +11336,7 @@ ErrorFound = true; } } else if ((DC && DC->getDependencyKind() == OMPC_DEPEND_sink) || - (DOC && DOC->getDependenceType() == OMPC_DOACROSS_sink)) { + (DOC && (ODK.isSink(DOC) || ODK.isSinkIter(DOC)))) { if (DependSourceClause || DoacrossSourceClause) { Diag(C->getBeginLoc(), diag::err_omp_sink_and_source_not_allowed) << (DC ? "depend" : "doacross") << 1; @@ -24025,7 +24058,10 @@ SourceLocation LParenLoc, SourceLocation EndLoc) { if (DSAStack->getCurrentDirective() == OMPD_ordered && - DepType != OMPC_DOACROSS_source && DepType != OMPC_DOACROSS_sink) { + DepType != OMPC_DOACROSS_source && DepType != OMPC_DOACROSS_sink && + DepType != OMPC_DOACROSS_sink_omp_cur_iteration && + DepType != OMPC_DOACROSS_source_omp_cur_iteration && + DepType != OMPC_DOACROSS_source) { Diag(DepLoc, diag::err_omp_unexpected_clause_value) << "'source' or 'sink'" << getOpenMPClauseName(OMPC_doacross); return nullptr; @@ -24035,7 +24071,11 @@ DSAStackTy::OperatorOffsetTy OpsOffs; llvm::APSInt TotalDepCount(/*BitWidth=*/32); DoacrossDataInfoTy VarOffset = ProcessOpenMPDoacrossClauseCommon( - *this, DepType == OMPC_DOACROSS_source, VarList, DSAStack, EndLoc); + *this, + DepType == OMPC_DOACROSS_source || + DepType == OMPC_DOACROSS_source_omp_cur_iteration || + DepType == OMPC_DOACROSS_sink_omp_cur_iteration, + VarList, DSAStack, EndLoc); Vars = VarOffset.Vars; OpsOffs = VarOffset.OpsOffs; TotalDepCount = VarOffset.TotalDepCount; diff --git a/clang/test/OpenMP/ordered_ast_print.cpp b/clang/test/OpenMP/ordered_ast_print.cpp --- a/clang/test/OpenMP/ordered_ast_print.cpp +++ b/clang/test/OpenMP/ordered_ast_print.cpp @@ -55,6 +55,8 @@ #if _OPENMP >= 202111 #pragma omp ordered doacross(source:) #pragma omp ordered doacross(sink:i+N) + #pragma omp ordered doacross(sink: omp_cur_iteration - 1) + #pragma omp ordered doacross(source: omp_cur_iteration) #else #pragma omp ordered depend(source) #pragma omp ordered depend(sink:i+N) @@ -100,6 +102,8 @@ #if _OPENMP >= 202111 // OMP52: #pragma omp ordered doacross(source:) // OMP52-NEXT: #pragma omp ordered doacross(sink: i + N) +// OMP52-NEXT: #pragma omp ordered doacross(sink: omp_cur_iteration - 1) +// OMP52-NEXT: #pragma omp ordered doacross(source: omp_cur_iteration) #else // OMP51: #pragma omp ordered depend(source) // OMP51-NEXT: #pragma omp ordered depend(sink : i + N) @@ -142,6 +146,8 @@ #if _OPENMP >= 202111 // OMP52: #pragma omp ordered doacross(source:) // OMP52-NEXT: #pragma omp ordered doacross(sink: i + 3) +// OMP52-NEXT: #pragma omp ordered doacross(sink: omp_cur_iteration - 1) +// OMP52-NEXT: #pragma omp ordered doacross(source: omp_cur_iteration) #else // OMP51: #pragma omp ordered depend(source) // OMP51-NEXT: #pragma omp ordered depend(sink : i + 3) @@ -189,6 +195,8 @@ #if _OPENMP >= 202111 #pragma omp ordered doacross(source:) #pragma omp ordered doacross(sink: i - 5) + #pragma omp ordered doacross(sink: omp_cur_iteration - 1) + #pragma omp ordered doacross(source: omp_cur_iteration) #else #pragma omp ordered depend(source) #pragma omp ordered depend(sink: i - 5) @@ -230,6 +238,8 @@ #if _OPENMP >= 202111 // OMP52: #pragma omp ordered doacross(source:) // OMP52-NEXT: #pragma omp ordered doacross(sink: i - 5) +// OMP52-NEXT: #pragma omp ordered doacross(sink: omp_cur_iteration - 1) +// OMP52-NEXT: #pragma omp ordered doacross(source: omp_cur_iteration) #else // OMP51: #pragma omp ordered depend(source) // OMP51-NEXT: #pragma omp ordered depend(sink : i - 5) diff --git a/clang/test/OpenMP/ordered_doacross_codegen.c b/clang/test/OpenMP/ordered_doacross_codegen.c --- a/clang/test/OpenMP/ordered_doacross_codegen.c +++ b/clang/test/OpenMP/ordered_doacross_codegen.c @@ -84,6 +84,43 @@ #pragma omp ordered depend(sink : i - 2) #endif d[i] = a[i - 2]; + foo(); +// CHECK: call void @foo() +// CHECK: load i32, ptr [[I]], +// CHECK-NEXT: sub nsw i32 %{{.+}}, 1 +// CHECK-NEXT: sub nsw i32 %{{.+}}, 0 +// CHECK-NEXT: sdiv i32 %{{.+}}, 1 +// CHECK-NEXT: sext i32 %{{.+}} to i64 +// CHECK-NEXT: [[TMP:%.+]] = getelementptr inbounds [1 x i64], ptr [[CNT:%.+]], i64 0, i64 0 +// CHECK-NEXT: store i64 %{{.+}}, ptr [[TMP]], +// CHECK-NEXT: [[TMP:%.+]] = getelementptr inbounds [1 x i64], ptr [[CNT]], i64 0, i64 0 +// CHECK-NORMAL-NEXT: call void @__kmpc_doacross_wait(ptr [[IDENT]], i32 [[GTID]], ptr [[TMP]]) +// CHECK-IRBUILDER-NEXT: [[GTID2:%.+]] = call i32 @__kmpc_global_thread_num(ptr [[IDENT:@.+]]) +// CHECK-IRBUILDER-NEXT: call void @__kmpc_doacross_wait(ptr [[IDENT]], i32 [[GTID2]], ptr [[TMP]]) +#ifdef OMP52 +#pragma omp ordered doacross(sink :omp_cur_iteration - 1) +#else +#pragma omp ordered depend(sink : i - 1) +#endif + d[i] = a[i - 1]; + foo(); +// CHECK: call void @foo() +// CHECK: load i32, ptr [[I:%.+]], +// CHECK-NEXT: sub nsw i32 %{{.+}}, 0 +// CHECK-NEXT: sdiv i32 %{{.+}}, 1 +// CHECK-NEXT: sext i32 %{{.+}} to i64 +// CHECK-NEXT: [[TMP:%.+]] = getelementptr inbounds [1 x i64], ptr [[CNT:%.+]], i64 0, i64 0 +// CHECK-NEXT: store i64 %{{.+}}, ptr [[TMP]], +// CHECK-NEXT: [[TMP:%.+]] = getelementptr inbounds [1 x i64], ptr [[CNT]], i64 0, i64 0 +// CHECK-NORMAL-NEXT: call void @__kmpc_doacross_post(ptr [[IDENT]], i32 [[GTID]], ptr [[TMP]]) +// CHECK-IRBUILDER-NEXT: [[GTID1:%.+]] = call i32 @__kmpc_global_thread_num(ptr [[IDENT:@.+]]) +// CHECK-IRBUILDER-NEXT: call void @__kmpc_doacross_post(ptr [[IDENT]], i32 [[GTID1]], ptr [[TMP]]) +#if OMP52 +#pragma omp ordered doacross(source:omp_cur_iteration) +#else +#pragma omp ordered depend(source) +#endif + c[i] = c[i] + 1; } // CHECK: call void @__kmpc_for_static_fini( // CHECK-NORMAL: call void @__kmpc_doacross_fini(ptr [[IDENT]], i32 [[GTID]]) diff --git a/clang/test/OpenMP/ordered_messages.cpp b/clang/test/OpenMP/ordered_messages.cpp --- a/clang/test/OpenMP/ordered_messages.cpp +++ b/clang/test/OpenMP/ordered_messages.cpp @@ -132,7 +132,11 @@ { foo(); } +#if _OPENMP >= 202111 + #pragma omp ordered doacross(source:omp_cur_iteration) // expected-error {{OpenMP constructs may not be nested inside a simd region}} +#else #pragma omp ordered depend(source) // expected-error {{OpenMP constructs may not be nested inside a simd region}} +#endif } #pragma omp parallel for ordered for (int i = 0; i < 10; ++i) { @@ -155,6 +159,10 @@ #pragma omp ordered doacross(sink : // omp52-error {{expected ')'}} omp52-note {{to match this '('}} omp52-error {{expected expression}} omp52-error {{expected 'i' loop iteration variable}} #pragma omp ordered doacross(sink : i // omp52-error {{expected ')'}} omp52-note {{to match this '('}} omp52-error {{expected 'j' loop iteration variable}} #pragma omp ordered doacross(sink : i) // omp52-error {{expected 'j' loop iteration variable}} +#pragma omp ordered doacross(sink:omp_cur_iteration + 1) // omp52-error {{'doacross sink:' must be with 'omp_cur_iteration - 1'}} +#pragma omp ordered doacross(sink:omp_cur_iteration - 2) // omp52-error {{'doacross sink:' must be with 'omp_cur_iteration - 1'}} +#pragma omp ordered doacross(sink:omp_cur_iteration) // omp52-error {{'doacross sink:' must be with 'omp_cur_iteration - 1'}} +#pragma omp ordered doacross(source:omp_cur_iteration - 1) // omp52-error {{'doacross source:' must be with 'omp_cur_iteration'}} #pragma omp ordered doacross(source:) if (i == j) #pragma omp ordered doacross(source:) // omp52-error {{'#pragma omp ordered' with 'doacross' clause cannot be an immediate substatement}} @@ -309,7 +317,11 @@ { foo(); } +#if _OPENMP >= 202111 + #pragma omp ordered doacross(source:omp_cur_iteration) // expected-error {{OpenMP constructs may not be nested inside a simd region}} +#else #pragma omp ordered depend(source) // expected-error {{OpenMP constructs may not be nested inside a simd region}} +#endif } #pragma omp parallel for ordered for (int i = 0; i < 10; ++i) {