diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -4326,8 +4326,13 @@ const uint64_t ABITypeAlign = DL.getABITypeAlign(ArgTy).value(); // If a function has linkage different from internal or private, we - // must use default ABI alignment as external users rely on it. - if (!(F && F->hasLocalLinkage())) + // must use default ABI alignment as external users rely on it. Same + // for a function that may be called from a function pointer. + if (!F || !F->hasLocalLinkage()) || + F->hasAddressTaken(/*Users=*/nullptr, + /*IgnoreCallbackUses=*/false, + /*IgnoreAssumeLikeCalls=*/true, + /*IngoreLLVMUsed=*/true)) return Align(ABITypeAlign); assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage"); diff --git a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll --- a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll +++ b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll @@ -13,16 +13,34 @@ %"class.sycl::_V1::detail::half_impl::half" = type { half } %complex_half = type { half, half } +; CHECK: .param .align 4 .b8 param2[4]; +; CHECK: st.param.v2.b16 [param2+0], {%h2, %h1}; +; CHECK: .param .align 2 .b8 retval0[4]; +; CHECK: call.uni (retval0), +; CHECK-NEXT: _Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE, define weak_odr void @foo() { entry: %call.i.i.i = tail call %"class.complex" bitcast (%complex_half ()* @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE to %"class.complex" (i32, i32, %"class.complex"*)*)(i32 0, i32 0, %"class.complex"* byval(%"class.complex") null) ret void } -declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE() +;; Function pointers can escape, so we have to use a conservative +;; alignment for a function that has address taken. +;; +declare i8* @usefp(i8* %fp) +; CHECK: .func callee( +; CHECK-NEXT: .param .align 4 .b8 callee_param_0[4] +define internal void @callee(%"class.complex"* byval(%"class.complex") %byval_arg) { + ret void +} +define void @boom() { + %fp = call i8* @usefp(i8* bitcast (void (%"class.complex"*)* @callee to i8*)) + %cast = bitcast i8* %fp to void (%"class.complex"*)* + ; CHECK: .param .align 4 .b8 param0[4]; + ; CHECK: st.param.v2.b16 [param0+0] + ; CHECK: .callprototype ()_ (.param .align 2 .b8 _[4]); + call void %cast(%"class.complex"* byval(%"class.complex") null) + ret void +} -; CHECK: .param .align 4 .b8 param2[4]; -; CHECK: st.param.v2.b16 [param2+0], {%h2, %h1}; -; CHECK: .param .align 2 .b8 retval0[4]; -; CHECK: call.uni (retval0), -; CHECK-NEXT: _Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE, +declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE()