diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -3887,6 +3887,8 @@ SourceLocation LParenLoc, MultiExprArg Args, SourceLocation RParenLoc, + Expr *ExecConfig = nullptr, + bool IsExecConfig = false, bool AllowRecovery = false); ExprResult BuildCallToObjectOfClassType(Scope *S, Expr *Object, SourceLocation LParenLoc, diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -6498,7 +6498,8 @@ if (Fn->getType() == Context.BoundMemberTy) { return BuildCallToMemberFunction(Scope, Fn, LParenLoc, ArgExprs, - RParenLoc, AllowRecovery); + RParenLoc, ExecConfig, IsExecConfig, + AllowRecovery); } } @@ -6517,7 +6518,8 @@ Scope, Fn, ULE, LParenLoc, ArgExprs, RParenLoc, ExecConfig, /*AllowTypoCorrection=*/true, find.IsAddressOfOperand); return BuildCallToMemberFunction(Scope, Fn, LParenLoc, ArgExprs, - RParenLoc, AllowRecovery); + RParenLoc, ExecConfig, IsExecConfig, + AllowRecovery); } } diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp --- a/clang/lib/Sema/SemaOverload.cpp +++ b/clang/lib/Sema/SemaOverload.cpp @@ -14166,6 +14166,7 @@ SourceLocation LParenLoc, MultiExprArg Args, SourceLocation RParenLoc, + Expr *ExecConfig, bool IsExecConfig, bool AllowRecovery) { assert(MemExprE->getType() == Context.BoundMemberTy || MemExprE->getType() == Context.OverloadTy); @@ -14361,8 +14362,8 @@ // If overload resolution picked a static member, build a // non-member call based on that function. if (Method->isStatic()) { - return BuildResolvedCallExpr(MemExprE, Method, LParenLoc, Args, - RParenLoc); + return BuildResolvedCallExpr(MemExprE, Method, LParenLoc, Args, RParenLoc, + ExecConfig, IsExecConfig); } MemExpr = cast(MemExprE->IgnoreParens()); diff --git a/clang/test/SemaCUDA/kernel-call.cu b/clang/test/SemaCUDA/kernel-call.cu --- a/clang/test/SemaCUDA/kernel-call.cu +++ b/clang/test/SemaCUDA/kernel-call.cu @@ -26,3 +26,34 @@ g1<<>>(42); // expected-error {{use of undeclared identifier 'undeclared'}} } + +// Make sure we can call static member kernels. +template struct a0 { + template static __global__ void Call(T); +}; +struct a1 { + template static __global__ void Call(T); +}; +template struct a2 { + static __global__ void Call(T); +}; +struct a3 { + static __global__ void Call(int); + static __global__ void Call(void*); +}; + +struct b { + template void d0(c arg) { + a0::Call<<<0, 0>>>(arg); + a1::Call<<<0,0>>>(arg); + a2::Call<<<0,0>>>(arg); + a3::Call<<<0, 0>>>(arg); + } + void d1(void* arg) { + a0::Call<<<0, 0>>>(arg); + a1::Call<<<0,0>>>(arg); + a2::Call<<<0,0>>>(arg); + a3::Call<<<0, 0>>>(arg); + } + void e() { d0(1); } +};