Index: lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp =================================================================== --- lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp +++ lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp @@ -61,6 +61,9 @@ case 6: OS << "%fd"; break; + case 7: + OS << "%h"; + break; } unsigned VReg = RegNo & 0x0FFFFFFF; @@ -247,8 +250,12 @@ O << "s"; else if (Imm == NVPTX::PTXLdStInstCode::Unsigned) O << "u"; - else + else if (Imm == NVPTX::PTXLdStInstCode::Untyped) + O << "b"; + else if (Imm == NVPTX::PTXLdStInstCode::Float) O << "f"; + else + llvm_unreachable("Unknown register type"); } else if (!strcmp(Modifier, "vec")) { if (Imm == NVPTX::PTXLdStInstCode::V2) O << ".v2"; Index: lib/Target/NVPTX/NVPTX.h =================================================================== --- lib/Target/NVPTX/NVPTX.h +++ lib/Target/NVPTX/NVPTX.h @@ -108,7 +108,8 @@ enum FromType { Unsigned = 0, Signed, - Float + Float, + Untyped }; enum VecType { Scalar = 1, Index: lib/Target/NVPTX/NVPTXAsmPrinter.cpp =================================================================== --- lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -279,6 +279,10 @@ switch (Cnt->getType()->getTypeID()) { default: report_fatal_error("Unsupported FP type"); break; + case Type::HalfTyID: + MCOp = MCOperand::createExpr( + NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); + break; case Type::FloatTyID: MCOp = MCOperand::createExpr( NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); @@ -316,6 +320,8 @@ Ret = (5 << 28); } else if (RC == &NVPTX::Float64RegsRegClass) { Ret = (6 << 28); + } else if (RC == &NVPTX::Float16RegsRegClass) { + Ret = (7 << 28); } else { report_fatal_error("Bad register class"); } @@ -355,12 +361,12 @@ unsigned size = 0; if (auto *ITy = dyn_cast(Ty)) { size = ITy->getBitWidth(); - if (size < 32) - size = 32; } else { assert(Ty->isFloatingPointTy() && "Floating point type expected here"); size = Ty->getPrimitiveSizeInBits(); } + if (size < 32) + size = 32; O << ".param .b" << size << " func_retval0"; } else if (isa(Ty)) { @@ -1342,6 +1348,8 @@ } break; } + case Type::HalfTyID: + return "b16"; case Type::FloatTyID: return "f32"; case Type::DoubleTyID: @@ -1569,6 +1577,8 @@ sz = 32; } else if (isa(Ty)) sz = thePointerTy.getSizeInBits(); + else if (Ty->isHalfTy()) + sz = 32; else sz = Ty->getPrimitiveSizeInBits(); if (isABI) Index: lib/Target/NVPTX/NVPTXISelDAGToDAG.h =================================================================== --- lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -69,6 +69,7 @@ bool tryTextureIntrinsic(SDNode *N); bool trySurfaceIntrinsic(SDNode *N); bool tryBFE(SDNode *N); + bool tryConstantFP16(SDNode *N); inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) { return CurDAG->getTargetConstant(Imm, DL, MVT::i32); Index: lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp =================================================================== --- lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -42,7 +42,6 @@ cl::desc("NVPTX Specific: Flush f32 subnormals to sign-preserving zero."), cl::init(false)); - /// createNVPTXISelDag - This pass converts a legalized DAG into a /// NVPTX-specific DAG, ready for instruction scheduling. FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM, @@ -515,6 +514,10 @@ case ISD::ADDRSPACECAST: SelectAddrSpaceCast(N); return; + case ISD::ConstantFP: + if (tryConstantFP16(N)) + return; + break; default: break; } @@ -536,6 +539,19 @@ } } +// There's no way to specify FP16 constants in .f16 ops, so we have to +// load them into an .f16 register first. +bool NVPTXDAGToDAGISel::tryConstantFP16(SDNode *N) { + if (N->getValueType(0) != MVT::f16) + return false; + SDValue Val = CurDAG->getTargetConstantFP( + cast(N)->getValueAPF(), SDLoc(N), MVT::f16); + SDNode *LoadConstF16 = + CurDAG->getMachineNode(NVPTX::LOAD_CONST_F16, SDLoc(N), MVT::f16, Val); + ReplaceNode(N, LoadConstF16); + return true; +} + static unsigned int getCodeAddrSpace(MemSDNode *N) { const Value *Src = N->getMemOperand()->getValue(); @@ -735,7 +751,8 @@ if ((LD->getExtensionType() == ISD::SEXTLOAD)) fromType = NVPTX::PTXLdStInstCode::Signed; else if (ScalarVT.isFloatingPoint()) - fromType = NVPTX::PTXLdStInstCode::Float; + fromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped + : NVPTX::PTXLdStInstCode::Float; else fromType = NVPTX::PTXLdStInstCode::Unsigned; @@ -761,6 +778,9 @@ case MVT::i64: Opcode = NVPTX::LD_i64_avar; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_avar; + break; case MVT::f32: Opcode = NVPTX::LD_f32_avar; break; @@ -789,6 +809,9 @@ case MVT::i64: Opcode = NVPTX::LD_i64_asi; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_asi; + break; case MVT::f32: Opcode = NVPTX::LD_f32_asi; break; @@ -818,6 +841,9 @@ case MVT::i64: Opcode = NVPTX::LD_i64_ari_64; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_ari_64; + break; case MVT::f32: Opcode = NVPTX::LD_f32_ari_64; break; @@ -841,6 +867,9 @@ case MVT::i64: Opcode = NVPTX::LD_i64_ari; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_ari; + break; case MVT::f32: Opcode = NVPTX::LD_f32_ari; break; @@ -870,6 +899,9 @@ case MVT::i64: Opcode = NVPTX::LD_i64_areg_64; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_areg_64; + break; case MVT::f32: Opcode = NVPTX::LD_f32_areg_64; break; @@ -893,6 +925,9 @@ case MVT::i64: Opcode = NVPTX::LD_i64_areg; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_areg; + break; case MVT::f32: Opcode = NVPTX::LD_f32_areg; break; @@ -968,7 +1003,8 @@ if (ExtensionType == ISD::SEXTLOAD) FromType = NVPTX::PTXLdStInstCode::Signed; else if (ScalarVT.isFloatingPoint()) - FromType = NVPTX::PTXLdStInstCode::Float; + FromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped + : NVPTX::PTXLdStInstCode::Float; else FromType = NVPTX::PTXLdStInstCode::Unsigned; @@ -2168,7 +2204,8 @@ unsigned toTypeWidth = ScalarVT.getSizeInBits(); unsigned int toType; if (ScalarVT.isFloatingPoint()) - toType = NVPTX::PTXLdStInstCode::Float; + toType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped + : NVPTX::PTXLdStInstCode::Float; else toType = NVPTX::PTXLdStInstCode::Unsigned; @@ -2195,6 +2232,9 @@ case MVT::i64: Opcode = NVPTX::ST_i64_avar; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_avar; + break; case MVT::f32: Opcode = NVPTX::ST_f32_avar; break; @@ -2224,6 +2264,9 @@ case MVT::i64: Opcode = NVPTX::ST_i64_asi; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_asi; + break; case MVT::f32: Opcode = NVPTX::ST_f32_asi; break; @@ -2254,6 +2297,9 @@ case MVT::i64: Opcode = NVPTX::ST_i64_ari_64; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_ari_64; + break; case MVT::f32: Opcode = NVPTX::ST_f32_ari_64; break; @@ -2277,6 +2323,9 @@ case MVT::i64: Opcode = NVPTX::ST_i64_ari; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_ari; + break; case MVT::f32: Opcode = NVPTX::ST_f32_ari; break; @@ -2307,6 +2356,9 @@ case MVT::i64: Opcode = NVPTX::ST_i64_areg_64; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_areg_64; + break; case MVT::f32: Opcode = NVPTX::ST_f32_areg_64; break; @@ -2330,6 +2382,9 @@ case MVT::i64: Opcode = NVPTX::ST_i64_areg; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_areg; + break; case MVT::f32: Opcode = NVPTX::ST_f32_areg; break; @@ -2781,6 +2836,9 @@ case MVT::i64: Opc = NVPTX::LoadParamMemI64; break; + case MVT::f16: + Opc = NVPTX::LoadParamMemF16; + break; case MVT::f32: Opc = NVPTX::LoadParamMemF32; break; @@ -2916,6 +2974,9 @@ case MVT::i64: Opcode = NVPTX::StoreRetvalI64; break; + case MVT::f16: + Opcode = NVPTX::StoreRetvalF16; + break; case MVT::f32: Opcode = NVPTX::StoreRetvalF32; break; @@ -3049,6 +3110,9 @@ case MVT::i64: Opcode = NVPTX::StoreParamI64; break; + case MVT::f16: + Opcode = NVPTX::StoreParamF16; + break; case MVT::f32: Opcode = NVPTX::StoreParamF32; break; Index: lib/Target/NVPTX/NVPTXISelLowering.h =================================================================== --- lib/Target/NVPTX/NVPTXISelLowering.h +++ lib/Target/NVPTX/NVPTXISelLowering.h @@ -527,6 +527,7 @@ SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerSTOREf16(SDValue Op, SelectionDAG &DAG) const; SDValue LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const; SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const; Index: lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- lib/Target/NVPTX/NVPTXISelLowering.cpp +++ lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -140,8 +140,13 @@ addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass); addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass); addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass); + addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass); // Operations not directly supported by NVPTX. + setOperationAction(ISD::SETCC, MVT::f16, + STI.allowFP16Math() ? Legal : Promote); + setOperationAction(ISD::SELECT_CC, MVT::f16, + STI.allowFP16Math() ? Expand : Promote); setOperationAction(ISD::SELECT_CC, MVT::f32, Expand); setOperationAction(ISD::SELECT_CC, MVT::f64, Expand); setOperationAction(ISD::SELECT_CC, MVT::i1, Expand); @@ -149,6 +154,8 @@ setOperationAction(ISD::SELECT_CC, MVT::i16, Expand); setOperationAction(ISD::SELECT_CC, MVT::i32, Expand); setOperationAction(ISD::SELECT_CC, MVT::i64, Expand); + setOperationAction(ISD::BR_CC, MVT::f16, + STI.allowFP16Math() ? Expand : Promote); setOperationAction(ISD::BR_CC, MVT::f32, Expand); setOperationAction(ISD::BR_CC, MVT::f64, Expand); setOperationAction(ISD::BR_CC, MVT::i1, Expand); @@ -235,6 +242,7 @@ // This is legal in NVPTX setOperationAction(ISD::ConstantFP, MVT::f64, Legal); setOperationAction(ISD::ConstantFP, MVT::f32, Legal); + setOperationAction(ISD::ConstantFP, MVT::f16, Legal); // TRAP can be lowered to PTX trap setOperationAction(ISD::TRAP, MVT::Other, Legal); @@ -281,18 +289,35 @@ setTargetDAGCombine(ISD::SREM); setTargetDAGCombine(ISD::UREM); + if (!STI.allowFP16Math()) { + // Promote to FP32 ops, even if FP16 hw is available. While + // sm_53+ GPUs have some sort of FP16 support in hardware, only + // sm_53 and sm_60 have full implementation. Others only have + // token amount of hardware and are likely to run faster by using + // fp32 units instead. + setOperationAction(ISD::FADD, MVT::f16, Promote); + setOperationAction(ISD::FMUL, MVT::f16, Promote); + setOperationAction(ISD::FSUB, MVT::f16, Promote); + setOperationAction(ISD::FMA, MVT::f16, Promote); + } + // Library functions. These default to Expand, but we have instructions // for them. + setOperationAction(ISD::FCEIL, MVT::f16, Legal); setOperationAction(ISD::FCEIL, MVT::f32, Legal); setOperationAction(ISD::FCEIL, MVT::f64, Legal); + setOperationAction(ISD::FFLOOR, MVT::f16, Legal); setOperationAction(ISD::FFLOOR, MVT::f32, Legal); setOperationAction(ISD::FFLOOR, MVT::f64, Legal); setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal); setOperationAction(ISD::FNEARBYINT, MVT::f64, Legal); + setOperationAction(ISD::FRINT, MVT::f16, Legal); setOperationAction(ISD::FRINT, MVT::f32, Legal); setOperationAction(ISD::FRINT, MVT::f64, Legal); + setOperationAction(ISD::FROUND, MVT::f16, Legal); setOperationAction(ISD::FROUND, MVT::f32, Legal); setOperationAction(ISD::FROUND, MVT::f64, Legal); + setOperationAction(ISD::FTRUNC, MVT::f16, Legal); setOperationAction(ISD::FTRUNC, MVT::f32, Legal); setOperationAction(ISD::FTRUNC, MVT::f64, Legal); setOperationAction(ISD::FMINNUM, MVT::f32, Legal); @@ -300,6 +325,24 @@ setOperationAction(ISD::FMAXNUM, MVT::f32, Legal); setOperationAction(ISD::FMAXNUM, MVT::f64, Legal); + // 'Expand' implements FCOPYSIGN w/o need for the library. + setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand); + + // FP16 does not support these nodes in hardware, but we can perform + // these ops using single-precision hardware. + setOperationAction(ISD::FDIV, MVT::f16, Promote); + setOperationAction(ISD::FREM, MVT::f16, Promote); + setOperationAction(ISD::FSQRT, MVT::f16, Promote); + setOperationAction(ISD::FSIN, MVT::f16, Promote); + setOperationAction(ISD::FCOS, MVT::f16, Promote); + setOperationAction(ISD::FABS, MVT::f16, Promote); + setOperationAction(ISD::FMINNUM, MVT::f16, Promote); + setOperationAction(ISD::FMAXNUM, MVT::f16, Promote); + setOperationAction(ISD::FMINNAN, MVT::f16, Promote); + setOperationAction(ISD::FMAXNAN, MVT::f16, Promote); + // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate. // No FPOW or FREM in PTX. @@ -943,19 +986,18 @@ unsigned size = 0; if (auto *ITy = dyn_cast(retTy)) { size = ITy->getBitWidth(); - if (size < 32) - size = 32; } else { assert(retTy->isFloatingPointTy() && "Floating point type expected here"); size = retTy->getPrimitiveSizeInBits(); } + if (size < 32) + size = 32; O << ".param .b" << size << " _"; } else if (isa(retTy)) { O << ".param .b" << PtrVT.getSizeInBits() << " _"; - } else if ((retTy->getTypeID() == Type::StructTyID) || - isa(retTy)) { + } else if (retTy->isAggregateType() || retTy->isVectorTy()) { auto &DL = CS->getCalledFunction()->getParent()->getDataLayout(); O << ".param .align " << retAlignment << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]"; @@ -994,7 +1036,7 @@ OIdx += len - 1; continue; } - // i8 types in IR will be i16 types in SDAG + // i8 types in IR will be i16 types in SDAG assert((getValueType(DL, Ty) == Outs[OIdx].VT || (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) && "type mismatch between callee prototype and arguments"); @@ -1004,8 +1046,10 @@ sz = cast(Ty)->getBitWidth(); if (sz < 32) sz = 32; - } else if (isa(Ty)) + } else if (isa(Ty)) { sz = PtrVT.getSizeInBits(); + } else if (Ty->isHalfTy()) + sz = 32; else sz = Ty->getPrimitiveSizeInBits(); O << ".param .b" << sz << " "; @@ -1316,7 +1360,8 @@ needExtend = true; if (sz < 32) sz = 32; - } + } else if (VT.isFloatingPoint() && sz < 32) + sz = 32; SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); SDValue DeclareParamOps[] = { Chain, DAG.getConstant(paramCount, dl, MVT::i32), @@ -1932,12 +1977,17 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const { EVT ValVT = Op.getOperand(1).getValueType(); - if (ValVT == MVT::i1) + switch (ValVT.getSimpleVT().SimpleTy) { + case MVT::i1: return LowerSTOREi1(Op, DAG); - else if (ValVT.isVector()) - return LowerSTOREVector(Op, DAG); - else - return SDValue(); + case MVT::f16: + return LowerSTOREf16(Op, DAG); + default: + if (ValVT.isVector()) + return LowerSTOREVector(Op, DAG); + else + return SDValue(); + } } SDValue @@ -2055,6 +2105,22 @@ return Result; } +SDValue NVPTXTargetLowering::LowerSTOREf16(SDValue Op, + SelectionDAG &DAG) const { + SDNode *Node = Op.getNode(); + SDLoc dl(Node); + StoreSDNode *ST = cast(Node); + SDValue Tmp1 = ST->getChain(); + SDValue Tmp2 = ST->getBasePtr(); + SDValue Tmp3 = ST->getValue(); + assert(Tmp3.getValueType() == MVT::f16 && + "Custom lowering for f16 store only"); + SDValue Result = + DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(), MVT::i16, + ST->getAlignment(), ST->getMemOperand()->getFlags()); + return Result; +} + SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const { std::string ParamSym; @@ -2539,8 +2605,9 @@ // specifically not for aggregates. TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal); TheStoreType = MVT::i32; - } - else if (TmpVal.getValueSizeInBits() < 16) + } else if (RetTy->isHalfTy()) { + TheStoreType = MVT::f16; + } else if (TmpVal.getValueSizeInBits() < 16) TmpVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, TmpVal); SDValue Ops[] = { Index: lib/Target/NVPTX/NVPTXInstrInfo.cpp =================================================================== --- lib/Target/NVPTX/NVPTXInstrInfo.cpp +++ lib/Target/NVPTX/NVPTXInstrInfo.cpp @@ -52,6 +52,9 @@ } else if (DestRC == &NVPTX::Int64RegsRegClass) { Op = (SrcRC == &NVPTX::Int64RegsRegClass ? NVPTX::IMOV64rr : NVPTX::BITCONVERT_64_F2I); + } else if (DestRC == &NVPTX::Float16RegsRegClass) { + Op = (SrcRC == &NVPTX::Float16RegsRegClass ? NVPTX::FMOV16rr + : NVPTX::BITCONVERT_16_I2F); } else if (DestRC == &NVPTX::Float32RegsRegClass) { Op = (SrcRC == &NVPTX::Float32RegsRegClass ? NVPTX::FMOV32rr : NVPTX::BITCONVERT_32_I2F); Index: lib/Target/NVPTX/NVPTXInstrInfo.td =================================================================== --- lib/Target/NVPTX/NVPTXInstrInfo.td +++ lib/Target/NVPTX/NVPTXInstrInfo.td @@ -18,6 +18,10 @@ def NOP : NVPTXInst<(outs), (ins), "", []>; } +let OperandType = "OPERAND_IMMEDIATE" in { + def f16imm : Operand; +} + // List of vector specific properties def isVecLD : VecInstTypeEnum<1>; def isVecST : VecInstTypeEnum<2>; @@ -148,6 +152,7 @@ def hasPTX31 : Predicate<"Subtarget->getPTXVersion() >= 31">; +def hasFP16Math: Predicate<"Subtarget->allowFP16Math()">; //===----------------------------------------------------------------------===// // Some Common Instruction Class Templates @@ -286,6 +291,19 @@ [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>, Requires<[allowFMA]>; + def f16rr_ftz : + NVPTXInst<(outs Float16Regs:$dst), + (ins Float16Regs:$a, Float16Regs:$b), + !strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"), + [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + Requires<[hasFP16Math, allowFMA, doF32FTZ]>; + def f16rr : + NVPTXInst<(outs Float16Regs:$dst), + (ins Float16Regs:$a, Float16Regs:$b), + !strconcat(OpcStr, ".f16 \t$dst, $a, $b;"), + [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + Requires<[hasFP16Math, allowFMA]>; + // These have strange names so we don't perturb existing mir tests. def _rnf64rr : NVPTXInst<(outs Float64Regs:$dst), @@ -323,6 +341,18 @@ !strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"), [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>, Requires<[noFMA]>; + def _rnf16rr_ftz : + NVPTXInst<(outs Float16Regs:$dst), + (ins Float16Regs:$a, Float16Regs:$b), + !strconcat(OpcStr, ".rn.ftz.f16 \t$dst, $a, $b;"), + [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + Requires<[hasFP16Math, noFMA, doF32FTZ]>; + def _rnf16rr : + NVPTXInst<(outs Float16Regs:$dst), + (ins Float16Regs:$a, Float16Regs:$b), + !strconcat(OpcStr, ".rn.f16 \t$dst, $a, $b;"), + [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + Requires<[hasFP16Math, noFMA]>; } // Template for operations which take two f32 or f64 operands. Provides three @@ -374,11 +404,6 @@ (ins Int16Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", FromName, ".u16\t$dst, $src;"), []>; - def _f16 : - NVPTXInst<(outs RC:$dst), - (ins Int16Regs:$src, CvtMode:$mode), - !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", - FromName, ".f16\t$dst, $src;"), []>; def _s32 : NVPTXInst<(outs RC:$dst), (ins Int32Regs:$src, CvtMode:$mode), @@ -399,6 +424,11 @@ (ins Int64Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", FromName, ".u64\t$dst, $src;"), []>; + def _f16 : + NVPTXInst<(outs RC:$dst), + (ins Float16Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", + FromName, ".f16\t$dst, $src;"), []>; def _f32 : NVPTXInst<(outs RC:$dst), (ins Float32Regs:$src, CvtMode:$mode), @@ -416,11 +446,11 @@ defm CVT_u8 : CVT_FROM_ALL<"u8", Int16Regs>; defm CVT_s16 : CVT_FROM_ALL<"s16", Int16Regs>; defm CVT_u16 : CVT_FROM_ALL<"u16", Int16Regs>; - defm CVT_f16 : CVT_FROM_ALL<"f16", Int16Regs>; defm CVT_s32 : CVT_FROM_ALL<"s32", Int32Regs>; defm CVT_u32 : CVT_FROM_ALL<"u32", Int32Regs>; defm CVT_s64 : CVT_FROM_ALL<"s64", Int64Regs>; defm CVT_u64 : CVT_FROM_ALL<"u64", Int64Regs>; + defm CVT_f16 : CVT_FROM_ALL<"f16", Float16Regs>; defm CVT_f32 : CVT_FROM_ALL<"f32", Float32Regs>; defm CVT_f64 : CVT_FROM_ALL<"f64", Float64Regs>; @@ -748,6 +778,14 @@ N->getValueAPF().convertToDouble() == 1.0; }]>; +// Loads FP16 constant into a register. +// ptxas does not have hex representation for fp16, so we can't use fp16 +// immediate values in .f16 instructions and have to load the +// constant into a register using mov.b16. +def LOAD_CONST_F16 : + NVPTXInst<(outs Float16Regs:$dst), (ins f16imm:$a), + "mov.b16 \t$dst, $a;", []>; + defm FADD : F3_fma_component<"add", fadd>; defm FSUB : F3_fma_component<"sub", fsub>; defm FMUL : F3_fma_component<"mul", fmul>; @@ -942,6 +980,15 @@ Requires<[Pred]>; } +multiclass FMA_F16 { + def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), + !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), + [(set RC:$dst, (fma RC:$a, RC:$b, RC:$c))]>, + Requires<[hasFP16Math, Pred]>; +} + +defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", Float16Regs, f16imm, doF32FTZ>; +defm FMA16 : FMA_F16<"fma.rn.f16", Float16Regs, f16imm, true>; defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>; defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, true>; defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, true>; @@ -1317,6 +1364,11 @@ defm SETP_u64 : SETP<"u64", Int64Regs, i64imm>; defm SETP_f32 : SETP<"f32", Float32Regs, f32imm>; defm SETP_f64 : SETP<"f64", Float64Regs, f64imm>; +def SETP_f16rr : + NVPTXInst<(outs Int1Regs:$dst), + (ins Float16Regs:$a, Float16Regs:$b, CmpMode:$cmp), + "setp${cmp:base}${cmp:ftz}.f16 $dst, $a, $b;", + []>, Requires<[hasFP16Math]>; // FIXME: This doesn't appear to be correct. The "set" mnemonic has the form // "set.CmpOp{.ftz}.dtype.stype", where dtype is the type of the destination @@ -1345,6 +1397,7 @@ defm SET_b64 : SET<"b64", Int64Regs, i64imm>; defm SET_s64 : SET<"s64", Int64Regs, i64imm>; defm SET_u64 : SET<"u64", Int64Regs, i64imm>; +defm SET_f16 : SET<"f16", Float16Regs, f16imm>; defm SET_f32 : SET<"f32", Float32Regs, f32imm>; defm SET_f64 : SET<"f64", Float64Regs, f64imm>; @@ -1408,6 +1461,7 @@ defm SELP_b64 : SELP_PATTERN<"b64", Int64Regs, i64imm, imm>; defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>; defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>; +defm SELP_f16 : SELP_PATTERN<"b16", Float16Regs, f16imm, fpimm>; defm SELP_f32 : SELP_PATTERN<"f32", Float32Regs, f32imm, fpimm>; defm SELP_f64 : SELP_PATTERN<"f64", Float64Regs, f64imm, fpimm>; @@ -1472,6 +1526,9 @@ def IMOV64rr : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$sss), "mov.u64 \t$dst, $sss;", []>; + def FMOV16rr : NVPTXInst<(outs Float16Regs:$dst), (ins Float16Regs:$src), + // We have to use .b16 here as there's no mov.f16. + "mov.b16 \t$dst, $src;", []>; def FMOV32rr : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src), "mov.f32 \t$dst, $src;", []>; def FMOV64rr : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$src), @@ -1633,6 +1690,26 @@ multiclass FSET_FORMAT { + // f16 -> pred + def : Pat<(i1 (OpNode Float16Regs:$a, Float16Regs:$b)), + (SETP_f16rr Float16Regs:$a, Float16Regs:$b, ModeFTZ)>, + Requires<[hasFP16Math,doF32FTZ]>; + def : Pat<(i1 (OpNode Float16Regs:$a, Float16Regs:$b)), + (SETP_f16rr Float16Regs:$a, Float16Regs:$b, Mode)>, + Requires<[hasFP16Math]>; + def : Pat<(i1 (OpNode Float16Regs:$a, fpimm:$b)), + (SETP_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), ModeFTZ)>, + Requires<[hasFP16Math,doF32FTZ]>; + def : Pat<(i1 (OpNode Float16Regs:$a, fpimm:$b)), + (SETP_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), Mode)>, + Requires<[hasFP16Math]>; + def : Pat<(i1 (OpNode fpimm:$a, Float16Regs:$b)), + (SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, ModeFTZ)>, + Requires<[hasFP16Math,doF32FTZ]>; + def : Pat<(i1 (OpNode fpimm:$a, Float16Regs:$b)), + (SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>, + Requires<[hasFP16Math]>; + // f32 -> pred def : Pat<(i1 (OpNode Float32Regs:$a, Float32Regs:$b)), (SETP_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>, @@ -1658,6 +1735,26 @@ def : Pat<(i1 (OpNode fpimm:$a, Float64Regs:$b)), (SETP_f64ir fpimm:$a, Float64Regs:$b, Mode)>; + // f16 -> i32 + def : Pat<(i32 (OpNode Float16Regs:$a, Float16Regs:$b)), + (SET_f16rr Float16Regs:$a, Float16Regs:$b, ModeFTZ)>, + Requires<[hasFP16Math, doF32FTZ]>; + def : Pat<(i32 (OpNode Float16Regs:$a, Float16Regs:$b)), + (SET_f16rr Float16Regs:$a, Float16Regs:$b, Mode)>, + Requires<[hasFP16Math]>; + def : Pat<(i32 (OpNode Float16Regs:$a, fpimm:$b)), + (SET_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), ModeFTZ)>, + Requires<[hasFP16Math, doF32FTZ]>; + def : Pat<(i32 (OpNode Float16Regs:$a, fpimm:$b)), + (SET_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), Mode)>, + Requires<[hasFP16Math]>; + def : Pat<(i32 (OpNode fpimm:$a, Float16Regs:$b)), + (SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, ModeFTZ)>, + Requires<[hasFP16Math, doF32FTZ]>; + def : Pat<(i32 (OpNode fpimm:$a, Float16Regs:$b)), + (SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>, + Requires<[hasFP16Math]>; + // f32 -> i32 def : Pat<(i32 (OpNode Float32Regs:$a, Float32Regs:$b)), (SET_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>, @@ -1941,6 +2038,7 @@ def LoadParamMemV4I32 : LoadParamV4MemInst; def LoadParamMemV4I16 : LoadParamV4MemInst; def LoadParamMemV4I8 : LoadParamV4MemInst; +def LoadParamMemF16 : LoadParamMemInst; def LoadParamMemF32 : LoadParamMemInst; def LoadParamMemF64 : LoadParamMemInst; def LoadParamMemV2F32 : LoadParamV2MemInst; @@ -1961,6 +2059,7 @@ def StoreParamV4I16 : StoreParamV4Inst; def StoreParamV4I8 : StoreParamV4Inst; +def StoreParamF16 : StoreParamInst; def StoreParamF32 : StoreParamInst; def StoreParamF64 : StoreParamInst; def StoreParamV2F32 : StoreParamV2Inst; @@ -1981,6 +2080,7 @@ def StoreRetvalF64 : StoreRetvalInst; def StoreRetvalF32 : StoreRetvalInst; +def StoreRetvalF16 : StoreRetvalInst; def StoreRetvalV2F64 : StoreRetvalV2Inst; def StoreRetvalV2F32 : StoreRetvalV2Inst; def StoreRetvalV4F32 : StoreRetvalV4Inst; @@ -2068,6 +2168,7 @@ [(set Int16Regs:$dst, (MoveParam Int16Regs:$src))]>; def MoveParamF64 : MoveParamInst; def MoveParamF32 : MoveParamInst; +def MoveParamF16 : MoveParamInst; class PseudoUseParamInst : NVPTXInst<(outs), (ins regclass:$src), @@ -2128,6 +2229,7 @@ defm LD_i16 : LD; defm LD_i32 : LD; defm LD_i64 : LD; + defm LD_f16 : LD; defm LD_f32 : LD; defm LD_f64 : LD; } @@ -2176,6 +2278,7 @@ defm ST_i16 : ST; defm ST_i32 : ST; defm ST_i64 : ST; + defm ST_f16 : ST; defm ST_f32 : ST; defm ST_f64 : ST; } @@ -2368,6 +2471,8 @@ !strconcat("mov.b", !strconcat(SzStr, " \t $d, $a;")), [(set regclassOut:$d, (bitconvert regclassIn:$a))]>; +def BITCONVERT_16_I2F : F_BITCONVERT<"16", Int16Regs, Float16Regs>; +def BITCONVERT_16_F2I : F_BITCONVERT<"16", Float16Regs, Int16Regs>; def BITCONVERT_32_I2F : F_BITCONVERT<"32", Int32Regs, Float32Regs>; def BITCONVERT_32_F2I : F_BITCONVERT<"32", Float32Regs, Int32Regs>; def BITCONVERT_64_I2F : F_BITCONVERT<"64", Int64Regs, Float64Regs>; @@ -2377,6 +2482,26 @@ // we cannot specify floating-point literals in isel patterns. Therefore, we // use an integer selp to select either 1 or 0 and then cvt to floating-point. +// sint -> f16 +def : Pat<(f16 (sint_to_fp Int1Regs:$a)), + (CVT_f16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; +def : Pat<(f16 (sint_to_fp Int16Regs:$a)), + (CVT_f16_s16 Int16Regs:$a, CvtRN)>; +def : Pat<(f16 (sint_to_fp Int32Regs:$a)), + (CVT_f16_s32 Int32Regs:$a, CvtRN)>; +def : Pat<(f16 (sint_to_fp Int64Regs:$a)), + (CVT_f16_s64 Int64Regs:$a, CvtRN)>; + +// uint -> f16 +def : Pat<(f16 (uint_to_fp Int1Regs:$a)), + (CVT_f16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; +def : Pat<(f16 (uint_to_fp Int16Regs:$a)), + (CVT_f16_u16 Int16Regs:$a, CvtRN)>; +def : Pat<(f16 (uint_to_fp Int32Regs:$a)), + (CVT_f16_u32 Int32Regs:$a, CvtRN)>; +def : Pat<(f16 (uint_to_fp Int64Regs:$a)), + (CVT_f16_u64 Int64Regs:$a, CvtRN)>; + // sint -> f32 def : Pat<(f32 (sint_to_fp Int1Regs:$a)), (CVT_f32_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; @@ -2418,6 +2543,38 @@ (CVT_f64_u64 Int64Regs:$a, CvtRN)>; +// f16 -> sint +def : Pat<(i1 (fp_to_sint Float16Regs:$a)), + (SETP_b16ri (BITCONVERT_16_F2I Float16Regs:$a), 0, CmpEQ)>; +def : Pat<(i16 (fp_to_sint Float16Regs:$a)), + (CVT_s16_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i16 (fp_to_sint Float16Regs:$a)), + (CVT_s16_f16 Float16Regs:$a, CvtRZI)>; +def : Pat<(i32 (fp_to_sint Float16Regs:$a)), + (CVT_s32_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i32 (fp_to_sint Float16Regs:$a)), + (CVT_s32_f16 Float16Regs:$a, CvtRZI)>; +def : Pat<(i64 (fp_to_sint Float16Regs:$a)), + (CVT_s64_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i64 (fp_to_sint Float16Regs:$a)), + (CVT_s64_f16 Float16Regs:$a, CvtRZI)>; + +// f16 -> uint +def : Pat<(i1 (fp_to_uint Float16Regs:$a)), + (SETP_b16ri (BITCONVERT_16_F2I Float16Regs:$a), 0, CmpEQ)>; +def : Pat<(i16 (fp_to_uint Float16Regs:$a)), + (CVT_u16_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i16 (fp_to_uint Float16Regs:$a)), + (CVT_u16_f16 Float16Regs:$a, CvtRZI)>; +def : Pat<(i32 (fp_to_uint Float16Regs:$a)), + (CVT_u32_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i32 (fp_to_uint Float16Regs:$a)), + (CVT_u32_f16 Float16Regs:$a, CvtRZI)>; +def : Pat<(i64 (fp_to_uint Float16Regs:$a)), + (CVT_u64_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i64 (fp_to_uint Float16Regs:$a)), + (CVT_u64_f16 Float16Regs:$a, CvtRZI)>; + // f32 -> sint def : Pat<(i1 (fp_to_sint Float32Regs:$a)), (SETP_b32ri (BITCONVERT_32_F2I Float32Regs:$a), 0, CmpEQ)>; @@ -2647,12 +2804,36 @@ def : Pat<(ctpop Int16Regs:$a), (CVT_u16_u32 (POPCr32 (CVT_u32_u16 Int16Regs:$a, CvtNONE)), CvtNONE)>; +// fpround f32 -> f16 +def : Pat<(f16 (fpround Float32Regs:$a)), + (CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f16 (fpround Float32Regs:$a)), + (CVT_f16_f32 Float32Regs:$a, CvtRN)>; + +// fpround f64 -> f16 +def : Pat<(f16 (fpround Float64Regs:$a)), + (CVT_f16_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f16 (fpround Float64Regs:$a)), + (CVT_f16_f64 Float64Regs:$a, CvtRN)>; + // fpround f64 -> f32 def : Pat<(f32 (fpround Float64Regs:$a)), (CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(f32 (fpround Float64Regs:$a)), (CVT_f32_f64 Float64Regs:$a, CvtRN)>; +// fpextend f16 -> f32 +def : Pat<(f32 (fpextend Float16Regs:$a)), + (CVT_f32_f16 Float16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f32 (fpextend Float16Regs:$a)), + (CVT_f32_f16 Float16Regs:$a, CvtNONE)>; + +// fpextend f16 -> f64 +def : Pat<(f64 (fpextend Float16Regs:$a)), + (CVT_f64_f16 Float16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f64 (fpextend Float16Regs:$a)), + (CVT_f64_f16 Float16Regs:$a, CvtNONE)>; + // fpextend f32 -> f64 def : Pat<(f64 (fpextend Float32Regs:$a)), (CVT_f64_f32 Float32Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; @@ -2664,6 +2845,10 @@ // fceil, ffloor, fround, ftrunc. +def : Pat<(fceil Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRPI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(fceil Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRPI)>, Requires<[doNoF32FTZ]>; def : Pat<(fceil Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, CvtRPI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(fceil Float32Regs:$a), @@ -2671,6 +2856,10 @@ def : Pat<(fceil Float64Regs:$a), (CVT_f64_f64 Float64Regs:$a, CvtRPI)>; +def : Pat<(ffloor Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRMI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(ffloor Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRMI)>, Requires<[doNoF32FTZ]>; def : Pat<(ffloor Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, CvtRMI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(ffloor Float32Regs:$a), @@ -2678,6 +2867,10 @@ def : Pat<(ffloor Float64Regs:$a), (CVT_f64_f64 Float64Regs:$a, CvtRMI)>; +def : Pat<(fround Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f16 (fround Float16Regs:$a)), + (CVT_f16_f16 Float16Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; def : Pat<(fround Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(f32 (fround Float32Regs:$a)), @@ -2685,6 +2878,10 @@ def : Pat<(f64 (fround Float64Regs:$a)), (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; +def : Pat<(ftrunc Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(ftrunc Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRZI)>, Requires<[doNoF32FTZ]>; def : Pat<(ftrunc Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(ftrunc Float32Regs:$a), @@ -2696,6 +2893,10 @@ // strictly correct, because it causes us to ignore the rounding mode. But it // matches what CUDA's "libm" does. +def : Pat<(fnearbyint Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(fnearbyint Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; def : Pat<(fnearbyint Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(fnearbyint Float32Regs:$a), @@ -2703,6 +2904,10 @@ def : Pat<(fnearbyint Float64Regs:$a), (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; +def : Pat<(frint Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(frint Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; def : Pat<(frint Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(frint Float32Regs:$a), Index: lib/Target/NVPTX/NVPTXIntrinsics.td =================================================================== --- lib/Target/NVPTX/NVPTXIntrinsics.td +++ lib/Target/NVPTX/NVPTXIntrinsics.td @@ -803,49 +803,13 @@ (CVT_f64_u64 Int64Regs:$a, CvtRP)>; -// FIXME: Ideally, we could use these patterns instead of the scope-creating -// patterns, but ptxas does not like these since .s16 is not compatible with -// .f16. The solution is to use .bXX for all integer register types, but we -// are not there yet. -//def : Pat<(int_nvvm_f2h_rn_ftz Float32Regs:$a), -// (CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ)>; -//def : Pat<(int_nvvm_f2h_rn Float32Regs:$a), -// (CVT_f16_f32 Float32Regs:$a, CvtRN)>; -// -//def : Pat<(int_nvvm_h2f Int16Regs:$a), -// (CVT_f32_f16 Int16Regs:$a, CvtNONE)>; +def : Pat<(int_nvvm_f2h_rn_ftz Float32Regs:$a), + (BITCONVERT_16_F2I (CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ))>; +def : Pat<(int_nvvm_f2h_rn Float32Regs:$a), + (BITCONVERT_16_F2I (CVT_f16_f32 Float32Regs:$a, CvtRN))>; -def INT_NVVM_F2H_RN_FTZ : F_MATH_1; -def INT_NVVM_F2H_RN : F_MATH_1; - -def INT_NVVM_H2F : F_MATH_1; - -def : Pat<(f32 (f16_to_fp Int16Regs:$a)), - (CVT_f32_f16 Int16Regs:$a, CvtNONE)>; -def : Pat<(i16 (fp_to_f16 Float32Regs:$a)), - (CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; -def : Pat<(i16 (fp_to_f16 Float32Regs:$a)), - (CVT_f16_f32 Float32Regs:$a, CvtRN)>; - -def : Pat<(f64 (f16_to_fp Int16Regs:$a)), - (CVT_f64_f16 Int16Regs:$a, CvtNONE)>; -def : Pat<(i16 (fp_to_f16 Float64Regs:$a)), - (CVT_f16_f64 Float64Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_h2f Int16Regs:$a), + (CVT_f32_f16 (BITCONVERT_16_I2F Int16Regs:$a), CvtNONE)>; // // Bitcast Index: lib/Target/NVPTX/NVPTXMCExpr.h =================================================================== --- lib/Target/NVPTX/NVPTXMCExpr.h +++ lib/Target/NVPTX/NVPTXMCExpr.h @@ -22,8 +22,9 @@ public: enum VariantKind { VK_NVPTX_None, - VK_NVPTX_SINGLE_PREC_FLOAT, // FP constant in single-precision - VK_NVPTX_DOUBLE_PREC_FLOAT // FP constant in double-precision + VK_NVPTX_HALF_PREC_FLOAT, // FP constant in half-precision + VK_NVPTX_SINGLE_PREC_FLOAT, // FP constant in single-precision + VK_NVPTX_DOUBLE_PREC_FLOAT // FP constant in double-precision }; private: @@ -40,6 +41,11 @@ static const NVPTXFloatMCExpr *create(VariantKind Kind, const APFloat &Flt, MCContext &Ctx); + static const NVPTXFloatMCExpr *createConstantFPHalf(const APFloat &Flt, + MCContext &Ctx) { + return create(VK_NVPTX_HALF_PREC_FLOAT, Flt, Ctx); + } + static const NVPTXFloatMCExpr *createConstantFPSingle(const APFloat &Flt, MCContext &Ctx) { return create(VK_NVPTX_SINGLE_PREC_FLOAT, Flt, Ctx); Index: lib/Target/NVPTX/NVPTXMCExpr.cpp =================================================================== --- lib/Target/NVPTX/NVPTXMCExpr.cpp +++ lib/Target/NVPTX/NVPTXMCExpr.cpp @@ -27,6 +27,13 @@ switch (Kind) { default: llvm_unreachable("Invalid kind!"); + case VK_NVPTX_HALF_PREC_FLOAT: + // ptxas does not have a way to specify half-precision floats. + // Instead we have to print and load fp16 constants as .b16 + OS << "0x"; + NumHex = 4; + APF.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored); + break; case VK_NVPTX_SINGLE_PREC_FLOAT: OS << "0f"; NumHex = 8; Index: lib/Target/NVPTX/NVPTXRegisterInfo.cpp =================================================================== --- lib/Target/NVPTX/NVPTXRegisterInfo.cpp +++ lib/Target/NVPTX/NVPTXRegisterInfo.cpp @@ -30,6 +30,13 @@ if (RC == &NVPTX::Float32RegsRegClass) { return ".f32"; } + if (RC == &NVPTX::Float16RegsRegClass) { + // Ideally fp16 registers should be .f16, but this syntax is only + // supported on sm_53+. On the other hand, .b16 registers are + // accepted for all supported fp16 instructions on all GPU + // variants, so we can use them instead. + return ".b16"; + } if (RC == &NVPTX::Float64RegsRegClass) { return ".f64"; } else if (RC == &NVPTX::Int64RegsRegClass) { @@ -67,9 +74,10 @@ } std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) { - if (RC == &NVPTX::Float32RegsRegClass) { + if (RC == &NVPTX::Float32RegsRegClass) return "%f"; - } + if (RC == &NVPTX::Float16RegsRegClass) + return "%h"; if (RC == &NVPTX::Float64RegsRegClass) { return "%fd"; } else if (RC == &NVPTX::Int64RegsRegClass) { Index: lib/Target/NVPTX/NVPTXRegisterInfo.td =================================================================== --- lib/Target/NVPTX/NVPTXRegisterInfo.td +++ lib/Target/NVPTX/NVPTXRegisterInfo.td @@ -36,6 +36,7 @@ def RS#i : NVPTXReg<"%rs"#i>; // 16-bit def R#i : NVPTXReg<"%r"#i>; // 32-bit def RL#i : NVPTXReg<"%rd"#i>; // 64-bit + def H#i : NVPTXReg<"%h"#i>; // 16-bit float def F#i : NVPTXReg<"%f"#i>; // 32-bit float def FL#i : NVPTXReg<"%fd"#i>; // 64-bit float @@ -57,6 +58,7 @@ def Int16Regs : NVPTXRegClass<[i16], 16, (add (sequence "RS%u", 0, 4))>; def Int32Regs : NVPTXRegClass<[i32], 32, (add (sequence "R%u", 0, 4))>; def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4))>; +def Float16Regs : NVPTXRegClass<[f16], 16, (add (sequence "H%u", 0, 4))>; def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>; def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>; def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>; Index: lib/Target/NVPTX/NVPTXSubtarget.h =================================================================== --- lib/Target/NVPTX/NVPTXSubtarget.h +++ lib/Target/NVPTX/NVPTXSubtarget.h @@ -101,6 +101,8 @@ inline bool hasROT32() const { return hasHWROT32() || hasSWROT32(); } inline bool hasROT64() const { return SmVersion >= 20; } bool hasImageHandles() const; + bool hasFP16() const { return SmVersion >= 53; } + bool allowFP16Math() const; unsigned int getSmVersion() const { return SmVersion; } std::string getTargetName() const { return TargetName; } Index: lib/Target/NVPTX/NVPTXSubtarget.cpp =================================================================== --- lib/Target/NVPTX/NVPTXSubtarget.cpp +++ lib/Target/NVPTX/NVPTXSubtarget.cpp @@ -23,6 +23,11 @@ #define GET_SUBTARGETINFO_CTOR #include "NVPTXGenSubtargetInfo.inc" +static cl::opt + NoF16Math("nvptx-no-f16-math", cl::ZeroOrMore, cl::Hidden, + cl::desc("NVPTX Specific: Disable generation of f16 math ops."), + cl::init(false)); + // Pin the vtable to this file. void NVPTXSubtarget::anchor() {} @@ -57,3 +62,7 @@ // Disabled, otherwise return false; } + +bool NVPTXSubtarget::allowFP16Math() const { + return hasFP16() && NoF16Math == false; +} Index: test/CodeGen/NVPTX/f16-instructions.ll =================================================================== --- /dev/null +++ test/CodeGen/NVPTX/f16-instructions.ll @@ -0,0 +1,1033 @@ +; ## Full FP16 support enabled by default. +; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_53 -asm-verbose=false \ +; RUN: -O0 -disable-post-ra -disable-fp-elim \ +; RUN: | FileCheck -check-prefixes CHECK,CHECK-F16 %s +; ## FP16 support explicitly disabled. +; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_53 -asm-verbose=false \ +; RUN: -O0 -disable-post-ra -disable-fp-elim --nvptx-no-f16-math \ +; RUN: | FileCheck -check-prefixes CHECK,CHECK-NOF16 %s +; ## FP16 is not supported by hardware. +; RUN: llc < %s -O0 -mtriple=nvptx64-nvidia-cuda -mcpu=sm_52 -asm-verbose=false \ +; RUN: -disable-post-ra -disable-fp-elim \ +; RUN: | FileCheck -check-prefixes CHECK,CHECK-NOF16 %s + +target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128" + +; CHECK-LABEL: test_ret_const( +; CHECK: mov.b16 [[R:%h[0-9]+]], 0x3C00; +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_ret_const() #0 { + ret half 1.0 +} + +; CHECK-LABEL: test_fadd( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fadd_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fadd_param_1]; +; CHECK-F16-NEXT: add.rn.f16 [[R:%h[0-9]+]], [[A]], [[B]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[A32:%f[0-9]+]], [[A]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[B32:%f[0-9]+]], [[B]] +; CHECK-NOF16-NEXT: add.rn.f32 [[R32:%f[0-9]+]], [[A32]], [[B32]]; +; CHECK-NOF16-NEXT: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[R32]] +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_fadd(half %a, half %b) #0 { + %r = fadd half %a, %b + ret half %r +} + +; Check that we can lower fadd with immediate arguments. +; CHECK-LABEL: test_fadd_imm_0( +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fadd_imm_0_param_0]; +; CHECK-F16-DAG: mov.b16 [[A:%h[0-9]+]], 0x3C00; +; CHECK-F16-NEXT: add.rn.f16 [[R:%h[0-9]+]], [[B]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[B32:%f[0-9]+]], [[B]] +; CHECK-NOF16-NEXT: add.rn.f32 [[R32:%f[0-9]+]], [[B32]], 0f3F800000; +; CHECK-NOF16-NEXT: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[R32]] +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_fadd_imm_0(half %b) #0 { + %r = fadd half 1.0, %b + ret half %r +} + +; CHECK-LABEL: test_fadd_imm_1( +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fadd_imm_1_param_0]; +; CHECK-F16-DAG: mov.b16 [[A:%h[0-9]+]], 0x3C00; +; CHECK-F16-NEXT: add.rn.f16 [[R:%h[0-9]+]], [[B]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[B32:%f[0-9]+]], [[B]] +; CHECK-NOF16-NEXT: add.rn.f32 [[R32:%f[0-9]+]], [[B32]], 0f3F800000; +; CHECK-NOF16-NEXT: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[R32]] +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_fadd_imm_1(half %a) #0 { + %r = fadd half %a, 1.0 + ret half %r +} + +; CHECK-LABEL: test_fsub( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fsub_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fsub_param_1]; +; CHECK-F16-NEXT: sub.rn.f16 [[R:%h[0-9]+]], [[A]], [[B]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[A32:%f[0-9]+]], [[A]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[B32:%f[0-9]+]], [[B]] +; CHECK-NOF16-NEXT: sub.rn.f32 [[R32:%f[0-9]+]], [[A32]], [[B32]]; +; CHECK-NOF16-NEXT: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[R32]] +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_fsub(half %a, half %b) #0 { + %r = fsub half %a, %b + ret half %r +} + +; CHECK-LABEL: test_fmul( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fmul_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fmul_param_1]; +; CHECK-F16-NEXT: mul.rn.f16 [[R:%h[0-9]+]], [[A]], [[B]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[A32:%f[0-9]+]], [[A]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[B32:%f[0-9]+]], [[B]] +; CHECK-NOF16-NEXT: mul.rn.f32 [[R32:%f[0-9]+]], [[A32]], [[B32]]; +; CHECK-NOF16-NEXT: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[R32]] +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_fmul(half %a, half %b) #0 { + %r = fmul half %a, %b + ret half %r +} + +; CHECK-LABEL: test_fdiv( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fdiv_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fdiv_param_1]; +; CHECK-DAG: cvt.f32.f16 [[F0:%f[0-9]+]], [[A]]; +; CHECK-DAG: cvt.f32.f16 [[F1:%f[0-9]+]], [[B]]; +; CHECK-NEXT: div.rn.f32 [[FR:%f[0-9]+]], [[F0]], [[F1]]; +; CHECK-NEXT: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[FR]]; +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_fdiv(half %a, half %b) #0 { + %r = fdiv half %a, %b + ret half %r +} + +; CHECK-LABEL: test_frem( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_frem_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_frem_param_1]; +; CHECK-DAG: cvt.f32.f16 [[F0:%f[0-9]+]], [[A]]; +; CHECK-DAG: cvt.f32.f16 [[F1:%f[0-9]+]], [[B]]; +; CHECK-NEXT: div.rn.f32 [[F2:%f[0-9]+]], [[F0]], [[F1]]; +; CHECK-NEXT: cvt.rmi.f32.f32 [[F3:%f[0-9]+]], [[F2]]; +; CHECK-NEXT: mul.f32 [[F4:%f[0-9]+]], [[F3]], [[F1]]; +; CHECK-NEXT: sub.f32 [[F5:%f[0-9]+]], [[F0]], [[F4]]; +; CHECK-NEXT: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[F5]]; +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_frem(half %a, half %b) #0 { + %r = frem half %a, %b + ret half %r +} + +; CHECK-LABEL: test_store( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_store_param_0]; +; CHECK-DAG: ld.param.u64 %[[PTR:rd[0-9]+]], [test_store_param_1]; +; CHECK-NEXT: st.b16 [%[[PTR]]], [[A]]; +; CHECK-NEXT: ret; +define void @test_store(half %a, half* %b) #0 { + store half %a, half* %b + ret void +} + +; CHECK-LABEL: test_load( +; CHECK: ld.param.u64 %[[PTR:rd[0-9]+]], [test_load_param_0]; +; CHECK-NEXT: ld.b16 [[R:%h[0-9]+]], [%[[PTR]]]; +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_load(half* %a) #0 { + %r = load half, half* %a + ret half %r +} + +declare half @test_callee(half %a, half %b) #0 + +; CHECK-LABEL: test_call( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_call_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_call_param_1]; +; CHECK: { +; CHECK-DAG: .param .b32 param0; +; CHECK-DAG: .param .b32 param1; +; CHECK-DAG: st.param.b16 [param0+0], [[A]]; +; CHECK-DAG: st.param.b16 [param1+0], [[B]]; +; CHECK-DAG: .param .b32 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_callee, +; CHECK: ); +; CHECK-NEXT: ld.param.b16 [[R:%h[0-9]+]], [retval0+0]; +; CHECK-NEXT: } +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_call(half %a, half %b) #0 { + %r = call half @test_callee(half %a, half %b) + ret half %r +} + +; CHECK-LABEL: test_call_flipped( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_call_flipped_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_call_flipped_param_1]; +; CHECK: { +; CHECK-DAG: .param .b32 param0; +; CHECK-DAG: .param .b32 param1; +; CHECK-DAG: st.param.b16 [param0+0], [[B]]; +; CHECK-DAG: st.param.b16 [param1+0], [[A]]; +; CHECK-DAG: .param .b32 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_callee, +; CHECK: ); +; CHECK-NEXT: ld.param.b16 [[R:%h[0-9]+]], [retval0+0]; +; CHECK-NEXT: } +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_call_flipped(half %a, half %b) #0 { + %r = call half @test_callee(half %b, half %a) + ret half %r +} + +; CHECK-LABEL: test_tailcall_flipped( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_tailcall_flipped_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_tailcall_flipped_param_1]; +; CHECK: { +; CHECK-DAG: .param .b32 param0; +; CHECK-DAG: .param .b32 param1; +; CHECK-DAG: st.param.b16 [param0+0], [[B]]; +; CHECK-DAG: st.param.b16 [param1+0], [[A]]; +; CHECK-DAG: .param .b32 retval0; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_callee, +; CHECK: ); +; CHECK-NEXT: ld.param.b16 [[R:%h[0-9]+]], [retval0+0]; +; CHECK-NEXT: } +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_tailcall_flipped(half %a, half %b) #0 { + %r = tail call half @test_callee(half %b, half %a) + ret half %r +} + +; CHECK-LABEL: test_select( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_select_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_select_param_1]; +; CHECK: setp.eq.b16 [[PRED:%p[0-9]+]], %rs{{.*}}, 1; +; CHECK-NEXT: selp.b16 [[R:%h[0-9]+]], [[A]], [[B]], [[PRED]]; +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_select(half %a, half %b, i1 zeroext %c) #0 { + %r = select i1 %c, half %a, half %b + ret half %r +} + +; CHECK-LABEL: test_select_cc( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_select_cc_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_select_cc_param_1]; +; CHECK-DAG: ld.param.b16 [[C:%h[0-9]+]], [test_select_cc_param_2]; +; CHECK-DAG: ld.param.b16 [[D:%h[0-9]+]], [test_select_cc_param_3]; +; CHECK-F16: setp.neu.f16 [[PRED:%p[0-9]+]], [[C]], [[D]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[DF:%f[0-9]+]], [[D]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[CF:%f[0-9]+]], [[C]]; +; CHECK-NOF16: setp.neu.f32 [[PRED:%p[0-9]+]], [[CF]], [[DF]] +; CHECK: selp.b16 [[R:%h[0-9]+]], [[A]], [[B]], [[PRED]]; +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_select_cc(half %a, half %b, half %c, half %d) #0 { + %cc = fcmp une half %c, %d + %r = select i1 %cc, half %a, half %b + ret half %r +} + +; CHECK-LABEL: test_select_cc_f32_f16( +; CHECK-DAG: ld.param.f32 [[A:%f[0-9]+]], [test_select_cc_f32_f16_param_0]; +; CHECK-DAG: ld.param.f32 [[B:%f[0-9]+]], [test_select_cc_f32_f16_param_1]; +; CHECK-DAG: ld.param.b16 [[C:%h[0-9]+]], [test_select_cc_f32_f16_param_2]; +; CHECK-DAG: ld.param.b16 [[D:%h[0-9]+]], [test_select_cc_f32_f16_param_3]; +; CHECK-F16: setp.neu.f16 [[PRED:%p[0-9]+]], [[C]], [[D]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[DF:%f[0-9]+]], [[D]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[CF:%f[0-9]+]], [[C]]; +; CHECK-NOF16: setp.neu.f32 [[PRED:%p[0-9]+]], [[CF]], [[DF]] +; CHECK-NEXT: selp.f32 [[R:%f[0-9]+]], [[A]], [[B]], [[PRED]]; +; CHECK-NEXT: st.param.f32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define float @test_select_cc_f32_f16(float %a, float %b, half %c, half %d) #0 { + %cc = fcmp une half %c, %d + %r = select i1 %cc, float %a, float %b + ret float %r +} + +; CHECK-LABEL: test_select_cc_f16_f32( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_select_cc_f16_f32_param_0]; +; CHECK-DAG: ld.param.f32 [[C:%f[0-9]+]], [test_select_cc_f16_f32_param_2]; +; CHECK-DAG: ld.param.f32 [[D:%f[0-9]+]], [test_select_cc_f16_f32_param_3]; +; CHECK-DAG: setp.neu.f32 [[PRED:%p[0-9]+]], [[C]], [[D]] +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_select_cc_f16_f32_param_1]; +; CHECK-NEXT: selp.b16 [[R:%h[0-9]+]], [[A]], [[B]], [[PRED]]; +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define half @test_select_cc_f16_f32(half %a, half %b, float %c, float %d) #0 { + %cc = fcmp une float %c, %d + %r = select i1 %cc, half %a, half %b + ret half %r +} + +; CHECK-LABEL: test_fcmp_une( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_une_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_une_param_1]; +; CHECK-F16: setp.neu.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.neu.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_une(half %a, half %b) #0 { + %r = fcmp une half %a, %b + ret i1 %r +} + +; CHECK-LABEL: test_fcmp_ueq( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_ueq_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_ueq_param_1]; +; CHECK-F16: setp.equ.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.equ.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_ueq(half %a, half %b) #0 { + %r = fcmp ueq half %a, %b + ret i1 %r +} + +; CHECK-LABEL: test_fcmp_ugt( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_ugt_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_ugt_param_1]; +; CHECK-F16: setp.gtu.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.gtu.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_ugt(half %a, half %b) #0 { + %r = fcmp ugt half %a, %b + ret i1 %r +} + +; CHECK-LABEL: test_fcmp_uge( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_uge_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_uge_param_1]; +; CHECK-F16: setp.geu.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.geu.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_uge(half %a, half %b) #0 { + %r = fcmp uge half %a, %b + ret i1 %r +} + +; CHECK-LABEL: test_fcmp_ult( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_ult_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_ult_param_1]; +; CHECK-F16: setp.ltu.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.ltu.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_ult(half %a, half %b) #0 { + %r = fcmp ult half %a, %b + ret i1 %r +} + +; CHECK-LABEL: test_fcmp_ule( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_ule_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_ule_param_1]; +; CHECK-F16: setp.leu.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.leu.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_ule(half %a, half %b) #0 { + %r = fcmp ule half %a, %b + ret i1 %r +} + + +; CHECK-LABEL: test_fcmp_uno( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_uno_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_uno_param_1]; +; CHECK-F16: setp.nan.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.nan.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_uno(half %a, half %b) #0 { + %r = fcmp uno half %a, %b + ret i1 %r +} + +; CHECK-LABEL: test_fcmp_one( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_one_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_one_param_1]; +; CHECK-F16: setp.ne.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.ne.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_one(half %a, half %b) #0 { + %r = fcmp one half %a, %b + ret i1 %r +} + +; CHECK-LABEL: test_fcmp_oeq( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_oeq_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_oeq_param_1]; +; CHECK-F16: setp.eq.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.eq.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_oeq(half %a, half %b) #0 { + %r = fcmp oeq half %a, %b + ret i1 %r +} + +; CHECK-LABEL: test_fcmp_ogt( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_ogt_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_ogt_param_1]; +; CHECK-F16: setp.gt.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.gt.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_ogt(half %a, half %b) #0 { + %r = fcmp ogt half %a, %b + ret i1 %r +} + +; CHECK-LABEL: test_fcmp_oge( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_oge_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_oge_param_1]; +; CHECK-F16: setp.ge.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.ge.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_oge(half %a, half %b) #0 { + %r = fcmp oge half %a, %b + ret i1 %r +} + +; XCHECK-LABEL: test_fcmp_olt( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_olt_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_olt_param_1]; +; CHECK-F16: setp.lt.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.lt.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_olt(half %a, half %b) #0 { + %r = fcmp olt half %a, %b + ret i1 %r +} + +; XCHECK-LABEL: test_fcmp_ole( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_ole_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_ole_param_1]; +; CHECK-F16: setp.le.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.le.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_ole(half %a, half %b) #0 { + %r = fcmp ole half %a, %b + ret i1 %r +} + +; CHECK-LABEL: test_fcmp_ord( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fcmp_ord_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fcmp_ord_param_1]; +; CHECK-F16: setp.num.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.num.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: selp.u32 [[R:%r[0-9]+]], 1, 0, [[PRED]]; +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define i1 @test_fcmp_ord(half %a, half %b) #0 { + %r = fcmp ord half %a, %b + ret i1 %r +} + +; CHECK-LABEL: test_br_cc( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_br_cc_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_br_cc_param_1]; +; CHECK-DAG: ld.param.u64 %[[C:rd[0-9]+]], [test_br_cc_param_2]; +; CHECK-DAG: ld.param.u64 %[[D:rd[0-9]+]], [test_br_cc_param_3]; +; CHECK-F16: setp.lt.f16 [[PRED:%p[0-9]+]], [[A]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK-NOF16: setp.lt.f32 [[PRED:%p[0-9]+]], [[AF]], [[BF]] +; CHECK-NEXT: @%p1 bra [[LABEL:LBB.*]]; +; CHECK: st.u32 [%[[C]]], +; CHECK: [[LABEL]]: +; CHECK: st.u32 [%[[D]]], +; CHECK: ret; +define void @test_br_cc(half %a, half %b, i32* %p1, i32* %p2) #0 { + %c = fcmp uge half %a, %b + br i1 %c, label %then, label %else +then: + store i32 0, i32* %p1 + ret void +else: + store i32 0, i32* %p2 + ret void +} + +; CHECK-LABEL: test_phi( +; CHECK: ld.param.u64 %[[P1:rd[0-9]+]], [test_phi_param_0]; +; CHECK: ld.b16 {{%h[0-9]+}}, [%[[P1]]]; +; CHECK: [[LOOP:LBB[0-9_]+]]: +; CHECK: mov.b16 [[R:%h[0-9]+]], [[AB:%h[0-9]+]]; +; CHECK: ld.b16 [[AB:%h[0-9]+]], [%[[P1]]]; +; CHECK: { +; CHECK: st.param.b64 [param0+0], %[[P1]]; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_dummy +; CHECK: } +; CHECK: setp.eq.b32 [[PRED:%p[0-9]+]], %r{{[0-9]+}}, 1; +; CHECK: @[[PRED]] bra [[LOOP]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_phi(half* %p1) #0 { +entry: + %a = load half, half* %p1 + br label %loop +loop: + %r = phi half [%a, %entry], [%b, %loop] + %b = load half, half* %p1 + %c = call i1 @test_dummy(half* %p1) + br i1 %c, label %loop, label %return +return: + ret half %r +} +declare i1 @test_dummy(half* %p1) #0 + +; CHECK-LABEL: test_fptosi_i32( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_fptosi_i32_param_0]; +; CHECK: cvt.rzi.s32.f16 [[R:%r[0-9]+]], [[A]]; +; CHECK: st.param.b32 [func_retval0+0], [[R]]; +; CHECK: ret; +define i32 @test_fptosi_i32(half %a) #0 { + %r = fptosi half %a to i32 + ret i32 %r +} + +; CHECK-LABEL: test_fptosi_i64( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_fptosi_i64_param_0]; +; CHECK: cvt.rzi.s64.f16 [[R:%rd[0-9]+]], [[A]]; +; CHECK: st.param.b64 [func_retval0+0], [[R]]; +; CHECK: ret; +define i64 @test_fptosi_i64(half %a) #0 { + %r = fptosi half %a to i64 + ret i64 %r +} + +; CHECK-LABEL: test_fptoui_i32( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_fptoui_i32_param_0]; +; CHECK: cvt.rzi.u32.f16 [[R:%r[0-9]+]], [[A]]; +; CHECK: st.param.b32 [func_retval0+0], [[R]]; +; CHECK: ret; +define i32 @test_fptoui_i32(half %a) #0 { + %r = fptoui half %a to i32 + ret i32 %r +} + +; CHECK-LABEL: test_fptoui_i64( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_fptoui_i64_param_0]; +; CHECK: cvt.rzi.u64.f16 [[R:%rd[0-9]+]], [[A]]; +; CHECK: st.param.b64 [func_retval0+0], [[R]]; +; CHECK: ret; +define i64 @test_fptoui_i64(half %a) #0 { + %r = fptoui half %a to i64 + ret i64 %r +} + +; CHECK-LABEL: test_uitofp_i32( +; CHECK: ld.param.u32 [[A:%r[0-9]+]], [test_uitofp_i32_param_0]; +; CHECK: cvt.rn.f16.u32 [[R:%h[0-9]+]], [[A]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_uitofp_i32(i32 %a) #0 { + %r = uitofp i32 %a to half + ret half %r +} + +; CHECK-LABEL: test_uitofp_i64( +; CHECK: ld.param.u64 [[A:%rd[0-9]+]], [test_uitofp_i64_param_0]; +; CHECK: cvt.rn.f16.u64 [[R:%h[0-9]+]], [[A]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_uitofp_i64(i64 %a) #0 { + %r = uitofp i64 %a to half + ret half %r +} + +; CHECK-LABEL: test_sitofp_i32( +; CHECK: ld.param.u32 [[A:%r[0-9]+]], [test_sitofp_i32_param_0]; +; CHECK: cvt.rn.f16.s32 [[R:%h[0-9]+]], [[A]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_sitofp_i32(i32 %a) #0 { + %r = sitofp i32 %a to half + ret half %r +} + +; CHECK-LABEL: test_sitofp_i64( +; CHECK: ld.param.u64 [[A:%rd[0-9]+]], [test_sitofp_i64_param_0]; +; CHECK: cvt.rn.f16.s64 [[R:%h[0-9]+]], [[A]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_sitofp_i64(i64 %a) #0 { + %r = sitofp i64 %a to half + ret half %r +} + +; CHECK-LABEL: test_uitofp_i32_fadd( +; CHECK-DAG: ld.param.u32 [[A:%r[0-9]+]], [test_uitofp_i32_fadd_param_0]; +; CHECK-DAG: cvt.rn.f16.u32 [[C:%h[0-9]+]], [[A]]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_uitofp_i32_fadd_param_1]; +; CHECK-F16: add.rn.f16 [[R:%h[0-9]+]], [[B]], [[C]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[B32:%f[0-9]+]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[C32:%f[0-9]+]], [[C]] +; CHECK-NOF16-NEXT: add.rn.f32 [[R32:%f[0-9]+]], [[B32]], [[C32]]; +; CHECK-NOF16-NEXT: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[R32]] +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_uitofp_i32_fadd(i32 %a, half %b) #0 { + %c = uitofp i32 %a to half + %r = fadd half %b, %c + ret half %r +} + +; CHECK-LABEL: test_sitofp_i32_fadd( +; CHECK-DAG: ld.param.u32 [[A:%r[0-9]+]], [test_sitofp_i32_fadd_param_0]; +; CHECK-DAG: cvt.rn.f16.s32 [[C:%h[0-9]+]], [[A]]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_sitofp_i32_fadd_param_1]; +; CHECK-F16: add.rn.f16 [[R:%h[0-9]+]], [[B]], [[C]]; +; XCHECK-NOF16-DAG: cvt.f32.f16 [[B32:%f[0-9]+]], [[B]] +; XCHECK-NOF16-DAG: cvt.f32.f16 [[C32:%f[0-9]+]], [[C]] +; XCHECK-NOF16-NEXT: add.rn.f32 [[R32:%f[0-9]+]], [[B32]], [[C32]]; +; XCHECK-NOF16-NEXT: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[R32]] +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_sitofp_i32_fadd(i32 %a, half %b) #0 { + %c = sitofp i32 %a to half + %r = fadd half %b, %c + ret half %r +} + +; CHECK-LABEL: test_fptrunc_float( +; CHECK: ld.param.f32 [[A:%f[0-9]+]], [test_fptrunc_float_param_0]; +; CHECK: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[A]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_fptrunc_float(float %a) #0 { + %r = fptrunc float %a to half + ret half %r +} + +; CHECK-LABEL: test_fptrunc_double( +; CHECK: ld.param.f64 [[A:%fd[0-9]+]], [test_fptrunc_double_param_0]; +; CHECK: cvt.rn.f16.f64 [[R:%h[0-9]+]], [[A]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_fptrunc_double(double %a) #0 { + %r = fptrunc double %a to half + ret half %r +} + +; CHECK-LABEL: test_fpext_float( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_fpext_float_param_0]; +; CHECK: cvt.f32.f16 [[R:%f[0-9]+]], [[A]]; +; CHECK: st.param.f32 [func_retval0+0], [[R]]; +; CHECK: ret; +define float @test_fpext_float(half %a) #0 { + %r = fpext half %a to float + ret float %r +} + +; CHECK-LABEL: test_fpext_double( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_fpext_double_param_0]; +; CHECK: cvt.f64.f16 [[R:%fd[0-9]+]], [[A]]; +; CHECK: st.param.f64 [func_retval0+0], [[R]]; +; CHECK: ret; +define double @test_fpext_double(half %a) #0 { + %r = fpext half %a to double + ret double %r +} + + +; CHECK-LABEL: test_bitcast_halftoi16( +; CHECK: ld.param.b16 [[AH:%h[0-9]+]], [test_bitcast_halftoi16_param_0]; +; CHECK: mov.b16 [[AS:%rs[0-9]+]], [[AH]] +; CHECK: cvt.u32.u16 [[R:%r[0-9]+]], [[AS]] +; CHECK: st.param.b32 [func_retval0+0], [[R]]; +; CHECK: ret; +define i16 @test_bitcast_halftoi16(half %a) #0 { + %r = bitcast half %a to i16 + ret i16 %r +} + +; CHECK-LABEL: test_bitcast_i16tohalf( +; CHECK: ld.param.u16 [[AS:%rs[0-9]+]], [test_bitcast_i16tohalf_param_0]; +; CHECK: mov.b16 [[AH:%h[0-9]+]], [[AS]] +; CHECK: st.param.b16 [func_retval0+0], [[AH]]; +; CHECK: ret; +define half @test_bitcast_i16tohalf(i16 %a) #0 { + %r = bitcast i16 %a to half + ret half %r +} + + +declare half @llvm.sqrt.f16(half %a) #0 +declare half @llvm.powi.f16(half %a, i32 %b) #0 +declare half @llvm.sin.f16(half %a) #0 +declare half @llvm.cos.f16(half %a) #0 +declare half @llvm.pow.f16(half %a, half %b) #0 +declare half @llvm.exp.f16(half %a) #0 +declare half @llvm.exp2.f16(half %a) #0 +declare half @llvm.log.f16(half %a) #0 +declare half @llvm.log10.f16(half %a) #0 +declare half @llvm.log2.f16(half %a) #0 +declare half @llvm.fma.f16(half %a, half %b, half %c) #0 +declare half @llvm.fabs.f16(half %a) #0 +declare half @llvm.minnum.f16(half %a, half %b) #0 +declare half @llvm.maxnum.f16(half %a, half %b) #0 +declare half @llvm.copysign.f16(half %a, half %b) #0 +declare half @llvm.floor.f16(half %a) #0 +declare half @llvm.ceil.f16(half %a) #0 +declare half @llvm.trunc.f16(half %a) #0 +declare half @llvm.rint.f16(half %a) #0 +declare half @llvm.nearbyint.f16(half %a) #0 +declare half @llvm.round.f16(half %a) #0 +declare half @llvm.fmuladd.f16(half %a, half %b, half %c) #0 + +; CHECK-LABEL: test_sqrt( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_sqrt_param_0]; +; CHECK: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK: sqrt.rn.f32 [[RF:%f[0-9]+]], [[AF]]; +; CHECK: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[RF]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_sqrt(half %a) #0 { + %r = call half @llvm.sqrt.f16(half %a) + ret half %r +} + +;;; Can't do this yet: requires libcall. +; XCHECK-LABEL: test_powi( +;define half @test_powi(half %a, i32 %b) #0 { +; %r = call half @llvm.powi.f16(half %a, i32 %b) +; ret half %r +;} + +; CHECK-LABEL: test_sin( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_sin_param_0]; +; CHECK: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK: sin.approx.f32 [[RF:%f[0-9]+]], [[AF]]; +; CHECK: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[RF]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_sin(half %a) #0 { + %r = call half @llvm.sin.f16(half %a) + ret half %r +} + +; CHECK-LABEL: test_cos( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_cos_param_0]; +; CHECK: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK: cos.approx.f32 [[RF:%f[0-9]+]], [[AF]]; +; CHECK: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[RF]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_cos(half %a) #0 { + %r = call half @llvm.cos.f16(half %a) + ret half %r +} + +;;; Can't do this yet: requires libcall. +; XCHECK-LABEL: test_pow( +;define half @test_pow(half %a, half %b) #0 { +; %r = call half @llvm.pow.f16(half %a, half %b) +; ret half %r +;} + +;;; Can't do this yet: requires libcall. +; XCHECK-LABEL: test_exp( +;define half @test_exp(half %a) #0 { +; %r = call half @llvm.exp.f16(half %a) +; ret half %r +;} + +;;; Can't do this yet: requires libcall. +; XCHECK-LABEL: test_exp2( +;define half @test_exp2(half %a) #0 { +; %r = call half @llvm.exp2.f16(half %a) +; ret half %r +;} + +;;; Can't do this yet: requires libcall. +; XCHECK-LABEL: test_log( +;define half @test_log(half %a) #0 { +; %r = call half @llvm.log.f16(half %a) +; ret half %r +;} + +;;; Can't do this yet: requires libcall. +; XCHECK-LABEL: test_log10( +;define half @test_log10(half %a) #0 { +; %r = call half @llvm.log10.f16(half %a) +; ret half %r +;} + +;;; Can't do this yet: requires libcall. +; XCHECK-LABEL: test_log2( +;define half @test_log2(half %a) #0 { +; %r = call half @llvm.log2.f16(half %a) +; ret half %r +;} + +; CHECK-LABEL: test_fma( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fma_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fma_param_1]; +; CHECK-DAG: ld.param.b16 [[C:%h[0-9]+]], [test_fma_param_2]; +; CHECK-F16: fma.rn.f16 [[R:%h[0-9]+]], [[A]], [[B]], [[C]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[A32:%f[0-9]+]], [[A]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[B32:%f[0-9]+]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[C32:%f[0-9]+]], [[C]] +; CHECK-NOF16-NEXT: fma.rn.f32 [[R32:%f[0-9]+]], [[A32]], [[B32]], [[C32]]; +; CHECK-NOF16-NEXT: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[R32]] +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret +define half @test_fma(half %a, half %b, half %c) #0 { + %r = call half @llvm.fma.f16(half %a, half %b, half %c) + ret half %r +} + +; CHECK-LABEL: test_fabs( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_fabs_param_0]; +; CHECK: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK: abs.f32 [[RF:%f[0-9]+]], [[AF]]; +; CHECK: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[RF]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_fabs(half %a) #0 { + %r = call half @llvm.fabs.f16(half %a) + ret half %r +} + +; CHECK-LABEL: test_minnum( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_minnum_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_minnum_param_1]; +; CHECK-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK: min.f32 [[RF:%f[0-9]+]], [[AF]], [[BF]]; +; CHECK: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[RF]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_minnum(half %a, half %b) #0 { + %r = call half @llvm.minnum.f16(half %a, half %b) + ret half %r +} + +; CHECK-LABEL: test_maxnum( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_maxnum_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_maxnum_param_1]; +; CHECK-DAG: cvt.f32.f16 [[AF:%f[0-9]+]], [[A]]; +; CHECK-DAG: cvt.f32.f16 [[BF:%f[0-9]+]], [[B]]; +; CHECK: max.f32 [[RF:%f[0-9]+]], [[AF]], [[BF]]; +; CHECK: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[RF]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_maxnum(half %a, half %b) #0 { + %r = call half @llvm.maxnum.f16(half %a, half %b) + ret half %r +} + +; CHECK-LABEL: test_copysign( +; CHECK-DAG: ld.param.b16 [[AH:%h[0-9]+]], [test_copysign_param_0]; +; CHECK-DAG: ld.param.b16 [[BH:%h[0-9]+]], [test_copysign_param_1]; +; CHECK-DAG: mov.b16 [[AS:%rs[0-9]+]], [[AH]]; +; CHECK-DAG: mov.b16 [[BS:%rs[0-9]+]], [[BH]]; +; CHECK-DAG: and.b16 [[AX:%rs[0-9]+]], [[AS]], 32767; +; CHECK-DAG: and.b16 [[BX:%rs[0-9]+]], [[BS]], -32768; +; CHECK: or.b16 [[RX:%rs[0-9]+]], [[AX]], [[BX]]; +; CHECK: mov.b16 [[R:%h[0-9]+]], [[RX]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_copysign(half %a, half %b) #0 { + %r = call half @llvm.copysign.f16(half %a, half %b) + ret half %r +} + +; CHECK-LABEL: test_copysign_f32( +; CHECK-DAG: ld.param.b16 [[AH:%h[0-9]+]], [test_copysign_f32_param_0]; +; CHECK-DAG: ld.param.f32 [[BF:%f[0-9]+]], [test_copysign_f32_param_1]; +; CHECK-DAG: mov.b16 [[A:%rs[0-9]+]], [[AH]]; +; CHECK-DAG: mov.b32 [[B:%r[0-9]+]], [[BF]]; +; CHECK-DAG: and.b16 [[AX:%rs[0-9]+]], [[A]], 32767; +; CHECK-DAG: and.b32 [[BX0:%r[0-9]+]], [[B]], -2147483648; +; CHECK-DAG: shr.u32 [[BX1:%r[0-9]+]], [[BX0]], 16; +; CHECK-DAG: cvt.u16.u32 [[BX2:%rs[0-9]+]], [[BX1]]; +; CHECK: or.b16 [[RX:%rs[0-9]+]], [[AX]], [[BX2]]; +; CHECK: mov.b16 [[R:%h[0-9]+]], [[RX]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_copysign_f32(half %a, float %b) #0 { + %tb = fptrunc float %b to half + %r = call half @llvm.copysign.f16(half %a, half %tb) + ret half %r +} + +; CHECK-LABEL: test_copysign_f64( +; CHECK-DAG: ld.param.b16 [[AH:%h[0-9]+]], [test_copysign_f64_param_0]; +; CHECK-DAG: ld.param.f64 [[BD:%fd[0-9]+]], [test_copysign_f64_param_1]; +; CHECK-DAG: mov.b16 [[A:%rs[0-9]+]], [[AH]]; +; CHECK-DAG: mov.b64 [[B:%rd[0-9]+]], [[BD]]; +; CHECK-DAG: and.b16 [[AX:%rs[0-9]+]], [[A]], 32767; +; CHECK-DAG: and.b64 [[BX0:%rd[0-9]+]], [[B]], -9223372036854775808; +; CHECK-DAG: shr.u64 [[BX1:%rd[0-9]+]], [[BX0]], 48; +; CHECK-DAG: cvt.u16.u64 [[BX2:%rs[0-9]+]], [[BX1]]; +; CHECK: or.b16 [[RX:%rs[0-9]+]], [[AX]], [[BX2]]; +; CHECK: mov.b16 [[R:%h[0-9]+]], [[RX]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_copysign_f64(half %a, double %b) #0 { + %tb = fptrunc double %b to half + %r = call half @llvm.copysign.f16(half %a, half %tb) + ret half %r +} + +; CHECK-LABEL: test_copysign_extended( +; CHECK-DAG: ld.param.b16 [[AH:%h[0-9]+]], [test_copysign_extended_param_0]; +; CHECK-DAG: ld.param.b16 [[BH:%h[0-9]+]], [test_copysign_extended_param_1]; +; CHECK-DAG: mov.b16 [[AS:%rs[0-9]+]], [[AH]]; +; CHECK-DAG: mov.b16 [[BS:%rs[0-9]+]], [[BH]]; +; CHECK-DAG: and.b16 [[AX:%rs[0-9]+]], [[AS]], 32767; +; CHECK-DAG: and.b16 [[BX:%rs[0-9]+]], [[BS]], -32768; +; CHECK: or.b16 [[RX:%rs[0-9]+]], [[AX]], [[BX]]; +; CHECK: mov.b16 [[R:%h[0-9]+]], [[RX]]; +; CHECK: cvt.f32.f16 [[XR:%f[0-9]+]], [[R]]; +; CHECK: st.param.f32 [func_retval0+0], [[XR]]; +; CHECK: ret; +define float @test_copysign_extended(half %a, half %b) #0 { + %r = call half @llvm.copysign.f16(half %a, half %b) + %xr = fpext half %r to float + ret float %xr +} + +; CHECK-LABEL: test_floor( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_floor_param_0]; +; CHECK: cvt.rmi.f16.f16 [[R:%h[0-9]+]], [[A]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_floor(half %a) #0 { + %r = call half @llvm.floor.f16(half %a) + ret half %r +} + +; CHECK-LABEL: test_ceil( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_ceil_param_0]; +; CHECK: cvt.rpi.f16.f16 [[R:%h[0-9]+]], [[A]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_ceil(half %a) #0 { + %r = call half @llvm.ceil.f16(half %a) + ret half %r +} + +; CHECK-LABEL: test_trunc( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_trunc_param_0]; +; CHECK: cvt.rzi.f16.f16 [[R:%h[0-9]+]], [[A]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_trunc(half %a) #0 { + %r = call half @llvm.trunc.f16(half %a) + ret half %r +} + +; CHECK-LABEL: test_rint( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_rint_param_0]; +; CHECK: cvt.rni.f16.f16 [[R:%h[0-9]+]], [[A]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_rint(half %a) #0 { + %r = call half @llvm.rint.f16(half %a) + ret half %r +} + +; CHECK-LABEL: test_nearbyint( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_nearbyint_param_0]; +; CHECK: cvt.rni.f16.f16 [[R:%h[0-9]+]], [[A]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_nearbyint(half %a) #0 { + %r = call half @llvm.nearbyint.f16(half %a) + ret half %r +} + +; CHECK-LABEL: test_round( +; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_round_param_0]; +; CHECK: cvt.rni.f16.f16 [[R:%h[0-9]+]], [[A]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_round(half %a) #0 { + %r = call half @llvm.round.f16(half %a) + ret half %r +} + +; CHECK-LABEL: test_fmuladd( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fmuladd_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fmuladd_param_1]; +; CHECK-DAG: ld.param.b16 [[C:%h[0-9]+]], [test_fmuladd_param_2]; +; CHECK-F16: fma.rn.f16 [[R:%h[0-9]+]], [[A]], [[B]], [[C]]; +; CHECK-NOF16-DAG: cvt.f32.f16 [[A32:%f[0-9]+]], [[A]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[B32:%f[0-9]+]], [[B]] +; CHECK-NOF16-DAG: cvt.f32.f16 [[C32:%f[0-9]+]], [[C]] +; CHECK-NOF16-NEXT: fma.rn.f32 [[R32:%f[0-9]+]], [[A32]], [[B32]], [[C32]]; +; CHECK-NOF16-NEXT: cvt.rn.f16.f32 [[R:%h[0-9]+]], [[R32]] +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define half @test_fmuladd(half %a, half %b, half %c) #0 { + %r = call half @llvm.fmuladd.f16(half %a, half %b, half %c) + ret half %r +} + +attributes #0 = { nounwind } Index: test/CodeGen/NVPTX/half.ll =================================================================== --- test/CodeGen/NVPTX/half.ll +++ test/CodeGen/NVPTX/half.ll @@ -2,8 +2,8 @@ define void @test_load_store(half addrspace(1)* %in, half addrspace(1)* %out) { ; CHECK-LABEL: @test_load_store -; CHECK: ld.global.u16 [[TMP:%rs[0-9]+]], [{{%r[0-9]+}}] -; CHECK: st.global.u16 [{{%r[0-9]+}}], [[TMP]] +; CHECK: ld.global.b16 [[TMP:%h[0-9]+]], [{{%r[0-9]+}}] +; CHECK: st.global.b16 [{{%r[0-9]+}}], [[TMP]] %val = load half, half addrspace(1)* %in store half %val, half addrspace(1) * %out ret void @@ -11,8 +11,8 @@ define void @test_bitcast_from_half(half addrspace(1)* %in, i16 addrspace(1)* %out) { ; CHECK-LABEL: @test_bitcast_from_half -; CHECK: ld.global.u16 [[TMP:%rs[0-9]+]], [{{%r[0-9]+}}] -; CHECK: st.global.u16 [{{%r[0-9]+}}], [[TMP]] +; CHECK: ld.global.b16 [[TMP:%h[0-9]+]], [{{%r[0-9]+}}] +; CHECK: st.global.b16 [{{%r[0-9]+}}], [[TMP]] %val = load half, half addrspace(1) * %in %val_int = bitcast half %val to i16 store i16 %val_int, i16 addrspace(1)* %out