Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1429,22 +1429,8 @@ // Check if we have call alignment metadata if (getAlign(*CI, Idx, Alignment)) return Align(Alignment); - - const Value *CalleeV = CI->getCalledOperand(); - // Ignore any bitcast instructions - while (isa(CalleeV)) { - const ConstantExpr *CE = cast(CalleeV); - if (!CE->isCast()) - break; - // Look through the bitcast - CalleeV = cast(CalleeV)->getOperand(0); - } - - // We have now looked past all of the bitcasts. Do we finally have a - // Function? - if (const auto *CalleeF = dyn_cast(CalleeV)) - DirectCallee = CalleeF; } + DirectCallee = getMaybeBitcastedCallee(CB); } // Check for function alignment information if we found that the @@ -1521,7 +1507,7 @@ // Try to increase alignment to enhance vectorization options. ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign( - CB->getCalledFunction(), ETy, DL)); + getMaybeBitcastedCallee(CB), ETy, DL)); // Enforce minumum alignment of 4 to work around ptxas miscompile // for sm_50+. See corresponding alignment adjustment in @@ -4341,7 +4327,7 @@ // If a function has linkage different from internal or private, we // must use default ABI alignment as external users rely on it. - if (!F->hasLocalLinkage()) + if (!(F && F->hasLocalLinkage())) return Align(ABITypeAlign); assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage"); Index: llvm/lib/Target/NVPTX/NVPTXUtilities.h =================================================================== --- llvm/lib/Target/NVPTX/NVPTXUtilities.h +++ llvm/lib/Target/NVPTX/NVPTXUtilities.h @@ -58,6 +58,7 @@ bool getAlign(const Function &, unsigned index, unsigned &); bool getAlign(const CallInst &, unsigned index, unsigned &); +Function *getMaybeBitcastedCallee(const CallBase *CB); // PTX ABI requires all scalar argument/return values to have // bit-size as a power of two of at least 32 bits. Index: llvm/lib/Target/NVPTX/NVPTXUtilities.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXUtilities.cpp +++ llvm/lib/Target/NVPTX/NVPTXUtilities.cpp @@ -324,4 +324,8 @@ return false; } +Function *getMaybeBitcastedCallee(const CallBase *CB) { + return dyn_cast(CB->getCalledOperand()->stripPointerCasts()); +} + } // namespace llvm Index: llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll @@ -0,0 +1,28 @@ +; RUN: llc < %s -march=nvptx -mcpu=sm_50 -verify-machineinstrs | FileCheck %s +; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_50 -verify-machineinstrs | %ptxas-verify %} + +; calls with a bitcasted function symbol should be fine, but in combination with +; a byval attribute were causing a segfault during isel. This testcase was +; reduced from a SYCL kernel using aggregate types which ended up being passed +; `byval` + +target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64" +target triple = "nvptx64-nvidia-cuda" + +%"class.complex" = type { %"class.sycl::_V1::detail::half_impl::half", %"class.sycl::_V1::detail::half_impl::half" } +%"class.sycl::_V1::detail::half_impl::half" = type { half } +%complex_half = type { half, half } + +define weak_odr void @foo() { +entry: + %call.i.i.i = tail call %"class.complex" bitcast (%complex_half ()* @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE to %"class.complex" (i32, i32, %"class.complex"*)*)(i32 0, i32 0, %"class.complex"* byval(%"class.complex") null) + ret void +} + +declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE() + +; CHECK: .param .align 4 .b8 param2[4]; +; CHECK: st.param.v2.b16 [param2+0], {%h2, %h1}; +; CHECK: .param .align 2 .b8 retval0[4]; +; CHECK: call.uni (retval0), +; CHECK-NEXT: _Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE,