diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -356,8 +356,7 @@ // PTX ABI requires all scalar return values to be at least 32 // bits in size. fp16 normally uses .b16 as its storage type in // PTX, so its size must be adjusted here, too. - if (size < 32) - size = 32; + size = promoteScalarArgumentSize(size); O << ".param .b" << size << " func_retval0"; } else if (isa(Ty)) { @@ -386,8 +385,8 @@ for (unsigned j = 0, je = elems; j != je; ++j) { unsigned sz = elemtype.getSizeInBits(); - if (elemtype.isInteger() && (sz < 32)) - sz = 32; + if (elemtype.isInteger()) + sz = promoteScalarArgumentSize(sz); O << ".reg .b" << sz << " func_retval" << idx; if (j < je - 1) O << ", "; @@ -1576,8 +1575,7 @@ unsigned sz = 0; if (isa(Ty)) { sz = cast(Ty)->getBitWidth(); - if (sz < 32) - sz = 32; + sz = promoteScalarArgumentSize(sz); } else if (isa(Ty)) sz = thePointerTy.getSizeInBits(); else if (Ty->isHalfTy()) @@ -1641,8 +1639,8 @@ for (unsigned j = 0, je = elems; j != je; ++j) { unsigned sz = elemtype.getSizeInBits(); - if (elemtype.isInteger() && (sz < 32)) - sz = 32; + if (elemtype.isInteger()) + sz = promoteScalarArgumentSize(sz); O << "\t.reg .b" << sz << " "; printParamName(I, paramIndex, O); if (j < je - 1) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -206,6 +206,40 @@ } } +/// PromoteScalarIntegerPTX +/// Used to make sure the arguments/returns are suitable for passing +/// and promote them to a larger size if they're not. +/// +/// The promoted type is placed in \p PromoteVT if the function returns true. +static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT) { + if (VT.isScalarInteger()) { + switch (PowerOf2Ceil(VT.getFixedSizeInBits())) { + default: + llvm_unreachable( + "Promotion is not suitable for scalars of size larger than 64-bits"); + case 1: + *PromotedVT = MVT::i1; + break; + case 2: + case 4: + case 8: + *PromotedVT = MVT::i8; + break; + case 16: + *PromotedVT = MVT::i16; + break; + case 32: + *PromotedVT = MVT::i32; + break; + case 64: + *PromotedVT = MVT::i64; + break; + } + return EVT(*PromotedVT) != VT; + } + return false; +} + // Check whether we can merge loads/stores of some of the pieces of a // flattened function parameter or return value into a single vector // load/store. @@ -1291,8 +1325,7 @@ // PTX ABI requires all scalar return values to be at least 32 // bits in size. fp16 normally uses .b16 as its storage type in // PTX, so its size must be adjusted here, too. - if (size < 32) - size = 32; + size = promoteScalarArgumentSize(size); O << ".param .b" << size << " _"; } else if (isa(retTy)) { @@ -1343,8 +1376,7 @@ unsigned sz = 0; if (isa(Ty)) { sz = cast(Ty)->getBitWidth(); - if (sz < 32) - sz = 32; + sz = promoteScalarArgumentSize(sz); } else if (isa(Ty)) { sz = PtrVT.getSizeInBits(); } else if (Ty->isHalfTy()) @@ -1515,11 +1547,11 @@ NeedAlign = true; } else { // declare .param .b .param; - if ((VT.isInteger() || VT.isFloatingPoint()) && TypeSize < 4) { + if (VT.isInteger() || VT.isFloatingPoint()) { // PTX ABI requires integral types to be at least 32 bits in // size. FP16 is loaded/stored using i16, so it's handled // here as well. - TypeSize = 4; + TypeSize = promoteScalarArgumentSize(TypeSize * 8) / 8; } SDValue DeclareScalarParamOps[] = { Chain, DAG.getConstant(ParamCount, dl, MVT::i32), @@ -1556,6 +1588,17 @@ } SDValue StVal = OutVals[OIdx]; + + MVT PromotedVT; + if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) { + EltVT = EVT(PromotedVT); + } + if (PromoteScalarIntegerPTX(StVal.getValueType(), &PromotedVT)) { + llvm::ISD::NodeType Ext = + Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + StVal = DAG.getNode(Ext, dl, PromotedVT, StVal); + } + if (IsByVal) { auto PtrVT = getPointerTy(DL); SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal, @@ -1638,9 +1681,7 @@ // Plus, this behavior is consistent with nvcc's. if (RetTy->isFloatingPointTy() || RetTy->isPointerTy() || (RetTy->isIntegerTy() && !RetTy->isIntegerTy(128))) { - // Scalar needs to be at least 32bit wide - if (resultsz < 32) - resultsz = 32; + resultsz = promoteScalarArgumentSize(resultsz); SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue); SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32), DAG.getConstant(resultsz, dl, MVT::i32), @@ -1778,6 +1819,14 @@ EVT TheLoadType = VTs[i]; EVT EltType = Ins[i].VT; Align EltAlign = commonAlignment(RetAlign, Offsets[i]); + MVT PromotedVT; + + if (PromoteScalarIntegerPTX(TheLoadType, &PromotedVT)) { + TheLoadType = EVT(PromotedVT); + EltType = EVT(PromotedVT); + needTruncate = true; + } + if (ExtendIntegerRetVal) { TheLoadType = MVT::i32; EltType = MVT::i32; @@ -2558,6 +2607,13 @@ // v2f16 was loaded as an i32. Now we must bitcast it back. else if (EltVT == MVT::v2f16) Elt = DAG.getNode(ISD::BITCAST, dl, MVT::v2f16, Elt); + + // If a promoted integer type is used, truncate down to the original + MVT PromotedVT; + if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) { + Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt); + } + // Extend the element if necessary (e.g. an i8 is loaded // into an i16 register) if (Ins[InsIdx].VT.isInteger() && @@ -2627,11 +2683,26 @@ return Chain; const DataLayout &DL = DAG.getDataLayout(); + SmallVector PromotedOutVals; SmallVector VTs; SmallVector Offsets; ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets); assert(VTs.size() == OutVals.size() && "Bad return value decomposition"); + for (unsigned i = 0, e = VTs.size(); i != e; ++i) { + SDValue PromotedOutVal = OutVals[i]; + MVT PromotedVT; + if (PromoteScalarIntegerPTX(VTs[i], &PromotedVT)) { + VTs[i] = EVT(PromotedVT); + } + if (PromoteScalarIntegerPTX(PromotedOutVal.getValueType(), &PromotedVT)) { + llvm::ISD::NodeType Ext = + Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + PromotedOutVal = DAG.getNode(Ext, dl, PromotedVT, PromotedOutVal); + } + PromotedOutVals.push_back(PromotedOutVal); + } + auto VectorInfo = VectorizePTXValueVTs( VTs, Offsets, RetTy->isSized() ? getFunctionParamOptimizedAlign(&F, RetTy, DL) @@ -2652,12 +2723,14 @@ StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32)); } - SDValue RetVal = OutVals[i]; + SDValue OutVal = OutVals[i]; + SDValue RetVal = PromotedOutVals[i]; + if (ExtendIntegerRetVal) { RetVal = DAG.getNode(Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, dl, MVT::i32, RetVal); - } else if (RetVal.getValueSizeInBits() < 16) { + } else if (OutVal.getValueSizeInBits() < 16) { // Use 16-bit registers for small load-stores as it's the // smallest general purpose register size supported by NVPTX. RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal); diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h --- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h +++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h @@ -59,6 +59,16 @@ bool getAlign(const Function &, unsigned index, unsigned &); bool getAlign(const CallInst &, unsigned index, unsigned &); +// PTX ABI requires all scalar argument/return values to have +// bit-size as a power of two of at least 32 bits. +inline unsigned promoteScalarArgumentSize(unsigned size) { + if (size <= 32) + return 32; + else if (size <= 64) + return 64; + else + return size; +} } #endif diff --git a/llvm/test/CodeGen/NVPTX/param-load-store.ll b/llvm/test/CodeGen/NVPTX/param-load-store.ll --- a/llvm/test/CodeGen/NVPTX/param-load-store.ll +++ b/llvm/test/CodeGen/NVPTX/param-load-store.ll @@ -132,6 +132,40 @@ ret <5 x i1> %r; } +; CHECK: .func (.param .b32 func_retval0) +; CHECK-LABEL: test_i2( +; CHECK-NEXT: .param .b32 test_i2_param_0 +; CHECK: ld.param.u8 {{%rs[0-9]+}}, [test_i2_param_0]; +; CHECK: .param .b32 param0; +; CHECK: st.param.b32 [param0+0], {{%r[0-9]+}}; +; CHECK: .param .b32 retval0; +; CHECK: call.uni (retval0), +; CHECK: test_i2, +; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0+0]; +; CHECK: st.param.b32 [func_retval0+0], {{%r[0-9]+}}; +; CHECK-NEXT: ret; +define i2 @test_i2(i2 %a) { + %r = tail call i2 @test_i2(i2 %a); + ret i2 %r; +} + +; CHECK: .func (.param .b32 func_retval0) +; CHECK-LABEL: test_i3( +; CHECK-NEXT: .param .b32 test_i3_param_0 +; CHECK: ld.param.u8 {{%rs[0-9]+}}, [test_i3_param_0]; +; CHECK: .param .b32 param0; +; CHECK: st.param.b32 [param0+0], {{%r[0-9]+}}; +; CHECK: .param .b32 retval0; +; CHECK: call.uni (retval0), +; CHECK: test_i3, +; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0+0]; +; CHECK: st.param.b32 [func_retval0+0], {{%r[0-9]+}}; +; CHECK-NEXT: ret; +define i3 @test_i3(i3 %a) { + %r = tail call i3 @test_i3(i3 %a); + ret i3 %r; +} + ; Unsigned i8 is loaded directly into 32-bit register. ; CHECK: .func (.param .b32 func_retval0) ; CHECK-LABEL: test_i8( @@ -234,6 +268,22 @@ ret <5 x i8> %r; } +; CHECK: .func (.param .b32 func_retval0) +; CHECK-LABEL: test_i11( +; CHECK-NEXT: .param .b32 test_i11_param_0 +; CHECK: ld.param.u16 {{%rs[0-9]+}}, [test_i11_param_0]; +; CHECK: st.param.b32 [param0+0], {{%r[0-9]+}}; +; CHECK: .param .b32 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_i11, +; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0+0]; +; CHECK: st.param.b32 [func_retval0+0], {{%r[0-9]+}}; +; CHECK-NEXT: ret; +define i11 @test_i11(i11 %a) { + %r = tail call i11 @test_i11(i11 %a); + ret i11 %r; +} + ; CHECK: .func (.param .b32 func_retval0) ; CHECK-LABEL: test_i16( ; CHECK-NEXT: .param .b32 test_i16_param_0 @@ -474,6 +524,77 @@ ret <9 x half> %r; } +; CHECK: .func (.param .b32 func_retval0) +; CHECK-LABEL: test_i19( +; CHECK-NEXT: .param .b32 test_i19_param_0 +; CHECK-DAG: ld.param.u16 {{%r[0-9]+}}, [test_i19_param_0]; +; CHECK-DAG: ld.param.u8 {{%r[0-9]+}}, [test_i19_param_0+2]; +; CHECK: .param .b32 param0; +; CHECK: st.param.b32 [param0+0], {{%r[0-9]+}}; +; CHECK: .param .b32 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_i19, +; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0+0]; +; CHECK: st.param.b32 [func_retval0+0], {{%r[0-9]+}}; +; CHECK-NEXT: ret; +define i19 @test_i19(i19 %a) { + %r = tail call i19 @test_i19(i19 %a); + ret i19 %r; +} + +; CHECK: .func (.param .b32 func_retval0) +; CHECK-LABEL: test_i23( +; CHECK-NEXT: .param .b32 test_i23_param_0 +; CHECK-DAG: ld.param.u16 {{%r[0-9]+}}, [test_i23_param_0]; +; CHECK-DAG: ld.param.u8 {{%r[0-9]+}}, [test_i23_param_0+2]; +; CHECK: .param .b32 param0; +; CHECK: st.param.b32 [param0+0], {{%r[0-9]+}}; +; CHECK: .param .b32 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_i23, +; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0+0]; +; CHECK: st.param.b32 [func_retval0+0], {{%r[0-9]+}}; +; CHECK-NEXT: ret; +define i23 @test_i23(i23 %a) { + %r = tail call i23 @test_i23(i23 %a); + ret i23 %r; +} + +; CHECK: .func (.param .b32 func_retval0) +; CHECK-LABEL: test_i24( +; CHECK-NEXT: .param .b32 test_i24_param_0 +; CHECK-DAG: ld.param.u8 {{%r[0-9]+}}, [test_i24_param_0+2]; +; CHECK-DAG: ld.param.u16 {{%r[0-9]+}}, [test_i24_param_0]; +; CHECK: .param .b32 param0; +; CHECK: st.param.b32 [param0+0], {{%r[0-9]+}}; +; CHECK: .param .b32 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_i24, +; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0+0]; +; CHECK: st.param.b32 [func_retval0+0], {{%r[0-9]+}}; +; CHECK-NEXT: ret; +define i24 @test_i24(i24 %a) { + %r = tail call i24 @test_i24(i24 %a); + ret i24 %r; +} + +; CHECK: .func (.param .b32 func_retval0) +; CHECK-LABEL: test_i29( +; CHECK-NEXT: .param .b32 test_i29_param_0 +; CHECK: ld.param.u32 {{%r[0-9]+}}, [test_i29_param_0]; +; CHECK: .param .b32 param0; +; CHECK: st.param.b32 [param0+0], {{%r[0-9]+}}; +; CHECK: .param .b32 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_i29, +; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0+0]; +; CHECK: st.param.b32 [func_retval0+0], {{%r[0-9]+}}; +; CHECK-NEXT: ret; +define i29 @test_i29(i29 %a) { + %r = tail call i29 @test_i29(i29 %a); + ret i29 %r; +} + ; CHECK: .func (.param .b32 func_retval0) ; CHECK-LABEL: test_i32( ; CHECK-NEXT: .param .b32 test_i32_param_0 @@ -567,6 +688,115 @@ ret float %r; } +; CHECK: .func (.param .b64 func_retval0) +; CHECK-LABEL: test_i40( +; CHECK-NEXT: .param .b64 test_i40_param_0 +; CHECK-DAG: ld.param.u8 {{%rd[0-9]+}}, [test_i40_param_0+4]; +; CHECK-DAG: ld.param.u32 {{%rd[0-9]+}}, [test_i40_param_0]; +; CHECK: .param .b64 param0; +; CHECK: st.param.b64 [param0+0], {{%rd[0-9]+}}; +; CHECK: .param .b64 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_i40, +; CHECK: ld.param.b64 {{%rd[0-9]+}}, [retval0+0]; +; CHECK: st.param.b64 [func_retval0+0], {{%rd[0-9]+}}; +; CHECK-NEXT: ret; +define i40 @test_i40(i40 %a) { + %r = tail call i40 @test_i40(i40 %a); + ret i40 %r; +} + +; CHECK: .func (.param .b64 func_retval0) +; CHECK-LABEL: test_i47( +; CHECK-NEXT: .param .b64 test_i47_param_0 +; CHECK-DAG: ld.param.u16 {{%rd[0-9]+}}, [test_i47_param_0+4]; +; CHECK-DAG: ld.param.u32 {{%rd[0-9]+}}, [test_i47_param_0]; +; CHECK: .param .b64 param0; +; CHECK: st.param.b64 [param0+0], {{%rd[0-9]+}}; +; CHECK: .param .b64 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_i47, +; CHECK: ld.param.b64 {{%rd[0-9]+}}, [retval0+0]; +; CHECK: st.param.b64 [func_retval0+0], {{%rd[0-9]+}}; +; CHECK-NEXT: ret; +define i47 @test_i47(i47 %a) { + %r = tail call i47 @test_i47(i47 %a); + ret i47 %r; +} + +; CHECK: .func (.param .b64 func_retval0) +; CHECK-LABEL: test_i48( +; CHECK-NEXT: .param .b64 test_i48_param_0 +; CHECK-DAG: ld.param.u16 {{%rd[0-9]+}}, [test_i48_param_0+4]; +; CHECK-DAG: ld.param.u32 {{%rd[0-9]+}}, [test_i48_param_0]; +; CHECK: .param .b64 param0; +; CHECK: st.param.b64 [param0+0], {{%rd[0-9]+}}; +; CHECK: .param .b64 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_i48, +; CHECK: ld.param.b64 {{%rd[0-9]+}}, [retval0+0]; +; CHECK: st.param.b64 [func_retval0+0], {{%rd[0-9]+}}; +; CHECK-NEXT: ret; +define i48 @test_i48(i48 %a) { + %r = tail call i48 @test_i48(i48 %a); + ret i48 %r; +} + +; CHECK: .func (.param .b64 func_retval0) +; CHECK-LABEL: test_i51( +; CHECK-NEXT: .param .b64 test_i51_param_0 +; CHECK-DAG: ld.param.u8 {{%rd[0-9]+}}, [test_i51_param_0+6]; +; CHECK-DAG: ld.param.u16 {{%rd[0-9]+}}, [test_i51_param_0+4]; +; CHECK-DAG: ld.param.u32 {{%rd[0-9]+}}, [test_i51_param_0]; +; CHECK: .param .b64 param0; +; CHECK: st.param.b64 [param0+0], {{%rd[0-9]+}}; +; CHECK: .param .b64 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_i51, +; CHECK: ld.param.b64 {{%rd[0-9]+}}, [retval0+0]; +; CHECK: st.param.b64 [func_retval0+0], {{%rd[0-9]+}}; +; CHECK-NEXT: ret; +define i51 @test_i51(i51 %a) { + %r = tail call i51 @test_i51(i51 %a); + ret i51 %r; +} + +; CHECK: .func (.param .b64 func_retval0) +; CHECK-LABEL: test_i56( +; CHECK-NEXT: .param .b64 test_i56_param_0 +; CHECK-DAG: ld.param.u8 {{%rd[0-9]+}}, [test_i56_param_0+6]; +; CHECK-DAG: ld.param.u16 {{%rd[0-9]+}}, [test_i56_param_0+4]; +; CHECK-DAG: ld.param.u32 {{%rd[0-9]+}}, [test_i56_param_0]; +; CHECK: .param .b64 param0; +; CHECK: st.param.b64 [param0+0], {{%rd[0-9]+}}; +; CHECK: .param .b64 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_i56, +; CHECK: ld.param.b64 {{%rd[0-9]+}}, [retval0+0]; +; CHECK: st.param.b64 [func_retval0+0], {{%rd[0-9]+}}; +; CHECK-NEXT: ret; +define i56 @test_i56(i56 %a) { + %r = tail call i56 @test_i56(i56 %a); + ret i56 %r; +} + +; CHECK: .func (.param .b64 func_retval0) +; CHECK-LABEL: test_i57( +; CHECK-NEXT: .param .b64 test_i57_param_0 +; CHECK: ld.param.u64 {{%rd[0-9]+}}, [test_i57_param_0]; +; CHECK: .param .b64 param0; +; CHECK: st.param.b64 [param0+0], {{%rd[0-9]+}}; +; CHECK: .param .b64 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_i57, +; CHECK: ld.param.b64 {{%rd[0-9]+}}, [retval0+0]; +; CHECK: st.param.b64 [func_retval0+0], {{%rd[0-9]+}}; +; CHECK-NEXT: ret; +define i57 @test_i57(i57 %a) { + %r = tail call i57 @test_i57(i57 %a); + ret i57 %r; +} + ; CHECK: .func (.param .b64 func_retval0) ; CHECK-LABEL: test_i64( ; CHECK-NEXT: .param .b64 test_i64_param_0