Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -9403,7 +9403,8 @@ /// /// Use this rather than examining the function's attributes yourself -- you /// will get it wrong. Returns CFT_Host if D is null. - CUDAFunctionTarget IdentifyCUDATarget(const FunctionDecl *D); + CUDAFunctionTarget IdentifyCUDATarget(const FunctionDecl *D, + bool IgnoreImplicitHDAttr = false); CUDAFunctionTarget IdentifyCUDATarget(const AttributeList *Attr); /// Gets the CUDA target for the current context. Index: lib/Sema/SemaCUDA.cpp =================================================================== --- lib/Sema/SemaCUDA.cpp +++ lib/Sema/SemaCUDA.cpp @@ -84,7 +84,7 @@ if (HasGlobalAttr) return CFT_Global; - if ((HasHostAttr && HasDeviceAttr) || ForceCUDAHostDeviceDepth > 0) + if (HasHostAttr && HasDeviceAttr) return CFT_HostDevice; if (HasDeviceAttr) @@ -93,8 +93,19 @@ return CFT_Host; } +template +static bool getAttr(const FunctionDecl *D, bool IgnoreImplicitAttr) { + if (Attr *Attribute = D->getAttr()) { + if (IgnoreImplicitAttr && Attribute->isImplicit()) + return false; + return true; + } + return false; +} + /// IdentifyCUDATarget - Determine the CUDA compilation target for this function -Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D) { +Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D, + bool IgnoreImplicitHDAttr) { // Code that lives outside a function is run on the host. if (D == nullptr) return CFT_Host; @@ -105,13 +116,13 @@ if (D->hasAttr()) return CFT_Global; - if (D->hasAttr()) { - if (D->hasAttr()) + if (getAttr(D, IgnoreImplicitHDAttr)) { + if (getAttr(D, IgnoreImplicitHDAttr)) return CFT_HostDevice; return CFT_Device; - } else if (D->hasAttr()) { + } else if (getAttr(D, IgnoreImplicitHDAttr)) { return CFT_Host; - } else if (D->isImplicit()) { + } else if (D->isImplicit() && !IgnoreImplicitHDAttr) { // Some implicit declarations (like intrinsic functions) are not marked. // Set the most lenient target on them for maximal flexibility. return CFT_HostDevice; @@ -523,8 +534,10 @@ return; } - NewD->addAttr(CUDAHostAttr::CreateImplicit(Context)); - NewD->addAttr(CUDADeviceAttr::CreateImplicit(Context)); + if (!NewD->hasAttr()) + NewD->addAttr(CUDAHostAttr::CreateImplicit(Context)); + if (!NewD->hasAttr()) + NewD->addAttr(CUDADeviceAttr::CreateImplicit(Context)); } // In CUDA, there are some constructs which may appear in semantically-valid @@ -867,3 +880,21 @@ } } } + +void Sema::mergeCUDATargetAttributes(NamedDecl *New, Decl *Old) { + if (auto *OldAttr = Old->getMostRecentDecl()->getAttr()) { + auto *NewAttr = OldAttr->clone(Context); + NewAttr->setInherited(true); + New->addAttr(NewAttr); + } + if (auto *OldAttr = Old->getMostRecentDecl()->getAttr()) { + auto *NewAttr = OldAttr->clone(Context); + NewAttr->setInherited(true); + New->addAttr(NewAttr); + } + if (auto *OldAttr = Old->getMostRecentDecl()->getAttr()) { + auto *NewAttr = OldAttr->clone(Context); + NewAttr->setInherited(true); + New->addAttr(NewAttr); + } +} Index: lib/Sema/SemaDeclAttr.cpp =================================================================== --- lib/Sema/SemaDeclAttr.cpp +++ lib/Sema/SemaDeclAttr.cpp @@ -5616,15 +5616,18 @@ handleFormatArgAttr(S, D, Attr); break; case AttributeList::AT_CUDAGlobal: - handleGlobalAttr(S, D, Attr); + if (!D->hasAttr()) + handleGlobalAttr(S, D, Attr); break; case AttributeList::AT_CUDADevice: - handleSimpleAttributeWithExclusions(S, D, - Attr); + if (!D->hasAttr()) + handleSimpleAttributeWithExclusions(S, D, + Attr); break; case AttributeList::AT_CUDAHost: - handleSimpleAttributeWithExclusions(S, D, - Attr); + if (!D->hasAttr()) + handleSimpleAttributeWithExclusions(S, D, + Attr); break; case AttributeList::AT_GNUInline: handleGNUInlineAttr(S, D, Attr); Index: lib/Sema/SemaTemplate.cpp =================================================================== --- lib/Sema/SemaTemplate.cpp +++ lib/Sema/SemaTemplate.cpp @@ -7043,13 +7043,13 @@ // Filter out matches that have different target. if (LangOpts.CUDA && - IdentifyCUDATarget(Specialization) != IdentifyCUDATarget(FD)) { + IdentifyCUDATarget(Specialization, true) != + IdentifyCUDATarget(FD, true)) { FailedCandidates.addCandidate().set( I.getPair(), FunTmpl->getTemplatedDecl(), MakeDeductionFailureInfo(Context, TDK_CUDATargetMismatch, Info)); continue; } - // Record this candidate. if (ExplicitTemplateArgs) ConvertedTemplateArgs[Specialization] = std::move(Args); @@ -7164,6 +7164,8 @@ // the prior function template specialization. Previous.clear(); Previous.addDecl(Specialization); + if (LangOpts.CUDA) + mergeCUDATargetAttributes(FD, Specialization); return false; } @@ -8114,7 +8116,7 @@ // Filter out matches that have different target. if (LangOpts.CUDA && - IdentifyCUDATarget(Specialization) != IdentifyCUDATarget(Attr)) { + IdentifyCUDATarget(Specialization, true) != IdentifyCUDATarget(Attr)) { FailedCandidates.addCandidate().set( P.getPair(), FunTmpl->getTemplatedDecl(), MakeDeductionFailureInfo(Context, TDK_CUDATargetMismatch, Info)); Index: test/SemaCUDA/function-template-overload.cu =================================================================== --- test/SemaCUDA/function-template-overload.cu +++ test/SemaCUDA/function-template-overload.cu @@ -56,24 +56,51 @@ template __host__ __device__ HDType overload_h_d2(T a) { return HDType(); } template __device__ DType overload_h_d2(T1 a) { T1 x; T2 y; return DType(); } +// constexpr functions are implicitly HD, but explicit +// instantiation/specialization must use target attributes as written. +template constexpr T overload_ce_implicit_hd(T a) { return a+1; } +// expected-note@-1 3 {{candidate template ignored: target attributes do not match}} + +// These will not match the template. +template __host__ __device__ int overload_ce_implicit_hd(int a); +// expected-error@-1 {{explicit instantiation of 'overload_ce_implicit_hd' does not refer to a function template, variable template, member function, member class, or static data member}} +template <> __host__ __device__ long overload_ce_implicit_hd(long a); +// expected-error@-1 {{no function template matches function template specialization 'overload_ce_implicit_hd'}} +template <> __host__ __device__ constexpr long overload_ce_implicit_hd(long a); +// expected-error@-1 {{no function template matches function template specialization 'overload_ce_implicit_hd'}} + +// These should work. +template __host__ int overload_ce_implicit_hd(int a); +template <> __host__ long overload_ce_implicit_hd(long a); + +template float overload_ce_implicit_hd(float a); +template <> float* overload_ce_implicit_hd(float *a); +template <> constexpr double overload_ce_implicit_hd(double a) { return a + 3.0; }; + __host__ void hf() { overload_hd(13); + overload_ce_implicit_hd('h'); // Implicitly instantiated + overload_ce_implicit_hd(1.0f); // Explicitly instantiated + overload_ce_implicit_hd(2.0); // Explicitly specialized HType h = overload_h_d(10); HType h2i = overload_h_d2(11); HType h2ii = overload_h_d2(12); // These should be implicitly instantiated from __host__ template returning HType. - DType d = overload_h_d(20); // expected-error {{no viable conversion from 'HType' to 'DType'}} - DType d2i = overload_h_d2(21); // expected-error {{no viable conversion from 'HType' to 'DType'}} + DType d = overload_h_d(20); // expected-error {{no viable conversion from 'HType' to 'DType'}} + DType d2i = overload_h_d2(21); // expected-error {{no viable conversion from 'HType' to 'DType'}} DType d2ii = overload_h_d2(22); // expected-error {{no viable conversion from 'HType' to 'DType'}} } __device__ void df() { overload_hd(23); + overload_ce_implicit_hd('d'); // Implicitly instantiated + overload_ce_implicit_hd(1.0f); // Explicitly instantiated + overload_ce_implicit_hd(2.0); // Explicitly specialized // These should be implicitly instantiated from __device__ template returning DType. - HType h = overload_h_d(10); // expected-error {{no viable conversion from 'DType' to 'HType'}} - HType h2i = overload_h_d2(11); // expected-error {{no viable conversion from 'DType' to 'HType'}} + HType h = overload_h_d(10); // expected-error {{no viable conversion from 'DType' to 'HType'}} + HType h2i = overload_h_d2(11); // expected-error {{no viable conversion from 'DType' to 'HType'}} HType h2ii = overload_h_d2(12); // expected-error {{no viable conversion from 'DType' to 'HType'}} DType d = overload_h_d(20);