diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -89,6 +89,12 @@ cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."), cl::init(true)); +static cl::opt ForceMinByValParamAlign( + "nvptx-force-min-byval-param-align", cl::Hidden, + cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval" + " params of device functions."), + cl::init(false)); + int NVPTXTargetLowering::getDivF32Level() const { if (UsePrecDivF32.getNumOccurrences() > 0) { // If nvptx-prec-div32=N is used on the command-line, always honor it @@ -4502,16 +4508,17 @@ if (F) ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL)); - // Work around a bug in ptxas. When PTX code takes address of + // Old ptx versions have a bug. When PTX code takes address of // byval parameter with alignment < 4, ptxas generates code to // spill argument into memory. Alas on sm_50+ ptxas generates // SASS code that fails with misaligned access. To work around // the problem, make sure that we align byval parameters by at - // least 4. - // TODO: this will need to be undone when we get to support multi-TU - // device-side compilation as it breaks ABI compatibility with nvcc. - // Hopefully ptxas bug is fixed by then. - ArgAlign = std::max(ArgAlign, Align(4)); + // least 4. This bug seems to be fixed at least starting from + // ptxas > 9.0. + // TODO: remove this after verifying the bug is not reproduced + // on non-deprecated ptxas versions. + if (ForceMinByValParamAlign) + ArgAlign = std::max(ArgAlign, Align(4)); return ArgAlign; } diff --git a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll --- a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll +++ b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll @@ -13,8 +13,9 @@ %"class.sycl::_V1::detail::half_impl::half" = type { half } %complex_half = type { half, half } -; CHECK: .param .align 4 .b8 param2[4]; -; CHECK: st.param.v2.b16 [param2+0], {%h2, %h1}; +; CHECK: .param .align 2 .b8 param2[4]; +; CHECK: st.param.b16 [param2+0], %h1; +; CHECK: st.param.b16 [param2+2], %h2; ; CHECK: .param .align 2 .b8 retval0[4]; ; CHECK: call.uni (retval0), ; CHECK-NEXT: _Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE, @@ -29,15 +30,16 @@ ;; declare ptr @usefp(ptr %fp) ; CHECK: .func callee( -; CHECK-NEXT: .param .align 4 .b8 callee_param_0[4] +; CHECK-NEXT: .param .align 2 .b8 callee_param_0[4] define internal void @callee(ptr byval(%"class.complex") %byval_arg) { ret void } define void @boom() { %fp = call ptr @usefp(ptr @callee) - ; CHECK: .param .align 4 .b8 param0[4]; - ; CHECK: st.param.v2.b16 [param0+0] - ; CHECK: .callprototype ()_ (.param .align 4 .b8 _[4]); + ; CHECK: .param .align 2 .b8 param0[4]; + ; CHECK: st.param.b16 [param0+0], %h1; + ; CHECK: st.param.b16 [param0+2], %h2; + ; CHECK: .callprototype ()_ (.param .align 2 .b8 _[4]); call void %fp(ptr byval(%"class.complex") null) ret void } diff --git a/llvm/test/CodeGen/NVPTX/param-align.ll b/llvm/test/CodeGen/NVPTX/param-align.ll --- a/llvm/test/CodeGen/NVPTX/param-align.ll +++ b/llvm/test/CodeGen/NVPTX/param-align.ll @@ -1,5 +1,7 @@ -; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s +; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s --check-prefixes=CHECK,NOALIGN4 +; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-force-min-byval-param-align | FileCheck %s --check-prefixes=CHECK,ALIGN4 ; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 | %ptxas-verify %} +; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 -nvptx-force-min-byval-param-align | %ptxas-verify %} ;;; Need 4-byte alignment on ptr passed byval define ptx_device void @t1(ptr byval(float) %x) { @@ -25,20 +27,21 @@ ret void } -;;; Need at least 4-byte alignment in order to avoid miscompilation by -;;; ptxas for sm_50+ define ptx_device void @t4(ptr byval(i8) %x) { ; CHECK: .func t4 -; CHECK: .param .align 4 .b8 t4_param_0[1] +; NOALIGN4: .param .align 1 .b8 t4_param_0[1] +; ALIGN4: .param .align 4 .b8 t4_param_0[1] ret void } ;;; Make sure we adjust alignment at the call site as well. define ptx_device void @t5(ptr align 2 byval(i8) %x) { ; CHECK: .func t5 -; CHECK: .param .align 4 .b8 t5_param_0[1] +; NOALIGN4: .param .align 2 .b8 t5_param_0[1] +; ALIGN4: .param .align 4 .b8 t5_param_0[1] ; CHECK: { -; CHECK: .param .align 4 .b8 param0[1]; +; NOALIGN4: .param .align 1 .b8 param0[1]; +; ALIGN4: .param .align 4 .b8 param0[1]; ; CHECK: call.uni call void @t4(ptr byval(i8) %x) ret void @@ -56,11 +59,13 @@ call void %fp(ptr byval(double) null); %fp2 = call ptr @getfp(i32 1) -; CHECK: prototype_4 : .callprototype ()_ (.param .align 4 .b8 _[4]); +; NOALIGN4: prototype_4 : .callprototype ()_ (.param .align 2 .b8 _[4]); +; ALIGN4: prototype_4 : .callprototype ()_ (.param .align 4 .b8 _[4]); call void %fp(ptr byval(%struct.half2) null); %fp3 = call ptr @getfp(i32 2) -; CHECK: prototype_6 : .callprototype ()_ (.param .align 4 .b8 _[1]); +; NOALIGN4: prototype_6 : .callprototype ()_ (.param .align 1 .b8 _[1]); +; ALIGN4: prototype_6 : .callprototype ()_ (.param .align 4 .b8 _[1]); call void %fp(ptr byval(i8) null); ret void }