diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1422,6 +1422,11 @@ Align AlignCandidate = getFunctionParamOptimizedAlign(F, ETy, DL); ParamByValAlign = std::max(ParamByValAlign, AlignCandidate); + // Enforce minumum alignment of 4 to work around ptxas miscompile + // for sm_50+. See corresponding alignment adjustment in + // emitFunctionParamList() for details. + ParamByValAlign = std::max(ParamByValAlign, Align(4)); + O << ".param .align " << ParamByValAlign.value() << " .b8 "; O << "_"; O << "[" << Outs[OIdx].Flags.getByValSize() << "]"; diff --git a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll --- a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll +++ b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll @@ -37,9 +37,9 @@ %fp = call ptr @usefp(ptr @callee) ; CHECK: .param .align 4 .b8 param0[4]; ; CHECK: st.param.v2.b16 [param0+0] - ; CHECK: .callprototype ()_ (.param .align 2 .b8 _[4]); + ; CHECK: .callprototype ()_ (.param .align 4 .b8 _[4]); call void %fp(ptr byval(%"class.complex") null) ret void } -declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE() +declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE(i32, i32, ptr byval(%"class.complex")) diff --git a/llvm/test/CodeGen/NVPTX/param-align.ll b/llvm/test/CodeGen/NVPTX/param-align.ll --- a/llvm/test/CodeGen/NVPTX/param-align.ll +++ b/llvm/test/CodeGen/NVPTX/param-align.ll @@ -43,3 +43,24 @@ call void @t4(ptr byval(i8) %x) ret void } + +;;; Make sure we adjust alignment for a function prototype +;;; in case of an inderec call. + +declare ptr @getfp(i32 %n) +%struct.half2 = type { half, half } +define ptx_device void @t6() { +; CHECK: .func t6 + %fp = call ptr @getfp(i32 0) +; CHECK: prototype_2 : .callprototype ()_ (.param .align 8 .b8 _[8]); + call void %fp(ptr byval(double) null); + + %fp2 = call ptr @getfp(i32 1) +; CHECK: prototype_4 : .callprototype ()_ (.param .align 4 .b8 _[4]); + call void %fp(ptr byval(%struct.half2) null); + + %fp3 = call ptr @getfp(i32 2) +; CHECK: prototype_6 : .callprototype ()_ (.param .align 4 .b8 _[1]); + call void %fp(ptr byval(i8) null); + ret void +}