Index: llvm/lib/Target/RISCV/RISCVISelLowering.h =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.h +++ llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -111,6 +111,9 @@ FCVT_W_RV64, FCVT_WU_RV64, + FP_ROUND_BF16, + FP_EXTEND_BF16, + // Rounds an FP value to its corresponding integer in the same FP format. // First operand is the value to round, the second operand is the largest // integer that can be represented exactly in the FP format. This will be Index: llvm/lib/Target/RISCV/RISCVISelLowering.cpp =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -116,6 +116,8 @@ if (Subtarget.hasStdExtZfhOrZfhmin()) addRegisterClass(MVT::f16, &RISCV::FPR16RegClass); + if (Subtarget.hasStdExtZfbfmin()) + addRegisterClass(MVT::bf16, &RISCV::FPR16RegClass); if (Subtarget.hasStdExtF()) addRegisterClass(MVT::f32, &RISCV::FPR32RegClass); if (Subtarget.hasStdExtD()) @@ -359,6 +361,15 @@ if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) setOperationAction(ISD::BITCAST, MVT::i16, Custom); + + if (Subtarget.hasStdExtZfbfmin()) { + setOperationAction(ISD::BITCAST, MVT::i16, Custom); + setOperationAction(ISD::BITCAST, MVT::bf16, Custom); + setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom); + setOperationAction(ISD::ConstantFP, MVT::bf16, Expand); + } if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) { if (Subtarget.hasStdExtZfhOrZhinx()) { @@ -4769,6 +4780,12 @@ SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, NewOp0); return FPConv; } + if (VT == MVT::bf16 && Op0VT == MVT::i16 && + Subtarget.hasStdExtZfbfmin()) { + SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0); + SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::bf16, NewOp0); + return FPConv; + } if (VT == MVT::f32 && Op0VT == MVT::i32 && Subtarget.is64Bit() && Subtarget.hasStdExtFOrZfinx()) { SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op0); @@ -4932,11 +4949,42 @@ } return SDValue(); } - case ISD::FP_EXTEND: - case ISD::FP_ROUND: + case ISD::FP_EXTEND: { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + SDValue Op0 = Op.getOperand(0); + EVT Op0VT = Op0.getValueType(); + if (VT == MVT::f32 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) + return DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0); + if (VT == MVT::f64 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) { + SDValue FloatVal = + DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0); + return DAG.getNode(ISD::FP_EXTEND, DL, MVT::f64, FloatVal); + } + + if (!Op.getValueType().isVector()) + return Op; + return lowerVectorFPExtendOrRoundLike(Op, DAG); + } + case ISD::FP_ROUND: { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + SDValue Op0 = Op.getOperand(0); + EVT Op0VT = Op0.getValueType(); + if (VT == MVT::bf16 && Op0VT == MVT::f32 && Subtarget.hasStdExtZfbfmin()) + return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, Op0); + if (VT == MVT::bf16 && Op0VT == MVT::f64 && Subtarget.hasStdExtZfbfmin() && + Subtarget.hasStdExtDOrZdinx()) { + SDValue FloatVal = + DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Op0, + DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); + return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, FloatVal); + } + if (!Op.getValueType().isVector()) return Op; return lowerVectorFPExtendOrRoundLike(Op, DAG); + } case ISD::STRICT_FP_ROUND: case ISD::STRICT_FP_EXTEND: return lowerStrictFPExtendOrRoundLike(Op, DAG); @@ -9927,6 +9975,10 @@ Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) { SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0); Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv)); + } else if (VT == MVT::i16 && Op0VT == MVT::bf16 && + Subtarget.hasStdExtZfbfmin()) { + SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv)); } else if (VT == MVT::i32 && Op0VT == MVT::f32 && Subtarget.is64Bit() && Subtarget.hasStdExtFOrZfinx()) { SDValue FPConv = @@ -14868,7 +14920,8 @@ // similar local variables rather than directly checking against the target // ABI. - if (UseGPRForF16_F32 && (ValVT == MVT::f16 || ValVT == MVT::f32)) { + if (UseGPRForF16_F32 && + (ValVT == MVT::f16 || ValVT == MVT::bf16 || ValVT == MVT::f32)) { LocVT = XLenVT; LocInfo = CCValAssign::BCvt; } else if (UseGPRForF64 && XLen == 64 && ValVT == MVT::f64) { @@ -14961,7 +15014,7 @@ unsigned StoreSizeBytes = XLen / 8; Align StackAlign = Align(XLen / 8); - if (ValVT == MVT::f16 && !UseGPRForF16_F32) + if ((ValVT == MVT::f16 || ValVT == MVT::bf16) && !UseGPRForF16_F32) Reg = State.AllocateReg(ArgFPR16s); else if (ValVT == MVT::f32 && !UseGPRForF16_F32) Reg = State.AllocateReg(ArgFPR32s); @@ -15118,8 +15171,9 @@ Val = convertFromScalableVector(VA.getValVT(), Val, DAG, Subtarget); break; case CCValAssign::BCvt: - if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16) - Val = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, Val); + if (VA.getLocVT().isInteger() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) + Val = DAG.getNode(RISCVISD::FMV_H_X, DL, VA.getValVT(), Val); else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) Val = DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Val); else @@ -15177,7 +15231,8 @@ Val = convertToScalableVector(LocVT, Val, DAG, Subtarget); break; case CCValAssign::BCvt: - if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16) + if (VA.getLocVT().isInteger() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, VA.getLocVT(), Val); else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Val); @@ -16197,6 +16252,8 @@ NODE_NAME_CASE(FCVT_WU_RV64) NODE_NAME_CASE(STRICT_FCVT_W_RV64) NODE_NAME_CASE(STRICT_FCVT_WU_RV64) + NODE_NAME_CASE(FP_ROUND_BF16) + NODE_NAME_CASE(FP_EXTEND_BF16) NODE_NAME_CASE(FROUND) NODE_NAME_CASE(FPCLASS) NODE_NAME_CASE(READ_CYCLE_WIDE) Index: llvm/lib/Target/RISCV/RISCVInstrInfo.td =================================================================== --- llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -1941,8 +1941,8 @@ include "RISCVInstrInfoZvfbf.td" include "RISCVInstrInfoZvk.td" include "RISCVInstrInfoZfa.td" -include "RISCVInstrInfoZfbfmin.td" include "RISCVInstrInfoZfh.td" +include "RISCVInstrInfoZfbfmin.td" include "RISCVInstrInfoZicbo.td" include "RISCVInstrInfoZicond.td" Index: llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td =================================================================== --- llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td +++ llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td @@ -13,6 +13,30 @@ // //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// RISC-V specific DAG Nodes. +//===----------------------------------------------------------------------===// + +def SDT_RISCVFP_ROUND_BF16 + : SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, f32>]>; +def SDT_RISCVFP_EXTEND_BF16 + : SDTypeProfile<1, 1, [SDTCisVT<0, f32>, SDTCisVT<1, bf16>]>; +def SDT_RISCVFMV_H_X_BF16 + : SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, XLenVT>]>; +def SDT_RISCVFMV_X_EXTH_BF16 + : SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisVT<1, bf16>]>; + +def riscv_fpround_bf16 + : SDNode<"RISCVISD::FP_ROUND_BF16", SDT_RISCVFP_ROUND_BF16>; +def riscv_fpextend_bf16 + : SDNode<"RISCVISD::FP_EXTEND_BF16", SDT_RISCVFP_EXTEND_BF16>; +def riscv_fmv_h_x_bf16 + : SDNode<"RISCVISD::FMV_H_X", SDT_RISCVFMV_H_X_BF16>; +def riscv_fmv_x_anyexth_bf16 + : SDNode<"RISCVISD::FMV_X_ANYEXTH", SDT_RISCVFMV_X_EXTH_BF16>; +def riscv_fmv_x_signexth_bf16 + : SDNode<"RISCVISD::FMV_X_SIGNEXTH", SDT_RISCVFMV_X_EXTH_BF16>; + //===----------------------------------------------------------------------===// // Instructions //===----------------------------------------------------------------------===// @@ -23,3 +47,27 @@ def FCVT_S_BF16 : FPUnaryOp_r_frm<0b0100000, 0b00110, FPR32, FPR16, "fcvt.s.bf16">, Sched<[WriteFCvtF32ToF16, ReadFCvtF32ToF16]>; } // Predicates = [HasStdExtZfbfmin] + +//===----------------------------------------------------------------------===// +// Pseudo-instructions and codegen patterns +//===----------------------------------------------------------------------===// + +let Predicates = [HasStdExtZfbfmin] in { +/// Loads +def : LdPat; + +/// Stores +def : StPat; + +/// Float conversion operations +// f32 -> bf16, bf16 -> f32 +def : Pat<(bf16 (riscv_fpround_bf16 FPR32:$rs1)), + (FCVT_BF16_S FPR32:$rs1, FRM_DYN)>; +def : Pat<(riscv_fpextend_bf16 (bf16 FPR16:$rs1)), + (FCVT_S_BF16 FPR16:$rs1, FRM_DYN)>; + +// Moves (no conversion) +def : Pat<(riscv_fmv_h_x_bf16 GPR:$src), (FMV_H_X GPR:$src)>; +def : Pat<(riscv_fmv_x_anyexth_bf16 (bf16 FPR16:$src)), (FMV_X_H FPR16:$src)>; +def : Pat<(riscv_fmv_x_signexth_bf16 (bf16 FPR16:$src)), (FMV_X_H FPR16:$src)>; +} // Predicates = [HasStdExtZfbfmin] Index: llvm/test/CodeGen/RISCV/zfbfmin.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/RISCV/zfbfmin.ll @@ -0,0 +1,92 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zfbfmin -verify-machineinstrs \ +; RUN: -target-abi ilp32d < %s | FileCheck -check-prefix=CHECKIZFBFMIN %s +; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zfbfmin -verify-machineinstrs \ +; RUN: -target-abi lp64d < %s | FileCheck -check-prefix=CHECKIZFBFMIN %s + +define bfloat @bitcast_bf16_i16(i16 %a) nounwind { +; CHECKIZFBFMIN-LABEL: bitcast_bf16_i16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fmv.h.x fa0, a0 +; CHECKIZFBFMIN-NEXT: ret + %1 = bitcast i16 %a to bfloat + ret bfloat %1 +} + +define i16 @bitcast_i16_bf16(bfloat %a) nounwind { +; CHECKIZFBFMIN-LABEL: bitcast_i16_bf16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fmv.x.h a0, fa0 +; CHECKIZFBFMIN-NEXT: ret + %1 = bitcast bfloat %a to i16 + ret i16 %1 +} + +define bfloat @fcvt_bf16_s(float %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_bf16_s: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.bf16.s fa0, fa0 +; CHECKIZFBFMIN-NEXT: ret + %1 = fptrunc float %a to bfloat + ret bfloat %1 +} + +define float @fcvt_s_bf16(bfloat %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_s_bf16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.s.bf16 fa0, fa0 +; CHECKIZFBFMIN-NEXT: ret + %1 = fpext bfloat %a to float + ret float %1 +} + +define bfloat @fcvt_bf16_d(double %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_bf16_d: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.s.d fa5, fa0 +; CHECKIZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5 +; CHECKIZFBFMIN-NEXT: ret + %1 = fptrunc double %a to bfloat + ret bfloat %1 +} + +define double @fcvt_d_bf16(bfloat %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_d_bf16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.s.bf16 fa5, fa0 +; CHECKIZFBFMIN-NEXT: fcvt.d.s fa0, fa5 +; CHECKIZFBFMIN-NEXT: ret + %1 = fpext bfloat %a to double + ret double %1 +} + +define bfloat @bfloat_load(ptr %a) nounwind { +; CHECKIZFBFMIN-LABEL: bfloat_load: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: flh fa0, 6(a0) +; CHECKIZFBFMIN-NEXT: ret + %1 = getelementptr bfloat, ptr %a, i32 3 + %2 = load bfloat, ptr %1 + ret bfloat %2 +} + +define bfloat @bfloat_imm() nounwind { +; CHECKIZFBFMIN-LABEL: bfloat_imm: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: lui a0, %hi(.LCPI7_0) +; CHECKIZFBFMIN-NEXT: flh fa0, %lo(.LCPI7_0)(a0) +; CHECKIZFBFMIN-NEXT: ret + ret bfloat 3.0 +} + +define dso_local void @bfloat_store(ptr %a, bfloat %b) nounwind { +; CHECKIZFBFMIN-LABEL: bfloat_store: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fsh fa0, 0(a0) +; CHECKIZFBFMIN-NEXT: fsh fa0, 16(a0) +; CHECKIZFBFMIN-NEXT: ret + store bfloat %b, ptr %a + %1 = getelementptr bfloat, ptr %a, i32 8 + store bfloat %b, ptr %1 + ret void +}