diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -2805,6 +2805,9 @@ /// class OMPTargetParallelDirective : public OMPExecutableDirective { friend class ASTStmtReader; + /// true if the construct has inner cancel directive. + bool HasCancel = false; + /// Build directive with the given start and end location. /// /// \param StartLoc Starting location of the directive kind. @@ -2827,6 +2830,9 @@ SourceLocation(), SourceLocation(), NumClauses, /*NumChildren=*/1) {} + /// Set cancel state. + void setHasCancel(bool Has) { HasCancel = Has; } + public: /// Creates directive with a list of \a Clauses. /// @@ -2835,10 +2841,11 @@ /// \param EndLoc Ending Location of the directive. /// \param Clauses List of clauses. /// \param AssociatedStmt Statement, associated with the directive. + /// \param HasCancel true if this directive has inner cancel directive. /// static OMPTargetParallelDirective * Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, - ArrayRef Clauses, Stmt *AssociatedStmt); + ArrayRef Clauses, Stmt *AssociatedStmt, bool HasCancel); /// Creates an empty directive with the place for \a NumClauses /// clauses. @@ -2849,6 +2856,9 @@ static OMPTargetParallelDirective * CreateEmpty(const ASTContext &C, unsigned NumClauses, EmptyShell); + /// Return true if current directive has inner cancel directive. + bool hasCancel() const { return HasCancel; } + static bool classof(const Stmt *T) { return T->getStmtClass() == OMPTargetParallelDirectiveClass; } diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -887,7 +887,7 @@ OMPTargetParallelDirective *OMPTargetParallelDirective::Create( const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, - ArrayRef Clauses, Stmt *AssociatedStmt) { + ArrayRef Clauses, Stmt *AssociatedStmt, bool HasCancel) { unsigned Size = llvm::alignTo(sizeof(OMPTargetParallelDirective), alignof(OMPClause *)); void *Mem = @@ -896,6 +896,7 @@ new (Mem) OMPTargetParallelDirective(StartLoc, EndLoc, Clauses.size()); Dir->setClauses(Clauses); Dir->setAssociatedStmt(AssociatedStmt); + Dir->setHasCancel(HasCancel); return Dir; } diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -1459,6 +1459,8 @@ bool HasCancel = false; if (const auto *OPD = dyn_cast(&D)) HasCancel = OPD->hasCancel(); + else if (const auto *OPD = dyn_cast(&D)) + HasCancel = OPD->hasCancel(); else if (const auto *OPSD = dyn_cast(&D)) HasCancel = OPSD->hasCancel(); else if (const auto *OPFD = dyn_cast(&D)) diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -9903,7 +9903,7 @@ setFunctionHasBranchProtectedScope(); return OMPTargetParallelDirective::Create(Context, StartLoc, EndLoc, Clauses, - AStmt); + AStmt, DSAStack->isCancelRegion()); } StmtResult Sema::ActOnOpenMPTargetParallelForDirective( diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2489,6 +2489,7 @@ VisitStmt(D); Record.skipInts(1); VisitOMPExecutableDirective(D); + D->setHasCancel(Record.readBool()); } void ASTStmtReader::VisitOMPTargetParallelForDirective( diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -2336,6 +2336,7 @@ VisitStmt(D); Record.push_back(D->getNumClauses()); VisitOMPExecutableDirective(D); + Record.writeBool(D->hasCancel()); Code = serialization::STMT_OMP_TARGET_PARALLEL_DIRECTIVE; } diff --git a/clang/test/OpenMP/target_parallel_ast_print.cpp b/clang/test/OpenMP/target_parallel_ast_print.cpp --- a/clang/test/OpenMP/target_parallel_ast_print.cpp +++ b/clang/test/OpenMP/target_parallel_ast_print.cpp @@ -225,8 +225,16 @@ #pragma omp target parallel defaultmap(tofrom: scalar) reduction(task, +:argc) // CHECK-NEXT: #pragma omp target parallel defaultmap(tofrom: scalar) reduction(task, +: argc) + { foo(); +#pragma omp cancellation point parallel +#pragma omp cancel parallel + } +// CHECK-NEXT: { // CHECK-NEXT: foo(); +// CHECK-NEXT: #pragma omp cancellation point parallel +// CHECK-NEXT: #pragma omp cancel parallel +// CHECK-NEXT: } return tmain(argc, &argc) + tmain(argv[0][0], argv[0]); } diff --git a/clang/test/OpenMP/target_parallel_codegen.cpp b/clang/test/OpenMP/target_parallel_codegen.cpp --- a/clang/test/OpenMP/target_parallel_codegen.cpp +++ b/clang/test/OpenMP/target_parallel_codegen.cpp @@ -134,6 +134,7 @@ #pragma omp target parallel if(target: 1) { aa += 1; +#pragma omp cancel parallel } // CHECK: [[IF:%.+]] = icmp sgt i32 {{[^,]+}}, 10 @@ -360,6 +361,12 @@ // CHECK: store i[[SZ]] %{{.+}}, i[[SZ]]* [[AA_ADDR]], align // CHECK: [[AA_CADDR:%.+]] = bitcast i[[SZ]]* [[AA_ADDR]] to i16* // CHECK: [[AA:%.+]] = load i16, i16* [[AA_CADDR]], align +// CHECK: [[IS_CANCEL:%.+]] = call i32 @__kmpc_cancel(%struct.ident_t* @{{.+}}, i32 %{{.+}}, i32 1) +// CHECK: [[CMP:%.+]] = icmp ne i32 [[IS_CANCEL]], 0 +// CHECK: br i1 [[CMP]], label %[[EXIT:.+]], label %[[CONTINUE:[^,]+]] +// CHECK: [[EXIT]]: +// CHECK: br label %[[CONTINUE]] +// CHECK: [[CONTINUE]]: // CHECK: ret void // CHECK-NEXT: } diff --git a/clang/test/OpenMP/target_parallel_messages.cpp b/clang/test/OpenMP/target_parallel_messages.cpp --- a/clang/test/OpenMP/target_parallel_messages.cpp +++ b/clang/test/OpenMP/target_parallel_messages.cpp @@ -76,6 +76,20 @@ #pragma omp target parallel copyin(pvt) // expected-error {{unexpected OpenMP clause 'copyin' in directive '#pragma omp target parallel'}} foo(); + #pragma omp target parallel + { +#pragma omp cancel // expected-error {{one of 'for', 'parallel', 'sections' or 'taskgroup' is expected}} +#pragma omp cancellation point // expected-error {{one of 'for', 'parallel', 'sections' or 'taskgroup' is expected}} +#pragma omp cancel for // expected-error {{region cannot be closely nested inside 'target parallel' region}} +#pragma omp cancellation point for // expected-error {{region cannot be closely nested inside 'target parallel' region}} +#pragma omp cancel sections // expected-error {{region cannot be closely nested inside 'target parallel' region}} +#pragma omp cancellation point sections // expected-error {{region cannot be closely nested inside 'target parallel' region}} +#pragma omp cancel taskgroup // expected-error {{region cannot be closely nested inside 'target parallel' region}} +#pragma omp cancellation point taskgroup // expected-error {{region cannot be closely nested inside 'target parallel' region}} +#pragma omp cancel parallel +#pragma omp cancellation point parallel + } + return 0; }