Index: llvm/lib/Target/RISCV/RISCVISelLowering.h =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.h +++ llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -829,6 +829,8 @@ unsigned ExtendOpc) const; SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const; SDValue lowerSET_ROUNDING(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerBF16_TO_FP(SDValue Op, SelectionDAG &DAG) const; SDValue lowerEH_DWARF_CFA(SDValue Op, SelectionDAG &DAG) const; SDValue lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op, SelectionDAG &DAG) const; Index: llvm/lib/Target/RISCV/RISCVISelLowering.cpp =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -426,7 +426,11 @@ setOperationAction(ISD::BR_CC, MVT::f32, Expand); setOperationAction(FPOpToExpand, MVT::f32, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand); setTruncStoreAction(MVT::f32, MVT::f16, Expand); + setTruncStoreAction(MVT::f32, MVT::bf16, Expand); + setOperationAction(ISD::FP_TO_BF16, MVT::f32, Custom); + setOperationAction(ISD::BF16_TO_FP, MVT::f32, Custom); setOperationAction(ISD::IS_FPCLASS, MVT::f32, Custom); if (Subtarget.hasStdExtZfa()) @@ -461,6 +465,10 @@ setOperationAction(FPOpToExpand, MVT::f64, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand); setTruncStoreAction(MVT::f64, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand); + setTruncStoreAction(MVT::f64, MVT::bf16, Expand); + setOperationAction(ISD::FP_TO_BF16, MVT::f64, Custom); + setOperationAction(ISD::BF16_TO_FP, MVT::f64, Expand); setOperationAction(ISD::IS_FPCLASS, MVT::f64, Custom); } @@ -1154,6 +1162,7 @@ setLibcallName(RTLIB::FPEXT_F16_F32, "__extendhfsf2"); setLibcallName(RTLIB::FPROUND_F32_F16, "__truncsfhf2"); + setLibcallName(RTLIB::FPROUND_F32_BF16, "__truncsfhf2"); // Disable strict node mutation. IsStrictFPEnabled = true; @@ -2424,6 +2433,33 @@ return Res; } +SDValue RISCVTargetLowering::lowerFP_TO_BF16(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + MakeLibCallOptions CallOptions; + RTLIB::Libcall LC = + RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16); + SDValue Res = + makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first; + if (Subtarget.is64Bit()) + return DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Res); + return DAG.getBitcast(MVT::i32, Res); +} + +SDValue RISCVTargetLowering::lowerBF16_TO_FP(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + MakeLibCallOptions CallOptions; + SDValue Arg = Subtarget.is64Bit() + ? DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, + Op.getOperand(0)) + : DAG.getBitcast(MVT::f32, Op.getOperand(0)); + SDValue Res = + makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg, CallOptions, DL) + .first; + return Res; +} + static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) { switch (Opc) { case ISD::FROUNDEVEN: @@ -4751,6 +4787,10 @@ case ISD::FP_TO_SINT_SAT: case ISD::FP_TO_UINT_SAT: return lowerFP_TO_INT_SAT(Op, DAG, Subtarget); + case ISD::FP_TO_BF16: + return lowerFP_TO_BF16(Op, DAG); + case ISD::BF16_TO_FP: + return lowerBF16_TO_FP(Op, DAG); case ISD::FTRUNC: case ISD::FCEIL: case ISD::FFLOOR: @@ -15980,8 +16020,9 @@ unsigned NumParts, MVT PartVT, std::optional CC) const { bool IsABIRegCopy = CC.has_value(); EVT ValueVT = Val.getValueType(); - if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) { - // Cast the f16 to i16, extend to i32, pad with ones to make a float nan, + if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) && + PartVT == MVT::f32) { + // Cast the f16/bf16 to i16, extend to i32, pad with ones to make a float nan, // and cast to f32. Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val); Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val); @@ -16033,13 +16074,14 @@ SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts, MVT PartVT, EVT ValueVT, std::optional CC) const { bool IsABIRegCopy = CC.has_value(); - if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) { + if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) && + PartVT == MVT::f32) { SDValue Val = Parts[0]; - // Cast the f32 to i32, truncate to i16, and cast back to f16. + // Cast the f32 to i32, truncate to i16, and cast back to f16/bf16. Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val); Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val); - Val = DAG.getNode(ISD::BITCAST, DL, MVT::f16, Val); + Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); return Val; }