diff --git a/clang/lib/CodeGen/CGCUDANV.cpp b/clang/lib/CodeGen/CGCUDANV.cpp --- a/clang/lib/CodeGen/CGCUDANV.cpp +++ b/clang/lib/CodeGen/CGCUDANV.cpp @@ -49,10 +49,10 @@ const Decl *D; }; llvm::SmallVector EmittedKernels; - // Map a device stub function to a symbol for identifying kernel in host code. + // Map a kernel mangled name to a symbol for identifying kernel in host code // For CUDA, the symbol for identifying the kernel is the same as the device // stub function. For HIP, they are different. - llvm::DenseMap KernelHandles; + llvm::DenseMap KernelHandles; // Map a kernel handle to the kernel stub. llvm::DenseMap KernelStubs; struct VarInfo { @@ -310,7 +310,8 @@ void CGNVCUDARuntime::emitDeviceStub(CodeGenFunction &CGF, FunctionArgList &Args) { EmittedKernels.push_back({CGF.CurFn, CGF.CurFuncDecl}); - if (auto *GV = dyn_cast(KernelHandles[CGF.CurFn])) { + if (auto *GV = + dyn_cast(KernelHandles[CGF.CurFn->getName()])) { GV->setLinkage(CGF.CurFn->getLinkage()); GV->setInitializer(CGF.CurFn); } @@ -400,8 +401,8 @@ ShmemSize.getPointer(), Stream.getPointer()}); // Emit the call to cudaLaunch - llvm::Value *Kernel = - CGF.Builder.CreatePointerCast(KernelHandles[CGF.CurFn], VoidPtrTy); + llvm::Value *Kernel = CGF.Builder.CreatePointerCast( + KernelHandles[CGF.CurFn->getName()], VoidPtrTy); CallArgList LaunchKernelArgs; LaunchKernelArgs.add(RValue::get(Kernel), cudaLaunchKernelFD->getParamDecl(0)->getType()); @@ -456,8 +457,8 @@ // Emit the call to cudaLaunch llvm::FunctionCallee cudaLaunchFn = getLaunchFn(); - llvm::Value *Arg = - CGF.Builder.CreatePointerCast(KernelHandles[CGF.CurFn], CharPtrTy); + llvm::Value *Arg = CGF.Builder.CreatePointerCast( + KernelHandles[CGF.CurFn->getName()], CharPtrTy); CGF.EmitRuntimeCallOrInvoke(cudaLaunchFn, Arg); CGF.EmitBranch(EndBlock); @@ -551,7 +552,7 @@ llvm::Constant *NullPtr = llvm::ConstantPointerNull::get(VoidPtrTy); llvm::Value *Args[] = { &GpuBinaryHandlePtr, - Builder.CreateBitCast(KernelHandles[I.Kernel], VoidPtrTy), + Builder.CreateBitCast(KernelHandles[I.Kernel->getName()], VoidPtrTy), KernelName, KernelName, llvm::ConstantInt::get(IntTy, -1), @@ -1130,7 +1131,7 @@ StringRef Section = CGM.getLangOpts().HIP ? "hip_offloading_entries" : "cuda_offloading_entries"; for (KernelInfo &I : EmittedKernels) - OMPBuilder.emitOffloadingEntry(KernelHandles[I.Kernel], + OMPBuilder.emitOffloadingEntry(KernelHandles[I.Kernel->getName()], getDeviceSideName(cast(I.D)), 0, DeviceVarFlags::OffloadGlobalEntry, Section); @@ -1193,12 +1194,12 @@ llvm::GlobalValue *CGNVCUDARuntime::getKernelHandle(llvm::Function *F, GlobalDecl GD) { - auto Loc = KernelHandles.find(F); + auto Loc = KernelHandles.find(F->getName()); if (Loc != KernelHandles.end()) return Loc->second; if (!CGM.getLangOpts().HIP) { - KernelHandles[F] = F; + KernelHandles[F->getName()] = F; KernelStubs[F] = F; return F; } @@ -1212,7 +1213,7 @@ Var->setDSOLocal(F->isDSOLocal()); Var->setVisibility(F->getVisibility()); CGM.maybeSetTrivialComdat(*GD.getDecl(), *Var); - KernelHandles[F] = Var; + KernelHandles[F->getName()] = Var; KernelStubs[Var] = F; return Var; } diff --git a/clang/test/CodeGenCUDA/incomplete-func-ptr-type.cu b/clang/test/CodeGenCUDA/incomplete-func-ptr-type.cu new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenCUDA/incomplete-func-ptr-type.cu @@ -0,0 +1,27 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm -x hip %s -o - \ +// RUN: | FileCheck %s + +#define __global__ __attribute__((global)) +// CHECK: @_Z4kern7TempValIjE = constant ptr @_Z19__device_stub__kern7TempValIjE, align 8 +// CHECK: @0 = private unnamed_addr constant [19 x i8] c"_Z4kern7TempValIjE\00", align 1 +template +struct TempVal { + type value; +}; + +__global__ void kern(TempVal in_val); + +int main(int argc, char ** argv) { + auto* fptr = &(kern); +// CHECK: store ptr @_Z4kern7TempValIjE, ptr %fptr, align 8 + return 0; +} +// CHECK: define dso_local void @_Z19__device_stub__kern7TempValIjE(i32 %in_val.coerce) #1 { +// CHECK: %2 = call i32 @hipLaunchByPtr(ptr @_Z4kern7TempValIjE) + +// CHECK: define internal void @__hip_register_globals(ptr %0) { +// CHECK: %1 = call i32 @__hipRegisterFunction(ptr %0, ptr @_Z4kern7TempValIjE, ptr @0, ptr @0, i32 -1, ptr null, ptr null, ptr null, ptr null, ptr null) + +__global__ void kern(TempVal in_val) { +} +