Index: include/clang/Basic/LangOptions.def =================================================================== --- include/clang/Basic/LangOptions.def +++ include/clang/Basic/LangOptions.def @@ -190,6 +190,7 @@ LANGOPT(CUDA , 1, 0, "CUDA") LANGOPT(OpenMP , 32, 0, "OpenMP support and version of OpenMP (31, 40 or 45)") LANGOPT(OpenMPUseTLS , 1, 0, "Use TLS for threadprivates or runtime calls") +LANGOPT(OpenMPImplicitDeclareTarget , 1, 0, "Enable implicit declare target extension - marks automatically declarations and definitions with declare target attribute") LANGOPT(OpenMPIsDevice , 1, 0, "Generate code only for OpenMP target device") LANGOPT(RenderScript , 1, 0, "RenderScript") Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -8658,6 +8658,9 @@ bool isInOpenMPDeclareTargetContext() const { return IsInOpenMPDeclareTargetContext; } + /// Check and mark declarations that are implicitly used inside OpenMP target + /// regions. + void checkDeclImplicitlyUsedOpenMPTargetContext(Decl *D); /// Return the number of captured regions created for an OpenMP directive. static int getOpenMPCaptureLevels(OpenMPDirectiveKind Kind); Index: include/clang/Sema/SemaInternal.h =================================================================== --- include/clang/Sema/SemaInternal.h +++ include/clang/Sema/SemaInternal.h @@ -60,6 +60,16 @@ return isDeviceSideDecl == LangOpts.CUDAIsDevice; } +// Helper function to check whether D's attributes match current offloading +// mode. +inline bool DeclAttrsMatchOffloadMode(const LangOptions &LangOpts, Decl *D, + bool InOpenMPDeviceRegion) { + if (LangOpts.OpenMPIsDevice) + return InOpenMPDeviceRegion; + + return DeclAttrsMatchCUDAMode(LangOpts, D); +} + // Directly mark a variable odr-used. Given a choice, prefer to use // MarkVariableReferenced since it does additional checks and then // calls MarkVarDeclODRUsed. Index: lib/Parse/ParseOpenMP.cpp =================================================================== --- lib/Parse/ParseOpenMP.cpp +++ lib/Parse/ParseOpenMP.cpp @@ -757,6 +757,7 @@ if (!Actions.ActOnStartOpenMPDeclareTargetDirective(DTLoc)) return DeclGroupPtrTy(); + SmallVector Decls; DKind = ParseOpenMPDirectiveKind(*this); while (DKind != OMPD_end_declare_target && DKind != OMPD_declare_target && Tok.isNot(tok::eof) && Tok.isNot(tok::r_brace)) { @@ -780,6 +781,12 @@ else TPA.Commit(); } + + // Save the declarations so that we can create the declare target group + // later on. + if (Ptr) + for (auto *V : Ptr.get()) + Decls.push_back(V); } if (DKind == OMPD_end_declare_target) { @@ -794,8 +801,17 @@ } else { Diag(Tok, diag::err_expected_end_declare_target); Diag(DTLoc, diag::note_matching) << "'#pragma omp declare target'"; + // We have an error, so we don't have to attempt to generate code for the + // declarations. + Decls.clear(); } Actions.ActOnFinishOpenMPDeclareTargetDirective(); + + // If we have decls generate the group so that code can be generated for it + // later on. + if (!Decls.empty()) + return Actions.BuildDeclaratorGroup(Decls); + return DeclGroupPtrTy(); } case OMPD_unknown: Index: lib/Sema/SemaDecl.cpp =================================================================== --- lib/Sema/SemaDecl.cpp +++ lib/Sema/SemaDecl.cpp @@ -6746,7 +6746,8 @@ case SC_Register: // Local Named register if (!Context.getTargetInfo().isValidGCCRegisterName(Label) && - DeclAttrsMatchCUDAMode(getLangOpts(), getCurFunctionDecl())) + DeclAttrsMatchOffloadMode(getLangOpts(), getCurFunctionDecl(), + IsInOpenMPDeclareTargetContext)) Diag(E->getExprLoc(), diag::err_asm_unknown_register_name) << Label; break; case SC_Static: @@ -6756,7 +6757,8 @@ } } else if (SC == SC_Register) { // Global Named register - if (DeclAttrsMatchCUDAMode(getLangOpts(), NewVD)) { + if (DeclAttrsMatchOffloadMode(getLangOpts(), NewVD, + IsInOpenMPDeclareTargetContext)) { const auto &TI = Context.getTargetInfo(); bool HasSizeMismatch; @@ -12656,6 +12658,11 @@ DiscardCleanupsInEvaluationContext(); } + // In case of OpenMPImplicitDeclareTarget, semantically parsed function body + // is visited to mark inner callexpr with OMPDeclareTargetDeclAttr attribute. + if (getLangOpts().OpenMP && getLangOpts().OpenMPImplicitDeclareTarget) + checkDeclImplicitlyUsedOpenMPTargetContext(dcl); + return dcl; } Index: lib/Sema/SemaOpenMP.cpp =================================================================== --- lib/Sema/SemaOpenMP.cpp +++ lib/Sema/SemaOpenMP.cpp @@ -19,6 +19,7 @@ #include "clang/AST/Decl.h" #include "clang/AST/DeclCXX.h" #include "clang/AST/DeclOpenMP.h" +#include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtOpenMP.h" #include "clang/AST/StmtVisitor.h" @@ -1139,6 +1140,124 @@ return false; } +namespace { +/// Visit actual function body and its associated nested functions bodies. +class ImplicitDeviceFunctionChecker + : public RecursiveASTVisitor { + Sema &SemaRef; + +public: + ImplicitDeviceFunctionChecker(Sema &SemaReference) : SemaRef(SemaReference){}; + + /// Traverse body of lambda, and mark it the with OMPDeclareTargetDeclAttr + bool TraverseLambdaCapture(LambdaExpr *LE, const LambdaCapture *C, + Expr *Init); + + /// Traverse FunctionDecl and mark it the with OMPDeclareTargetDeclAttr + bool VisitFunctionDecl(FunctionDecl *F); + + /// Traverse Callee of Calexpr and mark it the with OMPDeclareTargetDeclAttr + bool VisitCallExpr(CallExpr *Call); + + /// Traverse Constructs and mark it the with OMPDeclareTargetDeclAttr + bool VisitCXXConstructExpr(CXXConstructExpr *E); + + /// Traverse Destructor and mark it the with OMPDeclareTargetDeclAttr + bool VisitCXXDestructorDecl(CXXDestructorDecl *D); +}; +} + +/// Traverse declaration of /param D to check whether it has +/// OMPDeclareTargetDeclAttr or not. If so, it marks definition with +/// OMPDeclareTargetDeclAttr. +static void ImplicitDeclareTargetCheck(Sema &SemaRef, Decl *D) { + if (SemaRef.getLangOpts().OpenMPImplicitDeclareTarget) { + // Structured block of target region is visited to catch function call. + // Revealed function calls are marked with OMPDeclareTargetDeclAttr + // attribute, + // in case -fopenmp-implicit-declare-target extension is enabled. + ImplicitDeviceFunctionChecker FunctionCallChecker(SemaRef); + FunctionCallChecker.TraverseDecl(D); + } +} + +/// Traverse declaration of /param D to check whether it has +/// OMPDeclareTargetDeclAttr or not. If so, it marks definition with +/// OMPDeclareTargetDeclAttr. +void Sema::checkDeclImplicitlyUsedOpenMPTargetContext(Decl *D) { + if (!D || D->isInvalidDecl()) + return; + + if (FunctionDecl *FD = dyn_cast(D)) { + if (FD->hasBody()) { + for (auto RI : FD->redecls()) { + if (RI->hasAttr()) { + Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit( + Context, OMPDeclareTargetDeclAttr::MT_To); + D->addAttr(A); + + ImplicitDeclareTargetCheck(*this, FD); + return; + } + } + } + } + return; +} + +bool ImplicitDeviceFunctionChecker::TraverseLambdaCapture( + LambdaExpr *LE, const LambdaCapture *C, Expr *Init) { + if (CXXRecordDecl *Class = LE->getLambdaClass()) + if (!Class->hasAttr()) { + Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit( + SemaRef.Context, OMPDeclareTargetDeclAttr::MT_To); + Class->addAttr(A); + } + + TraverseStmt(LE->getBody()); + return true; +} + +bool ImplicitDeviceFunctionChecker::VisitFunctionDecl(FunctionDecl *F) { + assert(F); + if (!F->hasAttr()) { + Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit( + SemaRef.Context, OMPDeclareTargetDeclAttr::MT_To); + F->addAttr(A); + TraverseDecl(F); + } + return true; +} + +bool ImplicitDeviceFunctionChecker::VisitCallExpr(CallExpr *Call) { + if (FunctionDecl *Callee = Call->getDirectCallee()) { + return VisitFunctionDecl(Callee); + } + return true; +} + +bool ImplicitDeviceFunctionChecker::VisitCXXConstructExpr(CXXConstructExpr *E) { + CXXConstructorDecl *Constructor = E->getConstructor(); + // When constructor is invoked, it is checked whether the object has + // destructor or not. In case it has destructor, destructor is automatically + // marked with declare target attribute since it is needed to emit for device, + QualType Ty = E->getType(); + const RecordType *RT = + SemaRef.Context.getBaseElementType(Ty)->getAs(); + CXXRecordDecl *RD = cast(RT->getDecl()); + + if (auto *Destructor = RD->getDestructor()) + VisitCXXDestructorDecl(Destructor); + + return VisitFunctionDecl(Constructor); +} + +bool ImplicitDeviceFunctionChecker::VisitCXXDestructorDecl( + CXXDestructorDecl *D) { + assert(D); + return VisitFunctionDecl(D); +} + void Sema::InitDataSharingAttributesStack() { VarDataSharingAttributesStack = new DSAStackTy(*this); } @@ -1304,12 +1423,12 @@ // If we are attempting to capture a global variable in a directive with // 'target' we return true so that this global is also mapped to the device. // - // FIXME: If the declaration is enclosed in a 'declare target' directive, - // then it should not be captured. Therefore, an extra check has to be - // inserted here once support for 'declare target' is added. + // If the variable is enclosed in a declare target directive, that is not + // required. // auto *VD = dyn_cast(D); - if (VD && !VD->hasLocalStorage()) { + if (VD && !VD->hasLocalStorage() && + !VD->hasAttr()) { if (isOpenMPTargetExecutionDirective(DSAStack->getCurrentDirective()) && !DSAStack->isClauseParsingMode()) return VD; @@ -6270,6 +6389,8 @@ getCurFunction()->setHasBranchProtectedScope(); + ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl()); + return OMPTargetDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt); } @@ -6290,6 +6411,8 @@ getCurFunction()->setHasBranchProtectedScope(); + ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl()); + return OMPTargetParallelDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt); } @@ -6334,6 +6457,9 @@ } getCurFunction()->setHasBranchProtectedScope(); + + ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl()); + return OMPTargetParallelForDirective::Create(Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B, DSAStack->isCancelRegion()); @@ -6778,6 +6904,9 @@ return StmtError(); getCurFunction()->setHasBranchProtectedScope(); + + ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl()); + return OMPTargetParallelForSimdDirective::Create( Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B); } @@ -6825,6 +6954,9 @@ return StmtError(); getCurFunction()->setHasBranchProtectedScope(); + + ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl()); + return OMPTargetSimdDirective::Create(Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B); } @@ -6906,6 +7038,9 @@ return StmtError(); getCurFunction()->setHasBranchProtectedScope(); + + ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl()); + return OMPTeamsDistributeSimdDirective::Create( Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B); } @@ -7020,6 +7155,8 @@ getCurFunction()->setHasBranchProtectedScope(); + ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl()); + return OMPTargetTeamsDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt); } @@ -7054,6 +7191,9 @@ "omp target teams distribute loop exprs were not built"); getCurFunction()->setHasBranchProtectedScope(); + + ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl()); + return OMPTargetTeamsDistributeDirective::Create( Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B); } @@ -7099,6 +7239,9 @@ } getCurFunction()->setHasBranchProtectedScope(); + + ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl()); + return OMPTargetTeamsDistributeParallelForDirective::Create( Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B); } @@ -7145,6 +7288,9 @@ } getCurFunction()->setHasBranchProtectedScope(); + + ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl()); + return OMPTargetTeamsDistributeParallelForSimdDirective::Create( Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B); } @@ -7178,6 +7324,9 @@ "omp target teams distribute simd loop exprs were not built"); getCurFunction()->setHasBranchProtectedScope(); + + ImplicitDeclareTargetCheck(*this, CS->getCapturedDecl()); + return OMPTargetTeamsDistributeSimdDirective::Create( Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B); } @@ -12041,17 +12190,18 @@ // target region (it can be e.g. a lambda) that is legal and we do not need // to do anything else. if (LD == D) { - Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit( - SemaRef.Context, OMPDeclareTargetDeclAttr::MT_To); - D->addAttr(A); - if (ASTMutationListener *ML = SemaRef.Context.getASTMutationListener()) - ML->DeclarationMarkedOpenMPDeclareTarget(D, A); + if (!SemaRef.getLangOpts().OpenMPImplicitDeclareTarget) + if (!D->hasAttr()) + SemaRef.Diag(LD->getLocation(), diag::warn_omp_not_in_target_context); + return; } } if (!LD) LD = D; - if (LD && !LD->hasAttr() && + // The parameters of a function are considered 'declare target' declarations + // if the function itself is 'declare target'. + if (LD && !LD->hasAttr() && !isa(LD) && (isa(LD) || isa(LD))) { // Outlined declaration is not declared target. if (LD->isOutOfLine()) { @@ -12120,6 +12270,16 @@ return; } } + if (TemplateDecl *TD = dyn_cast(D)) { + // Mark template declarations as declare target so that they can propagate + // that information to their instances. + Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit( + Context, OMPDeclareTargetDeclAttr::MT_To); + TD->addAttr(A); + if (ASTMutationListener *ML = Context.getASTMutationListener()) + ML->DeclarationMarkedOpenMPDeclareTarget(TD, A); + return; + } if (!E) { // Checking declaration inside declare target region. if (!D->hasAttr() &&