diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -353,6 +353,8 @@ // PTX, so its size must be adjusted here, too. if (size < 32) size = 32; + if (size > 32 && size < 64) + size = 64; O << ".param .b" << size << " func_retval0"; } else if (isa(Ty)) { @@ -381,8 +383,12 @@ for (unsigned j = 0, je = elems; j != je; ++j) { unsigned sz = elemtype.getSizeInBits(); - if (elemtype.isInteger() && (sz < 32)) - sz = 32; + if (elemtype.isInteger()) { + if (sz <= 32) + sz = 32; + else if (sz <= 64) + sz = 64; + } O << ".reg .b" << sz << " func_retval" << idx; if (j < je - 1) O << ", "; @@ -1491,8 +1497,10 @@ unsigned sz = 0; if (isa(Ty)) { sz = cast(Ty)->getBitWidth(); - if (sz < 32) + if (sz <= 32) sz = 32; + else if (sz <= 64) + sz = 64; } else if (isa(Ty)) sz = thePointerTy.getSizeInBits(); else if (Ty->isHalfTy()) @@ -1556,8 +1564,12 @@ for (unsigned j = 0, je = elems; j != je; ++j) { unsigned sz = elemtype.getSizeInBits(); - if (elemtype.isInteger() && (sz < 32)) - sz = 32; + if (elemtype.isInteger()) { + if (sz <= 32) + sz = 32; + else if (sz <= 64) + sz = 64; + } O << "\t.reg .b" << sz << " "; printParamName(I, paramIndex, O); if (j < je - 1) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -206,6 +206,30 @@ } } +/// PromoteScalarIntegerPTX +/// Used to make sure the arguments/returns are suitable for passing +/// and promote them to a larger size if they're not. +/// +/// The promoted type is placed in \p PromoteVT if the function returns true. +static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT) { + if (VT.isScalarInteger()) { + if (1 < VT.getFixedSizeInBits() && VT.getFixedSizeInBits() < 8) { + *PromotedVT = MVT::i8; + return true; + } else if (8 < VT.getFixedSizeInBits() && VT.getFixedSizeInBits() < 16) { + *PromotedVT = MVT::i16; + return true; + } else if (16 < VT.getFixedSizeInBits() && VT.getFixedSizeInBits() < 32) { + *PromotedVT = MVT::i32; + return true; + } else if (32 < VT.getFixedSizeInBits() && VT.getFixedSizeInBits() < 64) { + *PromotedVT = MVT::i64; + return true; + } + } + return false; +} + // Check whether we can merge loads/stores of some of the pieces of a // flattened function parameter or return value into a single vector // load/store. @@ -1293,6 +1317,8 @@ // PTX, so its size must be adjusted here, too. if (size < 32) size = 32; + else if (size > 32 && size < 64) + size = 64; O << ".param .b" << size << " _"; } else if (isa(retTy)) { @@ -1345,6 +1371,8 @@ sz = cast(Ty)->getBitWidth(); if (sz < 32) sz = 32; + else if (sz > 32 && sz < 64) + sz = 64; } else if (isa(Ty)) { sz = PtrVT.getSizeInBits(); } else if (Ty->isHalfTy()) @@ -1520,6 +1548,8 @@ // size. FP16 is loaded/stored using i16, so it's handled // here as well. TypeSize = 4; + } else if (VT.isInteger() && TypeSize > 4 && TypeSize < 8) { + TypeSize = 8; } SDValue DeclareScalarParamOps[] = { Chain, DAG.getConstant(ParamCount, dl, MVT::i32), @@ -1556,6 +1586,15 @@ } SDValue StVal = OutVals[OIdx]; + + MVT PromotedVT; + if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) { + llvm::ISD::NodeType Ext = + Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + StVal = DAG.getNode(Ext, dl, PromotedVT, StVal); + EltVT = EVT(PromotedVT); + } + if (IsByVal) { auto PtrVT = getPointerTy(DL); SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal, @@ -1641,6 +1680,8 @@ // Scalar needs to be at least 32bit wide if (resultsz < 32) resultsz = 32; + else if (resultsz > 32 && resultsz < 64) + resultsz = 64; SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue); SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32), DAG.getConstant(resultsz, dl, MVT::i32), @@ -1778,6 +1819,14 @@ EVT TheLoadType = VTs[i]; EVT EltType = Ins[i].VT; Align EltAlign = commonAlignment(RetAlign, Offsets[i]); + MVT PromotedVT; + + if (PromoteScalarIntegerPTX(TheLoadType, &PromotedVT)) { + TheLoadType = EVT(PromotedVT); + EltType = EVT(PromotedVT); + needTruncate = true; + } + if (ExtendIntegerRetVal) { TheLoadType = MVT::i32; EltType = MVT::i32; @@ -2558,6 +2607,13 @@ // v2f16 was loaded as an i32. Now we must bitcast it back. else if (EltVT == MVT::v2f16) Elt = DAG.getNode(ISD::BITCAST, dl, MVT::v2f16, Elt); + + // If a promoted integer type is used, truncate down to the original + MVT PromotedVT; + if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) { + Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt); + } + // Extend the element if necessary (e.g. an i8 is loaded // into an i16 register) if (Ins[InsIdx].VT.isInteger() && @@ -2627,11 +2683,24 @@ return Chain; const DataLayout &DL = DAG.getDataLayout(); + SmallVector PromotedOutVals; SmallVector VTs; SmallVector Offsets; ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets); assert(VTs.size() == OutVals.size() && "Bad return value decomposition"); + for (unsigned i = 0, e = VTs.size(); i != e; ++i) { + SDValue PromotedOutVal = OutVals[i]; + MVT PromotedVT; + if (PromoteScalarIntegerPTX(VTs[i], &PromotedVT)) { + llvm::ISD::NodeType Ext = + Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + PromotedOutVal = DAG.getNode(Ext, dl, PromotedVT, PromotedOutVal); + VTs[i] = EVT(PromotedVT); + } + PromotedOutVals.push_back(PromotedOutVal); + } + auto VectorInfo = VectorizePTXValueVTs( VTs, Offsets, RetTy->isSized() ? getFunctionParamOptimizedAlign(&F, RetTy, DL) @@ -2652,12 +2721,14 @@ StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32)); } - SDValue RetVal = OutVals[i]; + SDValue OutVal = OutVals[i]; + SDValue RetVal = PromotedOutVals[i]; + if (ExtendIntegerRetVal) { RetVal = DAG.getNode(Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, dl, MVT::i32, RetVal); - } else if (RetVal.getValueSizeInBits() < 16) { + } else if (OutVal.getValueSizeInBits() < 16) { // Use 16-bit registers for small load-stores as it's the // smallest general purpose register size supported by NVPTX. RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal); diff --git a/llvm/test/CodeGen/NVPTX/i24-param.ll b/llvm/test/CodeGen/NVPTX/i24-param.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/i24-param.ll @@ -0,0 +1,27 @@ +; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | FileCheck %s +; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %} + +target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" + +; CHECK: .visible .func (.param .b32 func_retval0) callee +; CHECK: .param .b32 callee_param_0 +define i24 @callee(i24 %a) { + %val = alloca i24, align 4 + store i24 %a, i24* %val, align 4 + %ret = load i24, i24* %val, align 1 +; CHECK: ld.param.u8 +; CHECK: ld.param.u16 +; CHECK: st.param.b32 + ret i24 %ret +} + +; CHECK: .visible .func caller +define void @caller(i24* %a) { + %val = load i24, i24* %a + %ret = call i24 @callee(i24 %val) +; CHECK: ld.param.b32 +; CHECK: st.u16 +; CHECK: st.u8 + store i24 %ret, i24* %a + ret void +} diff --git a/llvm/test/CodeGen/NVPTX/i40-param.ll b/llvm/test/CodeGen/NVPTX/i40-param.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/i40-param.ll @@ -0,0 +1,27 @@ +; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | FileCheck %s +; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %} + +target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" + +; CHECK: .visible .func (.param .b64 func_retval0) callee +; CHECK: .param .b64 callee_param_0 +define i40 @callee(i40 %a) { + %val = alloca i40, align 8 + store i40 %a, i40* %val, align 8 + %ret = load i40, i40* %val, align 1 +; CHECK: ld.param.u8 +; CHECK: ld.param.u32 +; CHECK: st.param.b64 + ret i40 %ret +} + +; CHECK: .visible .func caller +define void @caller(i40* %a) { + %val = load i40, i40* %a + %ret = call i40 @callee(i40 %val) +; CHECK: ld.param.b64 +; CHECK: st.u32 +; CHECK: st.u8 + store i40 %ret, i40* %a + ret void +} diff --git a/llvm/test/CodeGen/NVPTX/i48-param.ll b/llvm/test/CodeGen/NVPTX/i48-param.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/i48-param.ll @@ -0,0 +1,27 @@ +; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | FileCheck %s +; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %} + +target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" + +; CHECK: .visible .func (.param .b64 func_retval0) callee +; CHECK: .param .b64 callee_param_0 +define i48 @callee(i48 %a) { + %val = alloca i48, align 8 + store i48 %a, i48* %val, align 8 + %ret = load i48, i48* %val, align 1 +; CHECK: ld.param.u16 +; CHECK: ld.param.u32 +; CHECK: st.param.b64 + ret i48 %ret +} + +; CHECK: .visible .func caller +define void @caller(i48* %a) { + %val = load i48, i48* %a + %ret = call i48 @callee(i48 %val) +; CHECK: ld.param.b64 +; CHECK: st.u32 +; CHECK: st.u16 + store i48 %ret, i48* %a + ret void +} diff --git a/llvm/test/CodeGen/NVPTX/i56-param.ll b/llvm/test/CodeGen/NVPTX/i56-param.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/i56-param.ll @@ -0,0 +1,29 @@ +; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | FileCheck %s +; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %} + +target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" + +; CHECK: .visible .func (.param .b64 func_retval0) callee +; CHECK: .param .b64 callee_param_0 +define i56 @callee(i56 %a) { + %val = alloca i56, align 8 + store i56 %a, i56* %val, align 8 + %ret = load i56, i56* %val, align 1 +; CHECK: ld.param.u8 +; CHECK: ld.param.u16 +; CHECK: ld.param.u32 +; CHECK: st.param.b64 + ret i56 %ret +} + +; CHECK: .visible .func caller +define void @caller(i56* %a) { + %val = load i56, i56* %a + %ret = call i56 @callee(i56 %val) +; CHECK: ld.param.b64 +; CHECK: st.u32 +; CHECK: st.u8 +; CHECK: st.u16 + store i56 %ret, i56* %a + ret void +}