Index: clang/include/clang/Sema/Sema.h =================================================================== --- clang/include/clang/Sema/Sema.h +++ clang/include/clang/Sema/Sema.h @@ -3319,12 +3319,16 @@ void ActOnReenterFunctionContext(Scope* S, Decl* D); void ActOnExitFunctionContext(); - DeclContext *getFunctionLevelDeclContext(); - - /// getCurFunctionDecl - If inside of a function body, this returns a pointer - /// to the function decl for the function being parsed. If we're currently - /// in a 'block', this returns the containing context. - FunctionDecl *getCurFunctionDecl(); + /// If \p AllowLambda is true, treat lambda as function. + DeclContext *getFunctionLevelDeclContext(bool AllowLambda = false); + + /// getCurFunctionDecl - If parsing a lambda, then return the lambda + /// declaration if \p AllowLambda is true, otherwise return the function + /// declaration enclosing the lambda. If in 'block' context, return the + /// enclosing function or lambda declaration if \p AllowLambda is true, + /// otherwise return the enclosing function declaration. For regular + /// functions, return the function declarations. + FunctionDecl *getCurFunctionDecl(bool AllowLambda = false); /// getCurMethodDecl - If inside of a method body, this returns a pointer to /// the method decl for the method being parsed. If we're currently Index: clang/lib/Sema/Sema.cpp =================================================================== --- clang/lib/Sema/Sema.cpp +++ clang/lib/Sema/Sema.cpp @@ -1415,19 +1415,18 @@ // Helper functions. //===----------------------------------------------------------------------===// -DeclContext *Sema::getFunctionLevelDeclContext() { +DeclContext *Sema::getFunctionLevelDeclContext(bool AllowLambda) { DeclContext *DC = CurContext; while (true) { if (isa(DC) || isa(DC) || isa(DC) || isa(DC)) { DC = DC->getParent(); - } else if (isa(DC) && + } else if (!AllowLambda && isa(DC) && cast(DC)->getOverloadedOperator() == OO_Call && cast(DC->getParent())->isLambda()) { DC = DC->getParent()->getParent(); - } - else break; + } else break; } return DC; @@ -1436,8 +1435,8 @@ /// getCurFunctionDecl - If inside of a function body, this returns a pointer /// to the function decl for the function being parsed. If we're currently /// in a 'block', this returns the containing context. -FunctionDecl *Sema::getCurFunctionDecl() { - DeclContext *DC = getFunctionLevelDeclContext(); +FunctionDecl *Sema::getCurFunctionDecl(bool AllowLambda) { + DeclContext *DC = getFunctionLevelDeclContext(AllowLambda); return dyn_cast(DC); } Index: clang/lib/Sema/SemaCUDA.cpp =================================================================== --- clang/lib/Sema/SemaCUDA.cpp +++ clang/lib/Sema/SemaCUDA.cpp @@ -730,8 +730,9 @@ Sema::SemaDiagnosticBuilder Sema::CUDADiagIfDeviceCode(SourceLocation Loc, unsigned DiagID) { assert(getLangOpts().CUDA && "Should only be called during CUDA compilation"); + FunctionDecl *CurFunContext = getCurFunctionDecl(/*AllowLambda=*/true); SemaDiagnosticBuilder::Kind DiagKind = [&] { - if (!isa(CurContext)) + if (!CurFunContext) return SemaDiagnosticBuilder::K_Nop; switch (CurrentCUDATarget()) { case CFT_Global: @@ -745,7 +746,7 @@ return SemaDiagnosticBuilder::K_Nop; if (IsLastErrorImmediate && Diags.getDiagnosticIDs()->isBuiltinNote(DiagID)) return SemaDiagnosticBuilder::K_Immediate; - return (getEmissionStatus(cast(CurContext)) == + return (getEmissionStatus(CurFunContext) == FunctionEmissionStatus::Emitted) ? SemaDiagnosticBuilder::K_ImmediateWithCallStack : SemaDiagnosticBuilder::K_Deferred; @@ -753,15 +754,15 @@ return SemaDiagnosticBuilder::K_Nop; } }(); - return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, - dyn_cast(CurContext), *this); + return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, CurFunContext, *this); } Sema::SemaDiagnosticBuilder Sema::CUDADiagIfHostCode(SourceLocation Loc, unsigned DiagID) { assert(getLangOpts().CUDA && "Should only be called during CUDA compilation"); + FunctionDecl *CurFunContext = getCurFunctionDecl(/*AllowLambda=*/true); SemaDiagnosticBuilder::Kind DiagKind = [&] { - if (!isa(CurContext)) + if (!CurFunContext) return SemaDiagnosticBuilder::K_Nop; switch (CurrentCUDATarget()) { case CFT_Host: @@ -774,7 +775,7 @@ return SemaDiagnosticBuilder::K_Nop; if (IsLastErrorImmediate && Diags.getDiagnosticIDs()->isBuiltinNote(DiagID)) return SemaDiagnosticBuilder::K_Immediate; - return (getEmissionStatus(cast(CurContext)) == + return (getEmissionStatus(CurFunContext) == FunctionEmissionStatus::Emitted) ? SemaDiagnosticBuilder::K_ImmediateWithCallStack : SemaDiagnosticBuilder::K_Deferred; @@ -782,8 +783,7 @@ return SemaDiagnosticBuilder::K_Nop; } }(); - return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, - dyn_cast(CurContext), *this); + return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, CurFunContext, *this); } bool Sema::CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee) { @@ -796,7 +796,7 @@ // FIXME: Is bailing out early correct here? Should we instead assume that // the caller is a global initializer? - FunctionDecl *Caller = dyn_cast(CurContext); + FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true); if (!Caller) return true; @@ -862,7 +862,7 @@ // File-scope lambda can only do init captures for global variables, which // results in passing by value for these global variables. - FunctionDecl *Caller = dyn_cast(CurContext); + FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true); if (!Caller) return; Index: clang/lib/Sema/SemaOverload.cpp =================================================================== --- clang/lib/Sema/SemaOverload.cpp +++ clang/lib/Sema/SemaOverload.cpp @@ -6473,7 +6473,7 @@ // (CUDA B.1): Check for invalid calls between targets. if (getLangOpts().CUDA) - if (const FunctionDecl *Caller = dyn_cast(CurContext)) + if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true)) // Skip the check for callers that are implicit members, because in this // case we may not yet know what the member's target is; the target is // inferred for the member automatically, based on the bases and fields of @@ -6983,7 +6983,7 @@ // (CUDA B.1): Check for invalid calls between targets. if (getLangOpts().CUDA) - if (const FunctionDecl *Caller = dyn_cast(CurContext)) + if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true)) if (!IsAllowedCUDACall(Caller, Method)) { Candidate.Viable = false; Candidate.FailureKind = ovl_fail_bad_target; @@ -9639,7 +9639,7 @@ // overloading resolution diagnostics. if (S.getLangOpts().CUDA && Cand1.Function && Cand2.Function && S.getLangOpts().GPUExcludeWrongSideOverloads) { - if (FunctionDecl *Caller = dyn_cast(S.CurContext)) { + if (FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true)) { bool IsCallerImplicitHD = Sema::isCUDAImplicitHostDeviceFunction(Caller); bool IsCand1ImplicitHD = Sema::isCUDAImplicitHostDeviceFunction(Cand1.Function); @@ -9922,7 +9922,7 @@ // If other rules cannot determine which is better, CUDA preference is used // to determine which is better. if (S.getLangOpts().CUDA && Cand1.Function && Cand2.Function) { - FunctionDecl *Caller = dyn_cast(S.CurContext); + FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true); return S.IdentifyCUDAPreference(Caller, Cand1.Function) > S.IdentifyCUDAPreference(Caller, Cand2.Function); } @@ -10043,7 +10043,7 @@ // -fgpu-exclude-wrong-side-overloads is on, all candidates are compared // uniformly in isBetterOverloadCandidate. if (S.getLangOpts().CUDA && !S.getLangOpts().GPUExcludeWrongSideOverloads) { - const FunctionDecl *Caller = dyn_cast(S.CurContext); + const FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true); bool ContainsSameSideCandidate = llvm::any_of(Candidates, [&](OverloadCandidate *Cand) { // Check viable function only. @@ -11068,7 +11068,7 @@ /// CUDA: diagnose an invalid call across targets. static void DiagnoseBadTarget(Sema &S, OverloadCandidate *Cand) { - FunctionDecl *Caller = cast(S.CurContext); + FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true); FunctionDecl *Callee = Cand->Function; Sema::CUDAFunctionTarget CallerTarget = S.IdentifyCUDATarget(Caller), @@ -12127,7 +12127,7 @@ if (FunctionDecl *FunDecl = dyn_cast(Fn)) { if (S.getLangOpts().CUDA) - if (FunctionDecl *Caller = dyn_cast(S.CurContext)) + if (FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true)) if (!Caller->isImplicit() && !S.IsAllowedCUDACall(Caller, FunDecl)) return false; if (FunDecl->isMultiVersion()) { @@ -12244,7 +12244,8 @@ } void EliminateSuboptimalCudaMatches() { - S.EraseUnwantedCUDAMatches(dyn_cast(S.CurContext), Matches); + S.EraseUnwantedCUDAMatches(S.getCurFunctionDecl(/*AllowLambda=*/true), + Matches); } public: Index: clang/test/CodeGenCUDA/openmp-parallel.cu =================================================================== --- /dev/null +++ clang/test/CodeGenCUDA/openmp-parallel.cu @@ -0,0 +1,28 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu \ +// RUN: -fopenmp -emit-llvm -o - -x hip %s | FileCheck %s + +#include "Inputs/cuda.h" + +void foo(double) {} +__device__ void foo(int) {} + +// Check foo resolves to the host function. +// CHECK-LABLE: define {{.*}}@_Z5test1v +// CHECK: call void @_Z3food(double noundef 1.000000e+00) +void test1() { + #pragma omp parallel + for (int i = 0; i < 100; i++) + foo(1); +} + +// Check foo resolves to the host function. +// CHECK-LABLE: define {{.*}}@_Z5test2v +// CHECK: call void @_Z3food(double noundef 1.000000e+00) +void test2() { + auto Lambda = []() { + #pragma omp parallel + for (int i = 0; i < 100; i++) + foo(1); + }; + Lambda(); +} Index: clang/test/SemaCUDA/openmp-parallel.cu =================================================================== --- /dev/null +++ clang/test/SemaCUDA/openmp-parallel.cu @@ -0,0 +1,19 @@ +// RUN: %clang_cc1 -fopenmp -fsyntax-only -verify %s + +#include "Inputs/cuda.h" + +__device__ void foo(int) {} // expected-note {{candidate function not viable: call to __device__ function from __host__ function}} +// expected-note@-1 {{'foo' declared here}} + +int main() { + #pragma omp parallel + for (int i = 0; i < 100; i++) + foo(1); // expected-error {{no matching function for call to 'foo'}} + + auto Lambda = []() { + #pragma omp parallel + for (int i = 0; i < 100; i++) + foo(1); // expected-error {{reference to __device__ function 'foo' in __host__ __device__ function}} + }; + Lambda(); // expected-note {{called by 'main'}} +}