diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -1612,21 +1612,12 @@ // = optimal alignment for the element type; always multiple of // PAL.getParamAlignment // size = typeallocsize of element type - Align OptimalAlign = getOptimalAlignForParam(ETy); - - // Work around a bug in ptxas. When PTX code takes address of - // byval parameter with alignment < 4, ptxas generates code to - // spill argument into memory. Alas on sm_50+ ptxas generates - // SASS code that fails with misaligned access. To work around - // the problem, make sure that we align byval parameters by at - // least 4. Matching change must be made in LowerCall() where we - // prepare parameters for the call. - // - // TODO: this will need to be undone when we get to support multi-TU - // device-side compilation as it breaks ABI compatibility with nvcc. - // Hopefully ptxas bug is fixed by then. - if (!isKernelFunc && OptimalAlign < Align(4)) - OptimalAlign = Align(4); + Align OptimalAlign = + isKernelFunc + ? getOptimalAlignForParam(ETy) + : TLI->getFunctionByValParamAlign( + F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL); + unsigned sz = DL.getTypeAllocSize(ETy); O << "\t.param .align " << OptimalAlign.value() << " .b8 "; printParamName(I, paramIndex, O); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -461,6 +461,11 @@ Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy, const DataLayout &DL) const; + /// Helper for computing alignment of a device function byval parameter. + Align getFunctionByValParamAlign(const Function *F, Type *ArgTy, + Align InitialAlign, + const DataLayout &DL) const; + /// isLegalAddressingMode - Return true if the addressing mode represented /// by AM is legal for this target, for a load/store of the specified type /// Used to guide target specific optimizations, like loop strength diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1414,13 +1414,10 @@ continue; } - Align ParamByValAlign = Outs[OIdx].Flags.getNonZeroByValAlign(); - - // Try to increase alignment. This code matches logic in LowerCall when - // alignment increase is performed to increase vectorization options. Type *ETy = Args[i].IndirectType; - Align AlignCandidate = getFunctionParamOptimizedAlign(F, ETy, DL); - ParamByValAlign = std::max(ParamByValAlign, AlignCandidate); + Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign(); + Align ParamByValAlign = + getFunctionByValParamAlign(F, ETy, InitialAlign, DL); O << ".param .align " << ParamByValAlign.value() << " .b8 "; O << "_"; @@ -1560,17 +1557,9 @@ // The ByValAlign in the Outs[OIdx].Flags is always set at this point, // so we don't need to worry whether it's naturally aligned or not. // See TargetLowering::LowerCallTo(). - ArgAlign = Outs[OIdx].Flags.getNonZeroByValAlign(); - - // Try to increase alignment to enhance vectorization options. - if (const Function *DirectCallee = CB->getCalledFunction()) - ArgAlign = std::max( - ArgAlign, getFunctionParamOptimizedAlign(DirectCallee, ETy, DL)); - - // Enforce minumum alignment of 4 to work around ptxas miscompile - // for sm_50+. See corresponding alignment adjustment in - // emitFunctionParamList() for details. - ArgAlign = std::max(ArgAlign, Align(4)); + Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign(); + ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy, + InitialAlign, DL); if (IsVAArg) VAOffset = alignTo(VAOffset, ArgAlign); } else { @@ -4510,6 +4499,29 @@ return Align(std::max(uint64_t(16), ABITypeAlign)); } +/// Helper for computing alignment of a device function byval parameter. +Align NVPTXTargetLowering::getFunctionByValParamAlign( + const Function *F, Type *ArgTy, Align InitialAlign, + const DataLayout &DL) const { + Align ArgAlign = InitialAlign; + // Try to increase alignment to enhance vectorization options. + if (F) + ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL)); + + // Work around a bug in ptxas. When PTX code takes address of + // byval parameter with alignment < 4, ptxas generates code to + // spill argument into memory. Alas on sm_50+ ptxas generates + // SASS code that fails with misaligned access. To work around + // the problem, make sure that we align byval parameters by at + // least 4. + // TODO: this will need to be undone when we get to support multi-TU + // device-side compilation as it breaks ABI compatibility with nvcc. + // Hopefully ptxas bug is fixed by then. + ArgAlign = std::max(ArgAlign, Align(4)); + + return ArgAlign; +} + /// isLegalAddressingMode - Return true if the addressing mode represented /// by AM is legal for this target, for a load/store of the specified type. /// Used to guide target specific optimizations, like loop strength reduction diff --git a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll --- a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll +++ b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll @@ -37,9 +37,9 @@ %fp = call ptr @usefp(ptr @callee) ; CHECK: .param .align 4 .b8 param0[4]; ; CHECK: st.param.v2.b16 [param0+0] - ; CHECK: .callprototype ()_ (.param .align 2 .b8 _[4]); + ; CHECK: .callprototype ()_ (.param .align 4 .b8 _[4]); call void %fp(ptr byval(%"class.complex") null) ret void } -declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE() +declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE(i32, i32, ptr byval(%"class.complex")) diff --git a/llvm/test/CodeGen/NVPTX/param-align.ll b/llvm/test/CodeGen/NVPTX/param-align.ll --- a/llvm/test/CodeGen/NVPTX/param-align.ll +++ b/llvm/test/CodeGen/NVPTX/param-align.ll @@ -43,3 +43,24 @@ call void @t4(ptr byval(i8) %x) ret void } + +;;; Make sure we adjust alignment for a function prototype +;;; in case of an inderect call. + +declare ptr @getfp(i32 %n) +%struct.half2 = type { half, half } +define ptx_device void @t6() { +; CHECK: .func t6 + %fp = call ptr @getfp(i32 0) +; CHECK: prototype_2 : .callprototype ()_ (.param .align 8 .b8 _[8]); + call void %fp(ptr byval(double) null); + + %fp2 = call ptr @getfp(i32 1) +; CHECK: prototype_4 : .callprototype ()_ (.param .align 4 .b8 _[4]); + call void %fp(ptr byval(%struct.half2) null); + + %fp3 = call ptr @getfp(i32 2) +; CHECK: prototype_6 : .callprototype ()_ (.param .align 4 .b8 _[1]); + call void %fp(ptr byval(i8) null); + ret void +}