Index: llvm/lib/Target/RISCV/RISCVFeatures.td =================================================================== --- llvm/lib/Target/RISCV/RISCVFeatures.td +++ llvm/lib/Target/RISCV/RISCVFeatures.td @@ -664,6 +664,14 @@ "'Zfbfmin' (Scalar BF16 Converts) or " "'Zvfbfwma' (Vector BF16 widening mul-add)">; +def HasScalarHalfFPLoadStoreMove + : Predicate<"Subtarget->hasScalarHalfFPLoadStoreMove()">, + AssemblerPredicate<(any_of FeatureStdExtZfh, FeatureStdExtZfhmin, + FeatureStdExtZfbfmin), + "'Zfh' (Half-Precision Floating-Point) or " + "'Zfhmin' (Half-Precision Floating-Point Minimal) or " + "'Zfbfmin' (Scalar BF16 Converts)">; + def FeatureStdExtZacas : SubtargetFeature<"experimental-zacas", "HasStdExtZacas", "true", "'Zacas' (Atomic Compare-And-Swap Instructions)">; Index: llvm/lib/Target/RISCV/RISCVISelLowering.h =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.h +++ llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -111,6 +111,9 @@ FCVT_W_RV64, FCVT_WU_RV64, + FP_ROUND_BF16, + FP_EXTEND_BF16, + // Rounds an FP value to its corresponding integer in the same FP format. // First operand is the value to round, the second operand is the largest // integer that can be represented exactly in the FP format. This will be Index: llvm/lib/Target/RISCV/RISCVISelLowering.cpp =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -116,6 +116,8 @@ if (Subtarget.hasStdExtZfhOrZfhmin()) addRegisterClass(MVT::f16, &RISCV::FPR16RegClass); + if (Subtarget.hasStdExtZfbfmin()) + addRegisterClass(MVT::bf16, &RISCV::FPR16RegClass); if (Subtarget.hasStdExtF()) addRegisterClass(MVT::f32, &RISCV::FPR32RegClass); if (Subtarget.hasStdExtD()) @@ -360,6 +362,15 @@ if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) setOperationAction(ISD::BITCAST, MVT::i16, Custom); + + if (Subtarget.hasStdExtZfbfmin()) { + setOperationAction(ISD::BITCAST, MVT::i16, Custom); + setOperationAction(ISD::BITCAST, MVT::bf16, Custom); + setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom); + setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom); + setOperationAction(ISD::ConstantFP, MVT::bf16, Expand); + } if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) { if (Subtarget.hasStdExtZfhOrZhinx()) { @@ -4734,6 +4745,12 @@ SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, NewOp0); return FPConv; } + if (VT == MVT::bf16 && Op0VT == MVT::i16 && + Subtarget.hasStdExtZfbfmin()) { + SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0); + SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::bf16, NewOp0); + return FPConv; + } if (VT == MVT::f32 && Op0VT == MVT::i32 && Subtarget.is64Bit() && Subtarget.hasStdExtFOrZfinx()) { SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Op0); @@ -4897,11 +4914,42 @@ } return SDValue(); } - case ISD::FP_EXTEND: - case ISD::FP_ROUND: + case ISD::FP_EXTEND: { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + SDValue Op0 = Op.getOperand(0); + EVT Op0VT = Op0.getValueType(); + if (VT == MVT::f32 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) + return DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0); + if (VT == MVT::f64 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) { + SDValue FloatVal = + DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0); + return DAG.getNode(ISD::FP_EXTEND, DL, MVT::f64, FloatVal); + } + + if (!Op.getValueType().isVector()) + return Op; + return lowerVectorFPExtendOrRoundLike(Op, DAG); + } + case ISD::FP_ROUND: { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + SDValue Op0 = Op.getOperand(0); + EVT Op0VT = Op0.getValueType(); + if (VT == MVT::bf16 && Op0VT == MVT::f32 && Subtarget.hasStdExtZfbfmin()) + return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, Op0); + if (VT == MVT::bf16 && Op0VT == MVT::f64 && Subtarget.hasStdExtZfbfmin() && + Subtarget.hasStdExtDOrZdinx()) { + SDValue FloatVal = + DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Op0, + DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); + return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, FloatVal); + } + if (!Op.getValueType().isVector()) return Op; return lowerVectorFPExtendOrRoundLike(Op, DAG); + } case ISD::STRICT_FP_ROUND: case ISD::STRICT_FP_EXTEND: return lowerStrictFPExtendOrRoundLike(Op, DAG); @@ -9887,6 +9935,10 @@ Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) { SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0); Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv)); + } else if (VT == MVT::i16 && Op0VT == MVT::bf16 && + Subtarget.hasStdExtZfbfmin()) { + SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv)); } else if (VT == MVT::i32 && Op0VT == MVT::f32 && Subtarget.is64Bit() && Subtarget.hasStdExtFOrZfinx()) { SDValue FPConv = @@ -14732,7 +14784,8 @@ // similar local variables rather than directly checking against the target // ABI. - if (UseGPRForF16_F32 && (ValVT == MVT::f16 || ValVT == MVT::f32)) { + if (UseGPRForF16_F32 && + (ValVT == MVT::f16 || ValVT == MVT::bf16 || ValVT == MVT::f32)) { LocVT = XLenVT; LocInfo = CCValAssign::BCvt; } else if (UseGPRForF64 && XLen == 64 && ValVT == MVT::f64) { @@ -14825,7 +14878,7 @@ unsigned StoreSizeBytes = XLen / 8; Align StackAlign = Align(XLen / 8); - if (ValVT == MVT::f16 && !UseGPRForF16_F32) + if ((ValVT == MVT::f16 || ValVT == MVT::bf16) && !UseGPRForF16_F32) Reg = State.AllocateReg(ArgFPR16s); else if (ValVT == MVT::f32 && !UseGPRForF16_F32) Reg = State.AllocateReg(ArgFPR32s); @@ -14982,8 +15035,9 @@ Val = convertFromScalableVector(VA.getValVT(), Val, DAG, Subtarget); break; case CCValAssign::BCvt: - if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16) - Val = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, Val); + if (VA.getLocVT().isInteger() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) + Val = DAG.getNode(RISCVISD::FMV_H_X, DL, VA.getValVT(), Val); else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) Val = DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Val); else @@ -15041,7 +15095,8 @@ Val = convertToScalableVector(LocVT, Val, DAG, Subtarget); break; case CCValAssign::BCvt: - if (VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16) + if (VA.getLocVT().isInteger() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, VA.getLocVT(), Val); else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Val); @@ -16040,6 +16095,8 @@ NODE_NAME_CASE(FCVT_WU_RV64) NODE_NAME_CASE(STRICT_FCVT_W_RV64) NODE_NAME_CASE(STRICT_FCVT_WU_RV64) + NODE_NAME_CASE(FP_ROUND_BF16) + NODE_NAME_CASE(FP_EXTEND_BF16) NODE_NAME_CASE(FROUND) NODE_NAME_CASE(FPCLASS) NODE_NAME_CASE(READ_CYCLE_WIDE) Index: llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td =================================================================== --- llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td +++ llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td @@ -13,6 +13,20 @@ // //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// RISC-V specific DAG Nodes. +//===----------------------------------------------------------------------===// + +def SDT_RISCVFP_ROUND_BF16 + : SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, f32>]>; +def SDT_RISCVFP_EXTEND_BF16 + : SDTypeProfile<1, 1, [SDTCisVT<0, f32>, SDTCisVT<1, bf16>]>; + +def riscv_fpround_bf16 + : SDNode<"RISCVISD::FP_ROUND_BF16", SDT_RISCVFP_ROUND_BF16>; +def riscv_fpextend_bf16 + : SDNode<"RISCVISD::FP_EXTEND_BF16", SDT_RISCVFP_EXTEND_BF16>; + //===----------------------------------------------------------------------===// // Instructions //===----------------------------------------------------------------------===// @@ -23,3 +37,14 @@ def FCVT_S_BF16 : FPUnaryOp_r_frm<0b0100000, 0b00110, FPR32, FPR16, "fcvt.s.bf16">, Sched<[WriteFCvtF32ToF16, ReadFCvtF32ToF16]>; } // Predicates = [HasStdExtZfbfmin] + +//===----------------------------------------------------------------------===// +// Pseudo-instructions and codegen patterns +//===----------------------------------------------------------------------===// + +let Predicates = [HasStdExtZfbfmin] in { +def : Pat<(bf16 (riscv_fpround_bf16 FPR32:$rs1)), + (FCVT_BF16_S FPR32:$rs1, FRM_DYN)>; +def : Pat<(riscv_fpextend_bf16 (bf16 FPR16:$rs1)), + (FCVT_S_BF16 FPR16:$rs1, FRM_DYN)>; +} // Predicates = [HasStdExtZfbfmin] Index: llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td =================================================================== --- llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td +++ llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td @@ -421,6 +421,14 @@ def : StPat; } // Predicates = [HasStdExtZfhOrZfhmin] +let Predicates = [HasStdExtZfbfmin] in { +/// Loads +def : LdPat; + +/// Stores +def : StPat; +} // Predicates = [HasStdExtZfbfmin] + let Predicates = [HasStdExtZhinxOrZhinxmin] in { /// Loads def : Pat<(f16 (load GPR:$rs1)), (COPY_TO_REGCLASS (LH GPR:$rs1, 0), GPRF16)>; @@ -437,13 +445,15 @@ def : Pat<(f16 (any_fpround FPR32:$rs1)), (FCVT_H_S FPR32:$rs1, FRM_DYN)>; def : Pat<(any_fpextend (f16 FPR16:$rs1)), (FCVT_S_H FPR16:$rs1)>; +def : Pat<(fcopysign FPR32:$rs1, (f16 FPR16:$rs2)), (FSGNJ_S $rs1, (FCVT_S_H $rs2))>; +} // Predicates = [HasStdExtZfhOrZfhmin] + +let Predicates = [HasScalarHalfFPLoadStoreMove] in { // Moves (no conversion) def : Pat<(riscv_fmv_h_x GPR:$src), (FMV_H_X GPR:$src)>; def : Pat<(riscv_fmv_x_anyexth (f16 FPR16:$src)), (FMV_X_H FPR16:$src)>; def : Pat<(riscv_fmv_x_signexth (f16 FPR16:$src)), (FMV_X_H FPR16:$src)>; - -def : Pat<(fcopysign FPR32:$rs1, (f16 FPR16:$rs2)), (FSGNJ_S $rs1, (FCVT_S_H $rs2))>; -} // Predicates = [HasStdExtZfhOrZfhmin] +} // Predicates = [HasScalarHalfFPLoadStoreMove] let Predicates = [HasStdExtZhinxOrZhinxmin] in { /// Float conversion operations Index: llvm/lib/Target/RISCV/RISCVSubtarget.h =================================================================== --- llvm/lib/Target/RISCV/RISCVSubtarget.h +++ llvm/lib/Target/RISCV/RISCVSubtarget.h @@ -127,6 +127,9 @@ return HasStdExtZfh || HasStdExtZfhmin || HasStdExtZfbfmin || HasStdExtZvfbfwma; } + bool hasScalarHalfFPLoadStoreMove() const { + return HasStdExtZfh || HasStdExtZfhmin || HasStdExtZfbfmin; + } bool is64Bit() const { return IsRV64; } MVT getXLenVT() const { return XLenVT; } unsigned getXLen() const { return XLen; } Index: llvm/test/CodeGen/RISCV/zfbfmin.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/RISCV/zfbfmin.ll @@ -0,0 +1,92 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zfbfmin -verify-machineinstrs \ +; RUN: -target-abi ilp32d < %s | FileCheck -check-prefix=CHECKIZFBFMIN %s +; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zfbfmin -verify-machineinstrs \ +; RUN: -target-abi lp64d < %s | FileCheck -check-prefix=CHECKIZFBFMIN %s + +define bfloat @bitcast_bf16_i16(i16 %a) nounwind { +; CHECKIZFBFMIN-LABEL: bitcast_bf16_i16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fmv.h.x fa0, a0 +; CHECKIZFBFMIN-NEXT: ret + %1 = bitcast i16 %a to bfloat + ret bfloat %1 +} + +define i16 @bitcast_i16_bf16(bfloat %a) nounwind { +; CHECKIZFBFMIN-LABEL: bitcast_i16_bf16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fmv.x.h a0, fa0 +; CHECKIZFBFMIN-NEXT: ret + %1 = bitcast bfloat %a to i16 + ret i16 %1 +} + +define bfloat @fcvt_bf16_s(float %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_bf16_s: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.bf16.s fa0, fa0 +; CHECKIZFBFMIN-NEXT: ret + %1 = fptrunc float %a to bfloat + ret bfloat %1 +} + +define float @fcvt_s_bf16(bfloat %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_s_bf16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.s.bf16 fa0, fa0 +; CHECKIZFBFMIN-NEXT: ret + %1 = fpext bfloat %a to float + ret float %1 +} + +define bfloat @fcvt_bf16_d(double %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_bf16_d: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.s.d fa5, fa0 +; CHECKIZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5 +; CHECKIZFBFMIN-NEXT: ret + %1 = fptrunc double %a to bfloat + ret bfloat %1 +} + +define double @fcvt_d_bf16(bfloat %a) nounwind { +; CHECKIZFBFMIN-LABEL: fcvt_d_bf16: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fcvt.s.bf16 fa5, fa0 +; CHECKIZFBFMIN-NEXT: fcvt.d.s fa0, fa5 +; CHECKIZFBFMIN-NEXT: ret + %1 = fpext bfloat %a to double + ret double %1 +} + +define bfloat @bfloat_load(ptr %a) nounwind { +; CHECKIZFBFMIN-LABEL: bfloat_load: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: flh fa0, 6(a0) +; CHECKIZFBFMIN-NEXT: ret + %1 = getelementptr bfloat, ptr %a, i32 3 + %2 = load bfloat, ptr %1 + ret bfloat %2 +} + +define bfloat @bfloat_imm() nounwind { +; CHECKIZFBFMIN-LABEL: bfloat_imm: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: lui a0, %hi(.LCPI7_0) +; CHECKIZFBFMIN-NEXT: flh fa0, %lo(.LCPI7_0)(a0) +; CHECKIZFBFMIN-NEXT: ret + ret bfloat 3.0 +} + +define dso_local void @bfloat_store(ptr %a, bfloat %b) nounwind { +; CHECKIZFBFMIN-LABEL: bfloat_store: +; CHECKIZFBFMIN: # %bb.0: +; CHECKIZFBFMIN-NEXT: fsh fa0, 0(a0) +; CHECKIZFBFMIN-NEXT: fsh fa0, 16(a0) +; CHECKIZFBFMIN-NEXT: ret + store bfloat %b, ptr %a + %1 = getelementptr bfloat, ptr %a, i32 8 + store bfloat %b, ptr %1 + ret void +}