Index: llvm/lib/CodeGen/TargetLoweringBase.cpp =================================================================== --- llvm/lib/CodeGen/TargetLoweringBase.cpp +++ llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -909,9 +909,9 @@ // ConstantFP nodes default to expand. Targets can either change this to // Legal, in which case all fp constants are legal, or use isFPImmLegal() // to optimize expansions for certain constants. - setOperationAction(ISD::ConstantFP, - {MVT::f16, MVT::f32, MVT::f64, MVT::f80, MVT::f128}, - Expand); + setOperationAction( + ISD::ConstantFP, + {MVT::f16, MVT::bf16, MVT::f32, MVT::f64, MVT::f80, MVT::f128}, Expand); // These library functions default to expand. setOperationAction({ISD::FCBRT, ISD::FLOG, ISD::FLOG2, ISD::FLOG10, ISD::FEXP, Index: llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp =================================================================== --- llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp +++ llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp @@ -1234,9 +1234,9 @@ Op.Reg.RegNum = convertFPR64ToFPR32(Reg); return Match_Success; } - // As the parser couldn't differentiate an FPR16 from an FPR64, coerce the - // register from FPR64 to FPR16 if necessary. - if (IsRegFPR64 && Kind == MCK_FPR16) { + // As the parser couldn't differentiate an [B]FPR16 from an FPR64, coerce the + // register from FPR64 to (B)FPR16 if necessary. + if (IsRegFPR64 && Kind == MCK_BFPR16) { Op.Reg.RegNum = convertFPR64ToFPR16(Reg); return Match_Success; } Index: llvm/lib/Target/RISCV/RISCVISelLowering.h =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.h +++ llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -111,6 +111,9 @@ FCVT_W_RV64, FCVT_WU_RV64, + FP_ROUND_BF16, + FP_EXTEND_BF16, + // Rounds an FP value to its corresponding integer in the same FP format. // First operand is the value to round, the second operand is the largest // integer that can be represented exactly in the FP format. This will be Index: llvm/lib/Target/RISCV/RISCVISelLowering.cpp =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -116,6 +116,8 @@ if (Subtarget.hasStdExtZfhOrZfhmin()) addRegisterClass(MVT::f16, &RISCV::FPR16RegClass); + if (Subtarget.hasStdExtZfbfmin()) + addRegisterClass(MVT::bf16, &RISCV::BFPR16RegClass); if (Subtarget.hasStdExtF()) addRegisterClass(MVT::f32, &RISCV::FPR32RegClass); if (Subtarget.hasStdExtD()) @@ -361,6 +363,14 @@ if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) setOperationAction(ISD::BITCAST, MVT::i16, Custom); + if (Subtarget.hasStdExtZfbfmin()) { + setOperationAction(ISD::BITCAST, MVT::i16, Custom); + setOperationAction(ISD::BITCAST, MVT::bf16, Custom); + setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom); + } + if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) { if (Subtarget.hasStdExtZfhOrZhinx()) { setOperationAction(FPLegalNodeTypes, MVT::f16, Legal); @@ -1832,6 +1842,8 @@ bool IsLegalVT = false; if (VT == MVT::f16) IsLegalVT = Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin(); + else if (VT == MVT::bf16) + IsLegalVT = Subtarget.hasStdExtZfbfmin(); else if (VT == MVT::f32) IsLegalVT = Subtarget.hasStdExtFOrZfinx(); else if (VT == MVT::f64) @@ -4602,6 +4614,12 @@ SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, NewOp0); return FPConv; } + if (VT == MVT::bf16 && Op0VT == MVT::i16 && + Subtarget.hasStdExtZfbfmin()) { + SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0); + SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::bf16, NewOp0); + return FPConv; + } if (VT == MVT::f32 && Op0VT == MVT::i32 && Subtarget.is64Bit() && Subtarget.hasStdExtFOrZfinx()) { SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op0); @@ -4765,11 +4783,42 @@ } return SDValue(); } - case ISD::FP_EXTEND: - case ISD::FP_ROUND: + case ISD::FP_EXTEND: { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + SDValue Op0 = Op.getOperand(0); + EVT Op0VT = Op0.getValueType(); + if (VT == MVT::f32 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) + return DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0); + if (VT == MVT::f64 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) { + SDValue FloatVal = + DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0); + return DAG.getNode(ISD::FP_EXTEND, DL, MVT::f64, FloatVal); + } + + if (!Op.getValueType().isVector()) + return Op; + return lowerVectorFPExtendOrRoundLike(Op, DAG); + } + case ISD::FP_ROUND: { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + SDValue Op0 = Op.getOperand(0); + EVT Op0VT = Op0.getValueType(); + if (VT == MVT::bf16 && Op0VT == MVT::f32 && Subtarget.hasStdExtZfbfmin()) + return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, Op0); + if (VT == MVT::bf16 && Op0VT == MVT::f64 && Subtarget.hasStdExtZfbfmin() && + Subtarget.hasStdExtDOrZdinx()) { + SDValue FloatVal = + DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Op0, + DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); + return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, FloatVal); + } + if (!Op.getValueType().isVector()) return Op; return lowerVectorFPExtendOrRoundLike(Op, DAG); + } case ISD::STRICT_FP_ROUND: case ISD::STRICT_FP_EXTEND: return lowerStrictFPExtendOrRoundLike(Op, DAG); @@ -9527,6 +9576,10 @@ Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) { SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0); Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv)); + } else if (VT == MVT::i16 && Op0VT == MVT::bf16 && + Subtarget.hasStdExtZfbfmin()) { + SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv)); } else if (VT == MVT::i32 && Op0VT == MVT::f32 && Subtarget.is64Bit() && Subtarget.hasStdExtFOrZfinx()) { SDValue FPConv = @@ -14112,6 +14165,10 @@ RISCV::F10_H, RISCV::F11_H, RISCV::F12_H, RISCV::F13_H, RISCV::F14_H, RISCV::F15_H, RISCV::F16_H, RISCV::F17_H }; +static const MCPhysReg ArgBFPR16s[] = { + RISCV::F10_H, RISCV::F11_H, RISCV::F12_H, RISCV::F13_H, + RISCV::F14_H, RISCV::F15_H, RISCV::F16_H, RISCV::F17_H +}; static const MCPhysReg ArgFPR32s[] = { RISCV::F10_F, RISCV::F11_F, RISCV::F12_F, RISCV::F13_F, RISCV::F14_F, RISCV::F15_F, RISCV::F16_F, RISCV::F17_F @@ -14250,7 +14307,8 @@ // similar local variables rather than directly checking against the target // ABI. - if (UseGPRForF16_F32 && (ValVT == MVT::f16 || ValVT == MVT::f32)) { + if (UseGPRForF16_F32 && + (ValVT == MVT::f16 || ValVT == MVT::bf16 || ValVT == MVT::f32)) { LocVT = XLenVT; LocInfo = CCValAssign::BCvt; } else if (UseGPRForF64 && XLen == 64 && ValVT == MVT::f64) { @@ -14345,6 +14403,8 @@ if (ValVT == MVT::f16 && !UseGPRForF16_F32) Reg = State.AllocateReg(ArgFPR16s); + else if (ValVT == MVT::bf16 && !UseGPRForF16_F32) + Reg = State.AllocateReg(ArgBFPR16s); else if (ValVT == MVT::f32 && !UseGPRForF16_F32) Reg = State.AllocateReg(ArgFPR32s); else if (ValVT == MVT::f64 && !UseGPRForF64) @@ -14500,8 +14560,9 @@ Val = convertFromScalableVector(VA.getValVT(), Val, DAG, Subtarget); break; case CCValAssign::BCvt: - if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16) - Val = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, Val); + if (VA.getLocVT().isInteger() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) + Val = DAG.getNode(RISCVISD::FMV_H_X, DL, VA.getValVT(), Val); else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) Val = DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Val); else @@ -14559,7 +14620,8 @@ Val = convertToScalableVector(LocVT, Val, DAG, Subtarget); break; case CCValAssign::BCvt: - if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16) + if (VA.getLocVT().isInteger() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, VA.getLocVT(), Val); else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Val); @@ -15551,6 +15613,8 @@ NODE_NAME_CASE(FCVT_WU_RV64) NODE_NAME_CASE(STRICT_FCVT_W_RV64) NODE_NAME_CASE(STRICT_FCVT_WU_RV64) + NODE_NAME_CASE(FP_ROUND_BF16) + NODE_NAME_CASE(FP_EXTEND_BF16) NODE_NAME_CASE(FROUND) NODE_NAME_CASE(FPCLASS) NODE_NAME_CASE(READ_CYCLE_WIDE) Index: llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td =================================================================== --- llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td +++ llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td @@ -13,6 +13,20 @@ // //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// RISC-V specific DAG Nodes. +//===----------------------------------------------------------------------===// + +def SDT_RISCVFP_ROUND_BF16 + : SDTypeProfile<1, 1, [SDTCisVT<0, f16>, SDTCisVT<1, f32>]>; +def SDT_RISCVFP_EXTEND_BF16 + : SDTypeProfile<1, 1, [SDTCisVT<0, f32>, SDTCisVT<1, f16>]>; + +def riscv_fpround_bf16 + : SDNode<"RISCVISD::FP_ROUND_BF16", SDT_RISCVFP_ROUND_BF16>; +def riscv_fpextend_bf16 + : SDNode<"RISCVISD::FP_EXTEND_BF16", SDT_RISCVFP_EXTEND_BF16>; + //===----------------------------------------------------------------------===// // Instructions //===----------------------------------------------------------------------===// @@ -23,3 +37,12 @@ def FCVT_S_BF16 : FPUnaryOp_r_frm<0b0100000, 0b00110, FPR32, FPR16, "fcvt.s.bf16">, Sched<[WriteFCvtF32ToF16, ReadFCvtF32ToF16]>; } // Predicates = [HasStdExtZfbfmin] + +//===----------------------------------------------------------------------===// +// Pseudo-instructions and codegen patterns +//===----------------------------------------------------------------------===// + +let Predicates = [HasStdExtZfbfmin] in { +def : Pat<(riscv_fpround_bf16 FPR32:$rs1), (FCVT_BF16_S FPR32:$rs1, FRM_DYN)>; +def : Pat<(riscv_fpextend_bf16 FPR16:$rs1), (FCVT_S_BF16 FPR16:$rs1, FRM_DYN)>; +} // Predicates = [HasStdExtZfbfmin] Index: llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td =================================================================== --- llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td +++ llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td @@ -424,6 +424,20 @@ def : Pat<(fcopysign FPR32:$rs1, FPR16:$rs2), (FSGNJ_S $rs1, (FCVT_S_H $rs2))>; } // Predicates = [HasStdExtZfhOrZfhmin] +let Predicates = [HasStdExtZfbfmin] in { +/// Loads +def : Pat<(bf16 (load GPR:$rs1)), (COPY_TO_REGCLASS (FLH GPR:$rs1, 0), BFPR16)>; + +/// Stores +def : Pat<(store (bf16 BFPR16:$rs2), GPR:$rs1), + (FSH (COPY_TO_REGCLASS BFPR16:$rs2, FPR16), GPR:$rs1, 0)>; + +// Moves (no conversion) +def : Pat<(riscv_fmv_h_x GPR:$src), (FMV_H_X GPR:$src)>; +def : Pat<(riscv_fmv_x_anyexth FPR16:$src), (FMV_X_H FPR16:$src)>; +def : Pat<(riscv_fmv_x_signexth FPR16:$src), (FMV_X_H FPR16:$src)>; +} + let Predicates = [HasStdExtZhinxOrZhinxmin] in { /// Float conversion operations Index: llvm/lib/Target/RISCV/RISCVRegisterInfo.td =================================================================== --- llvm/lib/Target/RISCV/RISCVRegisterInfo.td +++ llvm/lib/Target/RISCV/RISCVRegisterInfo.td @@ -242,6 +242,15 @@ (sequence "F%u_H", 18, 27) // fs2-fs11 )>; +def BFPR16 : RegisterClass<"RISCV", [bf16], 16, (add + (sequence "F%u_H", 15, 10), // fa5-fa0 + (sequence "F%u_H", 0, 7), // ft0-f7 + (sequence "F%u_H", 16, 17), // fa6-fa7 + (sequence "F%u_H", 28, 31), // ft8-ft11 + (sequence "F%u_H", 8, 9), // fs0-fs1 + (sequence "F%u_H", 18, 27) // fs2-fs11 +)>; + def FPR32 : RegisterClass<"RISCV", [f32], 32, (add (sequence "F%u_F", 15, 10), (sequence "F%u_F", 0, 7), Index: llvm/test/CodeGen/RISCV/zfbfmin.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/RISCV/zfbfmin.ll @@ -0,0 +1,83 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zfbfmin -verify-machineinstrs \ +; RUN: -target-abi ilp32d < %s | FileCheck -check-prefix=CHECKIZFBFMIN %s +; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zfbfmin -verify-machineinstrs \ +; RUN: -target-abi lp64d < %s | FileCheck -check-prefix=CHECKIZFBFMIN %s + +define bfloat @bitcast_bf16_i16(i16 %a) nounwind { +; CHECKIZFBFMIN-LABEL: bitcast_bf16_i16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fmv.h.x fa0, a0 +; CHECKIZFBFMIN-NEXT: ret + %1 = bitcast i16 %a to bfloat + ret bfloat %1 +} + +define i16 @bitcast_i16_bf16(bfloat %a) nounwind { +; CHECKIZFBFMIN-LABEL: bitcast_i16_bf16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fmv.x.h a0, fa0 +; CHECKIZFBFMIN-NEXT: ret + %1 = bitcast bfloat %a to i16 + ret i16 %1 +} + +define bfloat @fcvt_bf16_s(float %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_bf16_s: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.bf16.s fa0, fa0 +; CHECKIZFBFMIN-NEXT: ret + %1 = fptrunc float %a to bfloat + ret bfloat %1 +} + +define float @fcvt_s_bf16(bfloat %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_s_bf16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.s.bf16 fa0, fa0 +; CHECKIZFBFMIN-NEXT: ret + %1 = fpext bfloat %a to float + ret float %1 +} + +define bfloat @fcvt_bf16_d(double %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_bf16_d: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.s.d fa5, fa0 +; CHECKIZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5 +; CHECKIZFBFMIN-NEXT: ret + %1 = fptrunc double %a to bfloat + ret bfloat %1 +} + +define double @fcvt_d_bf16(bfloat %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_d_bf16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.s.bf16 fa5, fa0 +; CHECKIZFBFMIN-NEXT: fcvt.d.s fa0, fa5 +; CHECKIZFBFMIN-NEXT: ret + %1 = fpext bfloat %a to double + ret double %1 +} + +define bfloat @bfloat_imm() nounwind { +; CHECKIZFBFMIN-LABEL: bfloat_imm: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: lui a0, %hi(.LCPI6_0) +; CHECKIZFBFMIN-NEXT: flh fa0, %lo(.LCPI6_0)(a0) +; CHECKIZFBFMIN-NEXT: ret + ret bfloat 3.0 +} + +define dso_local void @bfloat_store(ptr %a, bfloat %b) nounwind { +; CHECKIZFBFMIN-LABEL: bfloat_store: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fsh fa0, 0(a0) +; CHECKIZFBFMIN-NEXT: addi a0, a0, 16 +; CHECKIZFBFMIN-NEXT: fsh fa0, 0(a0) +; CHECKIZFBFMIN-NEXT: ret + store bfloat %b, ptr %a + %1 = getelementptr bfloat, ptr %a, i32 8 + store bfloat %b, ptr %1 + ret void +}