diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -1465,7 +1465,6 @@ bool isKernelFunc = isKernelFunction(*F); bool isABI = (STI.getSmVersion() >= 20); bool hasImageHandles = STI.hasImageHandles(); - MVT thePointerTy = TLI->getPointerTy(DL); if (F->arg_empty()) { O << "()\n"; @@ -1538,10 +1537,17 @@ } // Just a scalar auto *PTy = dyn_cast(Ty); + unsigned PTySizeInBits = 0; + if (PTy) { + PTySizeInBits = + TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits(); + assert(PTySizeInBits && "Invalid pointer size"); + } + if (isKernelFunc) { if (PTy) { // Special handling for pointer arguments to kernel - O << "\t.param .u" << thePointerTy.getSizeInBits() << " "; + O << "\t.param .u" << PTySizeInBits << " "; if (static_cast(TM).getDrvInterface() != NVPTX::CUDA) { @@ -1584,9 +1590,10 @@ if (isa(Ty)) { sz = cast(Ty)->getBitWidth(); sz = promoteScalarArgumentSize(sz); - } else if (isa(Ty)) - sz = thePointerTy.getSizeInBits(); - else if (Ty->isHalfTy()) + } else if (PTy) { + assert(PTySizeInBits && "Invalid pointer size"); + sz = PTySizeInBits; + } else if (Ty->isHalfTy()) // PTX ABI requires all scalar parameters to be at least 32 // bits in size. fp16 normally uses .b16 as its storage type // in PTX, so its size must be adjusted here, too. diff --git a/llvm/test/CodeGen/NVPTX/short-ptr.ll b/llvm/test/CodeGen/NVPTX/short-ptr.ll --- a/llvm/test/CodeGen/NVPTX/short-ptr.ll +++ b/llvm/test/CodeGen/NVPTX/short-ptr.ll @@ -1,6 +1,6 @@ ; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 | FileCheck %s --check-prefix CHECK-DEFAULT ; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s --check-prefix CHECK-DEFAULT-32 -; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -nvptx-short-ptr | FileCheck %s --check-prefixes CHECK-SHORT-SHARED,CHECK-SHORT-CONST +; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -nvptx-short-ptr | FileCheck %s --check-prefixes CHECK-SHORT-SHARED,CHECK-SHORT-CONST,CHECK-SHORT-LOCAL ; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 | %ptxas-verify %} ; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 | %ptxas-verify %} @@ -15,3 +15,30 @@ ; CHECK-DEFAULT-32: .visible .const .align 8 .u32 c ; CHECK-SHORT-CONST: .visible .const .align 8 .u32 c @c = local_unnamed_addr addrspace(4) global i32 addrspace(4)* null, align 8 + +declare void @use(i8 %arg); + +; CHECK-DEFAULT: .param .b64 test1_param_0 +; CHECK-DEFAULT-32: .param .b32 test1_param_0 +; CHECK-SHORT-LOCAL: .param .b32 test1_param_0 +define void @test1(i8 addrspace(5)* %local) { + ; CHECK-DEFAULT: ld.param.u64 %rd{{.*}}, [test1_param_0]; + ; CHECK-DEFAULT-32: ld.param.u32 %r{{.*}}, [test1_param_0]; + ; CHECK-SHORT-LOCAL: ld.param.u32 %r{{.*}}, [test1_param_0]; + %v = load i8, i8 addrspace(5)* %local + call void @use(i8 %v) + ret void +} + +define void @test2() { + %v = alloca i8 + %cast = addrspacecast i8* %v to i8 addrspace(5)* + ; CHECK-DEFAULT: .param .b64 param0; + ; CHECK-DEFAULT: st.param.b64 + ; CHECK-DEFAULT-32: .param .b32 param0; + ; CHECK-DEFAULT-32: st.param.b32 + ; CHECK-SHORT-LOCAL: .param .b32 param0; + ; CHECK-SHORT-LOCAL: st.param.b32 + call void @test1(i8 addrspace(5)* %cast) + ret void +}