Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -1009,6 +1009,10 @@ EK_Decltype, EK_TemplateArgument, EK_Other } ExprContext; + /// If we are checking arguments of a template, this is the template + /// under check. + TemplateDecl *Template; + ExpressionEvaluationContextRecord(ExpressionEvaluationContext Context, unsigned NumCleanupObjects, CleanupInfo ParentCleanup, @@ -1017,7 +1021,7 @@ : Context(Context), ParentCleanup(ParentCleanup), NumCleanupObjects(NumCleanupObjects), NumTypos(0), ManglingContextDecl(ManglingContextDecl), MangleNumbering(), - ExprContext(ExprContext) {} + ExprContext(ExprContext), Template(nullptr) {} /// Retrieve the mangling numbering context, used to consistently /// number constructs like lambdas for mangling. @@ -6453,10 +6457,12 @@ bool CheckTemplateArgument(TemplateTypeParmDecl *Param, TypeSourceInfo *Arg); - ExprResult CheckTemplateArgument(NonTypeTemplateParmDecl *Param, - QualType InstantiatedParamType, Expr *Arg, - TemplateArgument &Converted, - CheckTemplateArgumentKind CTAK = CTAK_Specified); + ExprResult + CheckTemplateArgument(NonTypeTemplateParmDecl *Param, + QualType InstantiatedParamType, Expr *Arg, + TemplateArgument &Converted, + CheckTemplateArgumentKind CTAK = CTAK_Specified, + TemplateDecl *Template = nullptr); bool CheckTemplateTemplateArgument(TemplateParameterList *Params, TemplateArgumentLoc &Arg); Index: lib/Sema/SemaCUDA.cpp =================================================================== --- lib/Sema/SemaCUDA.cpp +++ lib/Sema/SemaCUDA.cpp @@ -836,9 +836,22 @@ bool Sema::CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee) { assert(getLangOpts().CUDA && "Should only be called during CUDA compilation"); assert(Callee && "Callee may not be null."); + + auto &ExprEvalCtx = ExprEvalContexts.back(); + if (ExprEvalCtx.isUnevaluated()) + return true; + + FunctionDecl *Caller = nullptr; + if (auto *Template = ExprEvalContexts.back().Template) { + if (auto *FD = dyn_cast(Template->getTemplatedDecl())) + Caller = FD; + } else if (ExprEvalCtx.isConstantEvaluated()) + return true; + // FIXME: Is bailing out early correct here? Should we instead assume that // the caller is a global initializer? - FunctionDecl *Caller = dyn_cast(CurContext); + if (!Caller) + Caller = dyn_cast(CurContext); if (!Caller) return true; Index: lib/Sema/SemaTemplate.cpp =================================================================== --- lib/Sema/SemaTemplate.cpp +++ lib/Sema/SemaTemplate.cpp @@ -4534,6 +4534,7 @@ EnterExpressionEvaluationContext ConstantEvaluated( SemaRef, Sema::ExpressionEvaluationContext::ConstantEvaluated); + SemaRef.ExprEvalContexts.back().Template = Template; return SemaRef.SubstExpr(Param->getDefaultArgument(), TemplateArgLists); } @@ -4784,8 +4785,8 @@ TemplateArgument Result; unsigned CurSFINAEErrors = NumSFINAEErrors; ExprResult Res = - CheckTemplateArgument(NTTP, NTTPType, Arg.getArgument().getAsExpr(), - Result, CTAK); + CheckTemplateArgument(NTTP, NTTPType, Arg.getArgument().getAsExpr(), + Result, CTAK, dyn_cast(Template)); if (Res.isInvalid()) return true; // If the current template argument causes an error, give up now. @@ -6154,6 +6155,22 @@ return true; } +namespace { +FunctionDecl *GetFunctionDecl(Expr *Arg) { + Expr *E = Arg; + if (UnaryOperator *UO = dyn_cast(E)) { + E = UO ? UO->getSubExpr() : nullptr; + } + if (DeclRefExpr *DRE = dyn_cast_or_null(E)) { + ValueDecl *Entity = DRE ? DRE->getDecl() : nullptr; + if (Entity) { + if (auto Callee = dyn_cast(Entity)) + return Callee; + } + } + return nullptr; +} +} // namespace /// Check a template argument against its corresponding /// non-type template parameter. /// @@ -6164,7 +6181,8 @@ ExprResult Sema::CheckTemplateArgument(NonTypeTemplateParmDecl *Param, QualType ParamType, Expr *Arg, TemplateArgument &Converted, - CheckTemplateArgumentKind CTAK) { + CheckTemplateArgumentKind CTAK, + TemplateDecl *Template) { SourceLocation StartLoc = Arg->getBeginLoc(); // If the parameter type somehow involves auto, deduce the type now. @@ -6251,6 +6269,7 @@ // a constant-evaluated context. EnterExpressionEvaluationContext ConstantEvaluated( *this, Sema::ExpressionEvaluationContext::ConstantEvaluated); + ExprEvalContexts.back().Template = Template; if (getLangOpts().CPlusPlus17) { // C++17 [temp.arg.nontype]p1: @@ -6570,6 +6589,10 @@ return ExprError(); } + if (auto *FD = GetFunctionDecl(Arg)) + if (!CheckCUDACall(Arg->getBeginLoc(), FD)) + return ExprError(); + if (!ParamType->isMemberPointerType()) { if (CheckTemplateArgumentAddressOfObjectOrFunction(*this, Param, ParamType, Index: test/SemaCUDA/kernel-template-with-func-arg.cu =================================================================== --- /dev/null +++ test/SemaCUDA/kernel-template-with-func-arg.cu @@ -0,0 +1,57 @@ +// RUN: %clang_cc1 -fsyntax-only -verify %s + +#include "Inputs/cuda.h" + +struct C { + __device__ void devfun() {} + void hostfun() {} + template __device__ void devtempfun() {} + __device__ __host__ void devhostfun() {} +}; + +__device__ void devfun() {} +__host__ void hostfun() {} +template __device__ void devtempfun() {} +__device__ __host__ void devhostfun() {} + +template __global__ void kernel() { devF();} +template __global__ void kernel2(T *p) { (p->*devF)(); } + +template<> __global__ void kernel(); +template<> __global__ void kernel(); // expected-error {{no function template matches function template specialization 'kernel'}} + // expected-note@-5 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}} +template<> __global__ void kernel >(); +template<> __global__ void kernel(); + +template<> __global__ void kernel<&devfun>(); +template<> __global__ void kernel<&hostfun>(); // expected-error {{no function template matches function template specialization 'kernel'}} + // expected-note@-11 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}} +template<> __global__ void kernel<&devtempfun >(); +template<> __global__ void kernel<&devhostfun>(); + +template<> __global__ void kernel2(C *p); +template<> __global__ void kernel2(C *p); // expected-error {{no function template matches function template specialization 'kernel2'}} + // expected-note@-16 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}} +template<> __global__ void kernel2 >(C *p); +template<> __global__ void kernel2(C *p); + +void fun() { + kernel<&devfun><<<1,1>>>(); + kernel<&hostfun><<<1,1>>>(); // expected-error {{no matching function for call to 'kernel'}} + // expected-note@-24 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}} + kernel<&devtempfun ><<<1,1>>>(); + kernel<&devhostfun><<<1,1>>>(); + + kernel<<<1,1>>>(); + kernel<<<1,1>>>(); // expected-error {{no matching function for call to 'kernel'}} + // expected-note@-30 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}} + kernel ><<<1,1>>>(); + kernel<<<1,1>>>(); + + C a; + kernel2<<<1,1>>>(&a); + kernel2<<<1,1>>>(&a); // expected-error {{no matching function for call to 'kernel2'}} + // expected-note@-36 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}} + kernel2 ><<<1,1>>>(&a); + kernel2<<<1,1>>>(&a); +}