diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -7561,14 +7561,49 @@ }; /// This represents 'destroy' clause in the '#pragma omp depobj' -/// directive. +/// directive or the '#pragma omp interop' directive.. /// /// \code /// #pragma omp depobj(a) destroy +/// #pragma omp interop destroy(obj) /// \endcode -/// In this example directive '#pragma omp depobj' has 'destroy' clause. +/// In these examples directive '#pragma omp depobj' and '#pragma omp interop' +/// have a 'destroy' clause. The 'interop' directive includes an object. class OMPDestroyClause final : public OMPClause { + friend class OMPClauseReader; + + /// Location of '('. + SourceLocation LParenLoc; + + /// Location of interop variable. + SourceLocation VarLoc; + + /// The interop variable. + Stmt *InteropVar = nullptr; + + /// Set the interop variable. + void setInteropVar(Expr *E) { InteropVar = E; } + + /// Sets the location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + + /// Sets the location of the interop variable. + void setVarLoc(SourceLocation Loc) { VarLoc = Loc; } + public: + /// Build 'destroy' clause with an interop variable expression \a InteropVar. + /// + /// \param InteropVar The interop variable. + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param VarLoc Location of the interop variable. + /// \param EndLoc Ending location of the clause. + OMPDestroyClause(Expr *InteropVar, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation VarLoc, + SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_destroy, StartLoc, EndLoc), + LParenLoc(LParenLoc), VarLoc(VarLoc), InteropVar(InteropVar) {} + /// Build 'destroy' clause. /// /// \param StartLoc Starting location of the clause. @@ -7581,11 +7616,24 @@ : OMPClause(llvm::omp::OMPC_destroy, SourceLocation(), SourceLocation()) { } + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + /// Returns the location of the interop variable. + SourceLocation getVarLoc() const { return VarLoc; } + + /// Returns the interop variable. + Expr *getInteropVar() const { return cast_or_null(InteropVar); } + child_range children() { + if (InteropVar) + return child_range(&InteropVar, &InteropVar + 1); return child_range(child_iterator(), child_iterator()); } const_child_range children() const { + if (InteropVar) + return const_child_range(&InteropVar, &InteropVar + 1); return const_child_range(const_child_iterator(), const_child_iterator()); } diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3210,7 +3210,8 @@ } template -bool RecursiveASTVisitor::VisitOMPDestroyClause(OMPDestroyClause *) { +bool RecursiveASTVisitor::VisitOMPDestroyClause(OMPDestroyClause *C) { + TRY_TO(TraverseStmt(C->getInteropVar())); return true; } diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -10998,8 +10998,11 @@ SourceLocation VarLoc, SourceLocation EndLoc); /// Called on well-formed 'destroy' clause. - OMPClause *ActOnOpenMPDestroyClause(SourceLocation StartLoc, + OMPClause *ActOnOpenMPDestroyClause(Expr *InteropVar, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation VarLoc, SourceLocation EndLoc); + /// Called on well-formed 'threads' clause. OMPClause *ActOnOpenMPThreadsClause(SourceLocation StartLoc, SourceLocation EndLoc); diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -1807,8 +1807,13 @@ OS << ")"; } -void OMPClausePrinter::VisitOMPDestroyClause(OMPDestroyClause *) { +void OMPClausePrinter::VisitOMPDestroyClause(OMPDestroyClause *Node) { OS << "destroy"; + if (Expr *E = Node->getInteropVar()) { + OS << "("; + E->printPretty(OS, nullptr, Policy); + OS << ")"; + } } template diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -552,7 +552,10 @@ Profiler->VisitStmt(C->getInteropVar()); } -void OMPClauseProfiler::VisitOMPDestroyClause(const OMPDestroyClause *) {} +void OMPClauseProfiler::VisitOMPDestroyClause(const OMPDestroyClause *C) { + if (C->getInteropVar()) + Profiler->VisitStmt(C->getInteropVar()); +} template void OMPClauseProfiler::VisitOMPClauseList(T *Node) { diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -2865,7 +2865,6 @@ case OMPC_unified_shared_memory: case OMPC_reverse_offload: case OMPC_dynamic_allocators: - case OMPC_destroy: // OpenMP [2.7.1, Restrictions, p. 9] // Only one ordered clause can appear on a loop directive. // OpenMP [2.7.1, Restrictions, C/C++, p. 4] @@ -2929,6 +2928,17 @@ case OMPC_uses_allocators: Clause = ParseOpenMPUsesAllocatorClause(DKind); break; + case OMPC_destroy: + if (DKind != OMPD_interop) { + if (!FirstClause) { + Diag(Tok, diag::err_omp_more_one_clause) + << getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0; + ErrorFound = true; + } + Clause = ParseOpenMPClause(CKind, WrongDirective); + break; + } + LLVM_FALLTHROUGH; case OMPC_init: case OMPC_use: Clause = ParseOpenMPInteropClause(CKind, WrongDirective); @@ -3160,6 +3170,10 @@ return Actions.ActOnOpenMPUseClause(InteropVarExpr.get(), Loc, T.getOpenLocation(), VarLoc, RLoc); + if (Kind == OMPC_destroy) + return Actions.ActOnOpenMPDestroyClause(InteropVarExpr.get(), Loc, + T.getOpenLocation(), VarLoc, RLoc); + llvm_unreachable("Unexpected interop variable clause."); } diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -14441,7 +14441,9 @@ Res = ActOnOpenMPDynamicAllocatorsClause(StartLoc, EndLoc); break; case OMPC_destroy: - Res = ActOnOpenMPDestroyClause(StartLoc, EndLoc); + Res = ActOnOpenMPDestroyClause(/*InteropVar=*/nullptr, StartLoc, + /*LParenLoc=*/SourceLocation(), + /*VarLoc=*/SourceLocation(), EndLoc); break; case OMPC_if: case OMPC_final: @@ -14599,19 +14601,13 @@ return new (Context) OMPDynamicAllocatorsClause(StartLoc, EndLoc); } -OMPClause *Sema::ActOnOpenMPDestroyClause(SourceLocation StartLoc, - SourceLocation EndLoc) { - return new (Context) OMPDestroyClause(StartLoc, EndLoc); -} - StmtResult Sema::ActOnOpenMPInteropDirective(ArrayRef Clauses, SourceLocation StartLoc, SourceLocation EndLoc) { // OpenMP 5.1 [2.15.1, interop Construct, Restrictions] // At least one action-clause must appear on a directive. - // TODO: also add 'destroy' here. - if (!hasClauses(Clauses, OMPC_init, OMPC_use, OMPC_nowait)) { + if (!hasClauses(Clauses, OMPC_init, OMPC_use, OMPC_destroy, OMPC_nowait)) { StringRef Expected = "'init', 'use', 'destroy', or 'nowait'"; Diag(StartLoc, diag::err_omp_no_clause_for_directive) << Expected << getOpenMPDirectiveName(OMPD_interop); @@ -14662,8 +14658,11 @@ const auto *UC = cast(C); VarLoc = UC->getVarLoc(); DRE = dyn_cast_or_null(UC->getInteropVar()); + } else if (ClauseKind == OMPC_destroy) { + const auto *DC = cast(C); + VarLoc = DC->getVarLoc(); + DRE = dyn_cast_or_null(DC->getInteropVar()); } - // TODO: 'destroy' clause to be added here. if (!DRE) continue; @@ -14723,8 +14722,7 @@ // OpenMP 5.1 [2.15.1, interop Construct, Restrictions] // The interop-var passed to init or destroy must be non-const. - // TODO: 'destroy' clause too. - if (Kind == OMPC_init && + if ((Kind == OMPC_init || Kind == OMPC_destroy) && isConstNotMutableType(SemaRef, InteropVarExpr->getType())) { SemaRef.Diag(VarLoc, diag::err_omp_interop_variable_expected) << /*non-const*/ 1; @@ -14773,6 +14771,19 @@ OMPUseClause(InteropVar, StartLoc, LParenLoc, VarLoc, EndLoc); } +OMPClause *Sema::ActOnOpenMPDestroyClause(Expr *InteropVar, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation VarLoc, + SourceLocation EndLoc) { + if (InteropVar && + !isValidInteropVariable(*this, InteropVar, VarLoc, OMPC_destroy)) + return nullptr; + + return new (Context) + OMPDestroyClause(InteropVar, StartLoc, LParenLoc, VarLoc, EndLoc); +} + OMPClause *Sema::ActOnOpenMPVarListClause( OpenMPClauseKind Kind, ArrayRef VarList, Expr *DepModOrTailExpr, const OMPVarListLocTy &Locs, SourceLocation ColonLoc, diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -2196,6 +2196,18 @@ VarLoc, EndLoc); } + /// Build a new OpenMP 'destroy' clause. + /// + /// By default, performs semantic analysis to build the new OpenMP clause. + /// Subclasses may override this routine to provide different behavior. + OMPClause *RebuildOMPDestroyClause(Expr *InteropVar, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation VarLoc, + SourceLocation EndLoc) { + return getSema().ActOnOpenMPDestroyClause(InteropVar, StartLoc, LParenLoc, + VarLoc, EndLoc); + } + /// Rebuild the operand to an Objective-C \@synchronized statement. /// /// By default, performs semantic analysis to build the new statement. @@ -9343,8 +9355,15 @@ template OMPClause * TreeTransform::TransformOMPDestroyClause(OMPDestroyClause *C) { - // No need to rebuild this clause, no template-dependent parameters. - return C; + ExprResult ER; + if (Expr *IV = C->getInteropVar()) { + ER = getDerived().TransformExpr(IV); + if (ER.isInvalid()) + return nullptr; + } + return getDerived().RebuildOMPDestroyClause(ER.get(), C->getBeginLoc(), + C->getLParenLoc(), C->getVarLoc(), + C->getEndLoc()); } template diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -12156,7 +12156,11 @@ C->setVarLoc(Record.readSourceLocation()); } -void OMPClauseReader::VisitOMPDestroyClause(OMPDestroyClause *) {} +void OMPClauseReader::VisitOMPDestroyClause(OMPDestroyClause *C) { + C->setInteropVar(Record.readSubExpr()); + C->setLParenLoc(Record.readSourceLocation()); + C->setVarLoc(Record.readSourceLocation()); +} void OMPClauseReader::VisitOMPUnifiedAddressClause(OMPUnifiedAddressClause *) {} diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -6231,7 +6231,11 @@ Record.AddSourceLocation(C->getVarLoc()); } -void OMPClauseWriter::VisitOMPDestroyClause(OMPDestroyClause *) {} +void OMPClauseWriter::VisitOMPDestroyClause(OMPDestroyClause *C) { + Record.AddStmt(C->getInteropVar()); + Record.AddSourceLocation(C->getLParenLoc()); + Record.AddSourceLocation(C->getVarLoc()); +} void OMPClauseWriter::VisitOMPPrivateClause(OMPPrivateClause *C) { Record.push_back(C->varlist_size()); diff --git a/clang/test/OpenMP/interop_ast_print.cpp b/clang/test/OpenMP/interop_ast_print.cpp --- a/clang/test/OpenMP/interop_ast_print.cpp +++ b/clang/test/OpenMP/interop_ast_print.cpp @@ -41,6 +41,12 @@ //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'I' #pragma omp interop use(I) + //PRINT: #pragma omp interop destroy(I) + //DUMP: OMPInteropDirective + //DUMP: OMPDestroyClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'I' + #pragma omp interop destroy(I) + //PRINT: #pragma omp interop init(target : IRef) //DUMP: OMPInteropDirective //DUMP: OMPInitClause @@ -53,6 +59,12 @@ //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'IRef' #pragma omp interop use(IRef) + //PRINT: #pragma omp interop destroy(IRef) + //DUMP: OMPInteropDirective + //DUMP: OMPDestroyClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'IRef' + #pragma omp interop destroy(IRef) + const omp_interop_t CI = (omp_interop_t)0; //PRINT: #pragma omp interop use(CI) //DUMP: OMPInteropDirective @@ -80,6 +92,16 @@ //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'I' #pragma omp interop device(dev) depend(inout:ap) use(I) + //PRINT: #pragma omp interop device(dev) depend(inout : ap) destroy(I) + //DUMP: OMPInteropDirective + //DUMP: OMPDeviceClause + //DUMP: DeclRefExpr{{.*}}'dev' 'int' + //DUMP: OMPDependClause + //DUMP: DeclRefExpr{{.*}}'ap' 'int *' + //DUMP: OMPDestroyClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'I' + #pragma omp interop device(dev) depend(inout:ap) destroy(I) + //PRINT: #pragma omp interop init(prefer_type(1,2,3,4,5,6), targetsync : I) //DUMP: OMPInteropDirective //DUMP: OMPInitClause @@ -150,6 +172,30 @@ //DUMP: OMPUseClause //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'J' #pragma omp interop use(I) use(J) + + //PRINT: #pragma omp interop destroy(I) destroy(J) + //DUMP: OMPInteropDirective + //DUMP: OMPDestroyClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'I' + //DUMP: OMPDestroyClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'J' + #pragma omp interop destroy(I) destroy(J) + + //PRINT: #pragma omp interop init(target : I) destroy(J) + //DUMP: OMPInteropDirective + //DUMP: OMPInitClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'I' + //DUMP: OMPDestroyClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'J' + #pragma omp interop init(target:I) destroy(J) + + //PRINT: #pragma omp interop destroy(I) use(J) + //DUMP: OMPInteropDirective + //DUMP: OMPDestroyClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'I' + //DUMP: OMPUseClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'J' + #pragma omp interop destroy(I) use(J) } //DUMP: FunctionTemplateDecl{{.*}}fooTemp @@ -200,6 +246,12 @@ //DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'t' 'T' #pragma omp interop use(t) + //PRINT: #pragma omp interop destroy(t) + //DUMP: OMPInteropDirective + //DUMP: OMPDestroyClause + //DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'t' 'T' + #pragma omp interop destroy(t) + //DUMP: FunctionDecl{{.*}}barTemp 'void (void *)' //DUMP: TemplateArgument type 'void *' //DUMP: ParmVarDecl{{.*}}t 'void *' @@ -211,6 +263,10 @@ //DUMP: OMPUseClause //DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'t' 'void *' //PRINT: #pragma omp interop use(t) + //DUMP: OMPInteropDirective + //DUMP: OMPDestroyClause + //DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'t' 'void *' + //PRINT: #pragma omp interop destroy(t) } void bar() diff --git a/clang/test/OpenMP/interop_messages.cpp b/clang/test/OpenMP/interop_messages.cpp --- a/clang/test/OpenMP/interop_messages.cpp +++ b/clang/test/OpenMP/interop_messages.cpp @@ -17,6 +17,9 @@ //expected-error@+1 {{use of undeclared identifier 'NoDeclVar'}} #pragma omp interop use(NoDeclVar) use(Another) + //expected-error@+1 {{use of undeclared identifier 'NoDeclVar'}} + #pragma omp interop destroy(NoDeclVar) destroy(Another) + //expected-error@+2 {{expected interop type: 'target' and/or 'targetsync'}} //expected-error@+1 {{expected expression}} #pragma omp interop init(InteropVar) init(target:Another) @@ -38,6 +41,9 @@ //expected-error@+1 {{interop variable must be of type 'omp_interop_t'}} #pragma omp interop use(IntVar) use(Another) + //expected-error@+1 {{interop variable must be of type 'omp_interop_t'}} + #pragma omp interop destroy(IntVar) destroy(Another) + //expected-error@+1 {{interop variable must be of type 'omp_interop_t'}} #pragma omp interop init(prefer_type(1,"sycl",3),target:SVar) \ init(target:Another) @@ -45,6 +51,9 @@ //expected-error@+1 {{interop variable must be of type 'omp_interop_t'}} #pragma omp interop use(SVar) use(Another) + //expected-error@+1 {{interop variable must be of type 'omp_interop_t'}} + #pragma omp interop destroy(SVar) destroy(Another) + int a, b; //expected-error@+1 {{expected variable of type 'omp_interop_t'}} #pragma omp interop init(target:a+b) init(target:Another) @@ -52,10 +61,16 @@ //expected-error@+1 {{expected variable of type 'omp_interop_t'}} #pragma omp interop use(a+b) use(Another) + //expected-error@+1 {{expected variable of type 'omp_interop_t'}} + #pragma omp interop destroy(a+b) destroy(Another) + const omp_interop_t C = (omp_interop_t)5; //expected-error@+1 {{expected non-const variable of type 'omp_interop_t'}} #pragma omp interop init(target:C) init(target:Another) + //expected-error@+1 {{expected non-const variable of type 'omp_interop_t'}} + #pragma omp interop destroy(C) destroy(Another) + //expected-error@+1 {{prefer_list item must be a string literal or constant integral expression}} #pragma omp interop init(prefer_type(1.0),target:InteropVar) \ init(target:Another) @@ -79,9 +94,18 @@ //expected-error@+1 {{interop variable 'InteropVar' used in multiple action clauses}} #pragma omp interop use(InteropVar) use(InteropVar) + //expected-error@+1 {{interop variable 'InteropVar' used in multiple action clauses}} + #pragma omp interop destroy(InteropVar) destroy(InteropVar) + //expected-error@+1 {{interop variable 'InteropVar' used in multiple action clauses}} #pragma omp interop init(target:InteropVar) use(InteropVar) + //expected-error@+1 {{interop variable 'InteropVar' used in multiple action clauses}} + #pragma omp interop init(target:InteropVar) destroy(InteropVar) + + //expected-error@+1 {{interop variable 'InteropVar' used in multiple action clauses}} + #pragma omp interop use(InteropVar) destroy(InteropVar) + //expected-error@+1 {{directive '#pragma omp interop' cannot contain more than one 'device' clause}} #pragma omp interop init(target:InteropVar) device(0) device(1) @@ -99,5 +123,7 @@ #pragma omp interop init(prefer_type(1,"sycl",3),target:InteropVar) nowait //expected-error@+1 {{'omp_interop_t' type not found; include }} #pragma omp interop use(InteropVar) nowait + //expected-error@+1 {{'omp_interop_t' type not found; include }} + #pragma omp interop destroy(InteropVar) nowait } #endif diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2286,7 +2286,10 @@ Visitor->AddStmt(C->getInteropVar()); } -void OMPClauseEnqueue::VisitOMPDestroyClause(const OMPDestroyClause *) {} +void OMPClauseEnqueue::VisitOMPDestroyClause(const OMPDestroyClause *C) { + if (C->getInteropVar()) + Visitor->AddStmt(C->getInteropVar()); +} void OMPClauseEnqueue::VisitOMPUnifiedAddressClause( const OMPUnifiedAddressClause *) {} diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -1650,6 +1650,7 @@ let allowedClauses = [ VersionedClause, VersionedClause, + VersionedClause, VersionedClause, VersionedClause, VersionedClause,