diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -7489,6 +7489,77 @@ } }; +/// This represents the 'use' clause in '#pragma omp ...' directives. +/// +/// \code +/// #pragma omp interop use(obj) +/// \endcode +class OMPUseClause final : public OMPClause { + friend class OMPClauseReader; + + /// Location of '('. + SourceLocation LParenLoc; + + /// Location of interop variable. + SourceLocation VarLoc; + + /// The interop variable. + Stmt *InteropVar = nullptr; + + /// Set the interop variable. + void setInteropVar(Expr *E) { InteropVar = E; } + + /// Sets the location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + + /// Sets the location of the interop variable. + void setVarLoc(SourceLocation Loc) { VarLoc = Loc; } + +public: + /// Build 'use' clause with and interop variable expression \a InteropVar. + /// + /// \param InteropVar The interop variable. + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param VarLoc Location of the interop variable. + /// \param EndLoc Ending location of the clause. + OMPUseClause(Expr *InteropVar, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation VarLoc, + SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_use, StartLoc, EndLoc), LParenLoc(LParenLoc), + VarLoc(VarLoc), InteropVar(InteropVar) {} + + /// Build an empty clause. + OMPUseClause() + : OMPClause(llvm::omp::OMPC_use, SourceLocation(), SourceLocation()) {} + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + /// Returns the location of the interop variable. + SourceLocation getVarLoc() const { return VarLoc; } + + /// Returns the interop variable. + Expr *getInteropVar() const { return cast(InteropVar); } + + child_range children() { return child_range(&InteropVar, &InteropVar + 1); } + + const_child_range children() const { + return const_child_range(&InteropVar, &InteropVar + 1); + } + + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_use; + } +}; + /// This represents 'destroy' clause in the '#pragma omp depobj' /// directive. /// diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3203,6 +3203,12 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPUseClause(OMPUseClause *C) { + TRY_TO(TraverseStmt(C->getInteropVar())); + return true; +} + template bool RecursiveASTVisitor::VisitOMPDestroyClause(OMPDestroyClause *) { return true; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -10992,6 +10992,11 @@ SourceLocation VarLoc, SourceLocation EndLoc); + /// Called on well-formed 'use' clause. + OMPClause *ActOnOpenMPUseClause(Expr *InteropVar, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation VarLoc, SourceLocation EndLoc); + /// Called on well-formed 'destroy' clause. OMPClause *ActOnOpenMPDestroyClause(SourceLocation StartLoc, SourceLocation EndLoc); diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -1801,6 +1801,12 @@ OS << ")"; } +void OMPClausePrinter::VisitOMPUseClause(OMPUseClause *Node) { + OS << "use("; + Node->getInteropVar()->printPretty(OS, nullptr, Policy); + OS << ")"; +} + void OMPClausePrinter::VisitOMPDestroyClause(OMPDestroyClause *) { OS << "destroy"; } diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -547,6 +547,11 @@ VisitOMPClauseList(C); } +void OMPClauseProfiler::VisitOMPUseClause(const OMPUseClause *C) { + if (C->getInteropVar()) + Profiler->VisitStmt(C->getInteropVar()); +} + void OMPClauseProfiler::VisitOMPDestroyClause(const OMPDestroyClause *) {} template diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -2930,6 +2930,7 @@ Clause = ParseOpenMPUsesAllocatorClause(DKind); break; case OMPC_init: + case OMPC_use: Clause = ParseOpenMPInteropClause(CKind, WrongDirective); break; case OMPC_device_type: @@ -3155,6 +3156,9 @@ return Actions.ActOnOpenMPInitClause(InteropVarExpr.get(), Prefs, IsTarget, IsTargetSync, Loc, T.getOpenLocation(), VarLoc, RLoc); + if (Kind == OMPC_use) + return Actions.ActOnOpenMPUseClause(InteropVarExpr.get(), Loc, + T.getOpenLocation(), VarLoc, RLoc); llvm_unreachable("Unexpected interop variable clause."); } diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -14610,8 +14610,8 @@ // OpenMP 5.1 [2.15.1, interop Construct, Restrictions] // At least one action-clause must appear on a directive. - // TODO: also add 'use' and 'destroy' here. - if (!hasClauses(Clauses, OMPC_init, OMPC_nowait)) { + // TODO: also add 'destroy' here. + if (!hasClauses(Clauses, OMPC_init, OMPC_use, OMPC_nowait)) { StringRef Expected = "'init', 'use', 'destroy', or 'nowait'"; Diag(StartLoc, diag::err_omp_no_clause_for_directive) << Expected << getOpenMPDirectiveName(OMPD_interop); @@ -14627,16 +14627,20 @@ // interop-type of 'targetsync'. Cases involving other directives cannot be // diagnosed. const OMPDependClause *DependClause = nullptr; + bool HasInitClause = false; bool IsTargetSync = false; for (const OMPClause *C : Clauses) { if (IsTargetSync) break; - if (const auto *InitClause = dyn_cast(C)) - IsTargetSync = InitClause->getIsTargetSync(); - else if (const auto *DC = dyn_cast(C)) + if (const auto *InitClause = dyn_cast(C)) { + HasInitClause = true; + if (InitClause->getIsTargetSync()) + IsTargetSync = true; + } else if (const auto *DC = dyn_cast(C)) { DependClause = DC; + } } - if (DependClause && !IsTargetSync) { + if (DependClause && HasInitClause && !IsTargetSync) { Diag(DependClause->getBeginLoc(), diag::err_omp_interop_bad_depend_clause); return StmtError(); } @@ -14654,8 +14658,12 @@ const auto *IC = cast(C); VarLoc = IC->getVarLoc(); DRE = dyn_cast_or_null(IC->getInteropVar()); + } else if (ClauseKind == OMPC_use) { + const auto *UC = cast(C); + VarLoc = UC->getVarLoc(); + DRE = dyn_cast_or_null(UC->getInteropVar()); } - // TODO: 'use' and 'destroy' clauses to be added here. + // TODO: 'destroy' clause to be added here. if (!DRE) continue; @@ -14753,6 +14761,18 @@ EndLoc); } +OMPClause *Sema::ActOnOpenMPUseClause(Expr *InteropVar, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation VarLoc, + SourceLocation EndLoc) { + + if (!isValidInteropVariable(*this, InteropVar, VarLoc, OMPC_use)) + return nullptr; + + return new (Context) + OMPUseClause(InteropVar, StartLoc, LParenLoc, VarLoc, EndLoc); +} + OMPClause *Sema::ActOnOpenMPVarListClause( OpenMPClauseKind Kind, ArrayRef VarList, Expr *DepModOrTailExpr, const OMPVarListLocTy &Locs, SourceLocation ColonLoc, diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -2185,6 +2185,17 @@ VarLoc, EndLoc); } + /// Build a new OpenMP 'use' clause. + /// + /// By default, performs semantic analysis to build the new OpenMP clause. + /// Subclasses may override this routine to provide different behavior. + OMPClause *RebuildOMPUseClause(Expr *InteropVar, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation VarLoc, SourceLocation EndLoc) { + return getSema().ActOnOpenMPUseClause(InteropVar, StartLoc, LParenLoc, + VarLoc, EndLoc); + } + /// Rebuild the operand to an Objective-C \@synchronized statement. /// /// By default, performs semantic analysis to build the new statement. @@ -9319,6 +9330,16 @@ C->getBeginLoc(), C->getLParenLoc(), C->getVarLoc(), C->getEndLoc()); } +template +OMPClause *TreeTransform::TransformOMPUseClause(OMPUseClause *C) { + ExprResult ER = getDerived().TransformExpr(C->getInteropVar()); + if (ER.isInvalid()) + return nullptr; + return getDerived().RebuildOMPUseClause(ER.get(), C->getBeginLoc(), + C->getLParenLoc(), C->getVarLoc(), + C->getEndLoc()); +} + template OMPClause * TreeTransform::TransformOMPDestroyClause(OMPDestroyClause *C) { diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -11971,6 +11971,9 @@ case llvm::omp::OMPC_init: C = OMPInitClause::CreateEmpty(Context, Record.readInt()); break; + case llvm::omp::OMPC_use: + C = new (Context) OMPUseClause(); + break; case llvm::omp::OMPC_destroy: C = new (Context) OMPDestroyClause(); break; @@ -12147,6 +12150,12 @@ C->setVarLoc(Record.readSourceLocation()); } +void OMPClauseReader::VisitOMPUseClause(OMPUseClause *C) { + C->setInteropVar(Record.readSubExpr()); + C->setLParenLoc(Record.readSourceLocation()); + C->setVarLoc(Record.readSourceLocation()); +} + void OMPClauseReader::VisitOMPDestroyClause(OMPDestroyClause *) {} void OMPClauseReader::VisitOMPUnifiedAddressClause(OMPUnifiedAddressClause *) {} diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -6225,6 +6225,12 @@ Record.AddSourceLocation(C->getVarLoc()); } +void OMPClauseWriter::VisitOMPUseClause(OMPUseClause *C) { + Record.AddStmt(C->getInteropVar()); + Record.AddSourceLocation(C->getLParenLoc()); + Record.AddSourceLocation(C->getVarLoc()); +} + void OMPClauseWriter::VisitOMPDestroyClause(OMPDestroyClause *) {} void OMPClauseWriter::VisitOMPPrivateClause(OMPPrivateClause *C) { diff --git a/clang/test/OpenMP/interop_ast_print.cpp b/clang/test/OpenMP/interop_ast_print.cpp --- a/clang/test/OpenMP/interop_ast_print.cpp +++ b/clang/test/OpenMP/interop_ast_print.cpp @@ -35,12 +35,31 @@ //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'I' #pragma omp interop init(target:I) + //PRINT: #pragma omp interop use(I) + //DUMP: OMPInteropDirective + //DUMP: OMPUseClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'I' + #pragma omp interop use(I) + //PRINT: #pragma omp interop init(target : IRef) //DUMP: OMPInteropDirective //DUMP: OMPInitClause //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'IRef' #pragma omp interop init(target:IRef) + //PRINT: #pragma omp interop use(IRef) + //DUMP: OMPInteropDirective + //DUMP: OMPUseClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'IRef' + #pragma omp interop use(IRef) + + const omp_interop_t CI = (omp_interop_t)0; + //PRINT: #pragma omp interop use(CI) + //DUMP: OMPInteropDirective + //DUMP: OMPUseClause + //DUMP: DeclRefExpr{{.*}}'const omp_interop_t'{{.*}}Var{{.*}}'CI' + #pragma omp interop use(CI) + //PRINT: #pragma omp interop device(dev) depend(inout : ap) init(targetsync : I) //DUMP: OMPInteropDirective //DUMP: OMPDeviceClause @@ -51,6 +70,16 @@ //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'I' #pragma omp interop device(dev) depend(inout:ap) init(targetsync:I) + //PRINT: #pragma omp interop device(dev) depend(inout : ap) use(I) + //DUMP: OMPInteropDirective + //DUMP: OMPDeviceClause + //DUMP: DeclRefExpr{{.*}}'dev' 'int' + //DUMP: OMPDependClause + //DUMP: DeclRefExpr{{.*}}'ap' 'int *' + //DUMP: OMPUseClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'I' + #pragma omp interop device(dev) depend(inout:ap) use(I) + //PRINT: #pragma omp interop init(prefer_type(1,2,3,4,5,6), targetsync : I) //DUMP: OMPInteropDirective //DUMP: OMPInitClause @@ -106,6 +135,21 @@ //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'J' #pragma omp interop init(target:I) init(targetsync:J) + //PRINT: #pragma omp interop init(target : I) use(J) + //DUMP: OMPInteropDirective + //DUMP: OMPInitClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'I' + //DUMP: OMPUseClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'J' + #pragma omp interop init(target:I) use(J) + + //PRINT: #pragma omp interop use(I) use(J) + //DUMP: OMPInteropDirective + //DUMP: OMPUseClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'I' + //DUMP: OMPUseClause + //DUMP: DeclRefExpr{{.*}}'omp_interop_t'{{.*}}Var{{.*}}'J' + #pragma omp interop use(I) use(J) } //DUMP: FunctionTemplateDecl{{.*}}fooTemp @@ -150,6 +194,12 @@ //DUMP: StringLiteral{{.*}}"level_one" #pragma omp interop init(prefer_type(4,"level_one"), target: t) + //PRINT: #pragma omp interop use(t) + //DUMP: OMPInteropDirective + //DUMP: OMPUseClause + //DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'t' 'T' + #pragma omp interop use(t) + //DUMP: FunctionDecl{{.*}}barTemp 'void (void *)' //DUMP: TemplateArgument type 'void *' //DUMP: ParmVarDecl{{.*}}t 'void *' @@ -157,6 +207,10 @@ //DUMP: OMPInitClause //DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'t' 'void *' //PRINT: #pragma omp interop init(prefer_type(4,"level_one"), target : t) + //DUMP: OMPInteropDirective + //DUMP: OMPUseClause + //DUMP: DeclRefExpr{{.*}}ParmVar{{.*}}'t' 'void *' + //PRINT: #pragma omp interop use(t) } void bar() diff --git a/clang/test/OpenMP/interop_messages.cpp b/clang/test/OpenMP/interop_messages.cpp --- a/clang/test/OpenMP/interop_messages.cpp +++ b/clang/test/OpenMP/interop_messages.cpp @@ -14,6 +14,9 @@ //expected-error@+1 {{use of undeclared identifier 'NoDeclVar'}} #pragma omp interop init(target:NoDeclVar) init(target:Another) + //expected-error@+1 {{use of undeclared identifier 'NoDeclVar'}} + #pragma omp interop use(NoDeclVar) use(Another) + //expected-error@+2 {{expected interop type: 'target' and/or 'targetsync'}} //expected-error@+1 {{expected expression}} #pragma omp interop init(InteropVar) init(target:Another) @@ -32,14 +35,23 @@ #pragma omp interop init(prefer_type(1,"sycl",3),target:IntVar) \ init(target:Another) + //expected-error@+1 {{interop variable must be of type 'omp_interop_t'}} + #pragma omp interop use(IntVar) use(Another) + //expected-error@+1 {{interop variable must be of type 'omp_interop_t'}} #pragma omp interop init(prefer_type(1,"sycl",3),target:SVar) \ init(target:Another) + //expected-error@+1 {{interop variable must be of type 'omp_interop_t'}} + #pragma omp interop use(SVar) use(Another) + int a, b; //expected-error@+1 {{expected variable of type 'omp_interop_t'}} #pragma omp interop init(target:a+b) init(target:Another) + //expected-error@+1 {{expected variable of type 'omp_interop_t'}} + #pragma omp interop use(a+b) use(Another) + const omp_interop_t C = (omp_interop_t)5; //expected-error@+1 {{expected non-const variable of type 'omp_interop_t'}} #pragma omp interop init(target:C) init(target:Another) @@ -64,6 +76,12 @@ //expected-error@+1 {{interop variable 'InteropVar' used in multiple action clauses}} #pragma omp interop init(target:InteropVar) init(target:InteropVar) + //expected-error@+1 {{interop variable 'InteropVar' used in multiple action clauses}} + #pragma omp interop use(InteropVar) use(InteropVar) + + //expected-error@+1 {{interop variable 'InteropVar' used in multiple action clauses}} + #pragma omp interop init(target:InteropVar) use(InteropVar) + //expected-error@+1 {{directive '#pragma omp interop' cannot contain more than one 'device' clause}} #pragma omp interop init(target:InteropVar) device(0) device(1) @@ -79,5 +97,7 @@ int InteropVar; //expected-error@+1 {{'omp_interop_t' type not found; include }} #pragma omp interop init(prefer_type(1,"sycl",3),target:InteropVar) nowait + //expected-error@+1 {{'omp_interop_t' type not found; include }} + #pragma omp interop use(InteropVar) nowait } #endif diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2282,6 +2282,10 @@ VisitOMPClauseList(C); } +void OMPClauseEnqueue::VisitOMPUseClause(const OMPUseClause *C) { + Visitor->AddStmt(C->getInteropVar()); +} + void OMPClauseEnqueue::VisitOMPDestroyClause(const OMPDestroyClause *) {} void OMPClauseEnqueue::VisitOMPUnifiedAddressClause( diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -718,6 +718,7 @@ CHECK_SIMPLE_CLAUSE(UseDeviceAddr, OMPC_use_device_addr) CHECK_SIMPLE_CLAUSE(Write, OMPC_write) CHECK_SIMPLE_CLAUSE(Init, OMPC_init) +CHECK_SIMPLE_CLAUSE(Use, OMPC_use) CHECK_REQ_SCALAR_INT_CLAUSE(Allocator, OMPC_allocator) CHECK_REQ_SCALAR_INT_CLAUSE(Grainsize, OMPC_grainsize) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -271,6 +271,9 @@ def OMPC_Init : Clause<"init"> { let clangClass = "OMPInitClause"; } +def OMPC_Use : Clause<"use"> { + let clangClass = "OMPUseClause"; +} def OMPC_Destroy : Clause<"destroy"> { let clangClass = "OMPDestroyClause"; } @@ -1649,6 +1652,7 @@ VersionedClause, VersionedClause, VersionedClause, + VersionedClause, ]; } def OMP_Unknown : Directive<"unknown"> {