Index: llvm/lib/Target/RISCV/RISCVISelLowering.h =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.h +++ llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -829,6 +829,8 @@ unsigned ExtendOpc) const; SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const; SDValue lowerSET_ROUNDING(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerBF16_TO_FP(SDValue Op, SelectionDAG &DAG) const; SDValue lowerEH_DWARF_CFA(SDValue Op, SelectionDAG &DAG) const; SDValue lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op, SelectionDAG &DAG) const; Index: llvm/lib/Target/RISCV/RISCVISelLowering.cpp =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -426,7 +426,11 @@ setOperationAction(ISD::BR_CC, MVT::f32, Expand); setOperationAction(FPOpToExpand, MVT::f32, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand); setTruncStoreAction(MVT::f32, MVT::f16, Expand); + setTruncStoreAction(MVT::f32, MVT::bf16, Expand); + setOperationAction(ISD::FP_TO_BF16, MVT::f32, Custom); + setOperationAction(ISD::BF16_TO_FP, MVT::f32, Custom); setOperationAction(ISD::IS_FPCLASS, MVT::f32, Custom); if (Subtarget.hasStdExtZfa()) @@ -461,6 +465,10 @@ setOperationAction(FPOpToExpand, MVT::f64, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand); setTruncStoreAction(MVT::f64, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand); + setTruncStoreAction(MVT::f64, MVT::bf16, Expand); + setOperationAction(ISD::FP_TO_BF16, MVT::f64, Custom); + setOperationAction(ISD::BF16_TO_FP, MVT::f64, Expand); setOperationAction(ISD::IS_FPCLASS, MVT::f64, Custom); } @@ -1154,6 +1162,7 @@ setLibcallName(RTLIB::FPEXT_F16_F32, "__extendhfsf2"); setLibcallName(RTLIB::FPROUND_F32_F16, "__truncsfhf2"); + setLibcallName(RTLIB::FPROUND_F32_BF16, "__truncsfbf2"); // Disable strict node mutation. IsStrictFPEnabled = true; @@ -2424,6 +2433,33 @@ return Res; } +SDValue RISCVTargetLowering::lowerFP_TO_BF16(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + MakeLibCallOptions CallOptions; + RTLIB::Libcall LC = + RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16); + SDValue Res = + makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first; + if (Subtarget.is64Bit()) + return DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Res); + return DAG.getBitcast(MVT::i32, Res); +} + +SDValue RISCVTargetLowering::lowerBF16_TO_FP(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + MakeLibCallOptions CallOptions; + SDValue Arg = Subtarget.is64Bit() + ? DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, + Op.getOperand(0)) + : DAG.getBitcast(MVT::f32, Op.getOperand(0)); + SDValue Res = + makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg, CallOptions, DL) + .first; + return Res; +} + static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) { switch (Opc) { case ISD::FROUNDEVEN: @@ -4751,6 +4787,10 @@ case ISD::FP_TO_SINT_SAT: case ISD::FP_TO_UINT_SAT: return lowerFP_TO_INT_SAT(Op, DAG, Subtarget); + case ISD::FP_TO_BF16: + return lowerFP_TO_BF16(Op, DAG); + case ISD::BF16_TO_FP: + return lowerBF16_TO_FP(Op, DAG); case ISD::FTRUNC: case ISD::FCEIL: case ISD::FFLOOR: @@ -15980,8 +16020,9 @@ unsigned NumParts, MVT PartVT, std::optional CC) const { bool IsABIRegCopy = CC.has_value(); EVT ValueVT = Val.getValueType(); - if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) { - // Cast the f16 to i16, extend to i32, pad with ones to make a float nan, + if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) && + PartVT == MVT::f32) { + // Cast the f16/bf16 to i16, extend to i32, pad with ones to make a float nan, // and cast to f32. Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val); Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val); @@ -16033,13 +16074,14 @@ SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts, MVT PartVT, EVT ValueVT, std::optional CC) const { bool IsABIRegCopy = CC.has_value(); - if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) { + if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) && + PartVT == MVT::f32) { SDValue Val = Parts[0]; - // Cast the f32 to i32, truncate to i16, and cast back to f16. + // Cast the f32 to i32, truncate to i16, and cast back to f16/bf16. Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val); Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val); - Val = DAG.getNode(ISD::BITCAST, DL, MVT::f16, Val); + Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); return Val; } Index: llvm/test/CodeGen/RISCV/bfloat.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/RISCV/bfloat.ll @@ -0,0 +1,276 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=riscv64 -mattr=+f,+d -verify-machineinstrs| FileCheck %s --check-prefixes=RV64 +; RUN: llc < %s -mtriple=riscv32 -mattr=+f,+d -verify-machineinstrs| FileCheck %s --check-prefixes=RV32 + +define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind { +; RV64-LABEL: add: +; RV64: # %bb.0: +; RV64-NEXT: addi sp, sp, -32 +; RV64-NEXT: sd ra, 24(sp) # 8-byte Folded Spill +; RV64-NEXT: sd s0, 16(sp) # 8-byte Folded Spill +; RV64-NEXT: sd s1, 8(sp) # 8-byte Folded Spill +; RV64-NEXT: fsd fs0, 0(sp) # 8-byte Folded Spill +; RV64-NEXT: lhu s1, 0(a0) +; RV64-NEXT: lhu a0, 0(a1) +; RV64-NEXT: mv s0, a2 +; RV64-NEXT: fmv.w.x fa0, a0 +; RV64-NEXT: call __extendhfsf2@plt +; RV64-NEXT: fmv.s fs0, fa0 +; RV64-NEXT: fmv.w.x fa0, s1 +; RV64-NEXT: call __extendhfsf2@plt +; RV64-NEXT: fadd.s fa0, fa0, fs0 +; RV64-NEXT: call __truncsfbf2@plt +; RV64-NEXT: fmv.x.w a0, fa0 +; RV64-NEXT: sh a0, 0(s0) +; RV64-NEXT: ld ra, 24(sp) # 8-byte Folded Reload +; RV64-NEXT: ld s0, 16(sp) # 8-byte Folded Reload +; RV64-NEXT: ld s1, 8(sp) # 8-byte Folded Reload +; RV64-NEXT: fld fs0, 0(sp) # 8-byte Folded Reload +; RV64-NEXT: addi sp, sp, 32 +; RV64-NEXT: ret +; +; RV32-LABEL: add: +; RV32: # %bb.0: +; RV32-NEXT: addi sp, sp, -32 +; RV32-NEXT: sw ra, 28(sp) # 4-byte Folded Spill +; RV32-NEXT: sw s0, 24(sp) # 4-byte Folded Spill +; RV32-NEXT: fsd fs0, 16(sp) # 8-byte Folded Spill +; RV32-NEXT: fsd fs1, 8(sp) # 8-byte Folded Spill +; RV32-NEXT: lhu a0, 0(a0) +; RV32-NEXT: lhu a1, 0(a1) +; RV32-NEXT: mv s0, a2 +; RV32-NEXT: fmv.w.x fs0, a0 +; RV32-NEXT: fmv.w.x fa0, a1 +; RV32-NEXT: call __extendhfsf2@plt +; RV32-NEXT: fmv.s fs1, fa0 +; RV32-NEXT: fmv.s fa0, fs0 +; RV32-NEXT: call __extendhfsf2@plt +; RV32-NEXT: fadd.s fa0, fa0, fs1 +; RV32-NEXT: call __truncsfbf2@plt +; RV32-NEXT: fmv.x.w a0, fa0 +; RV32-NEXT: sh a0, 0(s0) +; RV32-NEXT: lw ra, 28(sp) # 4-byte Folded Reload +; RV32-NEXT: lw s0, 24(sp) # 4-byte Folded Reload +; RV32-NEXT: fld fs0, 16(sp) # 8-byte Folded Reload +; RV32-NEXT: fld fs1, 8(sp) # 8-byte Folded Reload +; RV32-NEXT: addi sp, sp, 32 +; RV32-NEXT: ret + %a = load bfloat, ptr %pa + %b = load bfloat, ptr %pb + %add = fadd bfloat %a, %b + store bfloat %add, ptr %pc + ret void +} + +define bfloat @add2(bfloat %a, bfloat %b) nounwind { +; RV64-LABEL: add2: +; RV64: # %bb.0: +; RV64-NEXT: addi sp, sp, -32 +; RV64-NEXT: sd ra, 24(sp) # 8-byte Folded Spill +; RV64-NEXT: sd s0, 16(sp) # 8-byte Folded Spill +; RV64-NEXT: fsd fs0, 8(sp) # 8-byte Folded Spill +; RV64-NEXT: fmv.x.w a0, fa0 +; RV64-NEXT: lui a1, 16 +; RV64-NEXT: addiw a1, a1, -1 +; RV64-NEXT: and s0, a0, a1 +; RV64-NEXT: fmv.x.w a0, fa1 +; RV64-NEXT: and a0, a0, a1 +; RV64-NEXT: fmv.w.x fa0, a0 +; RV64-NEXT: call __extendhfsf2@plt +; RV64-NEXT: fmv.s fs0, fa0 +; RV64-NEXT: fmv.w.x fa0, s0 +; RV64-NEXT: call __extendhfsf2@plt +; RV64-NEXT: fadd.s fa0, fa0, fs0 +; RV64-NEXT: call __truncsfbf2@plt +; RV64-NEXT: fmv.x.w a0, fa0 +; RV64-NEXT: lui a1, 1048560 +; RV64-NEXT: or a0, a0, a1 +; RV64-NEXT: fmv.w.x fa0, a0 +; RV64-NEXT: ld ra, 24(sp) # 8-byte Folded Reload +; RV64-NEXT: ld s0, 16(sp) # 8-byte Folded Reload +; RV64-NEXT: fld fs0, 8(sp) # 8-byte Folded Reload +; RV64-NEXT: addi sp, sp, 32 +; RV64-NEXT: ret +; +; RV32-LABEL: add2: +; RV32: # %bb.0: +; RV32-NEXT: addi sp, sp, -32 +; RV32-NEXT: sw ra, 28(sp) # 4-byte Folded Spill +; RV32-NEXT: fsd fs0, 16(sp) # 8-byte Folded Spill +; RV32-NEXT: fsd fs1, 8(sp) # 8-byte Folded Spill +; RV32-NEXT: fmv.x.w a0, fa0 +; RV32-NEXT: lui a1, 16 +; RV32-NEXT: addi a1, a1, -1 +; RV32-NEXT: and a0, a0, a1 +; RV32-NEXT: fmv.w.x fs0, a0 +; RV32-NEXT: fmv.x.w a0, fa1 +; RV32-NEXT: and a0, a0, a1 +; RV32-NEXT: fmv.w.x fa0, a0 +; RV32-NEXT: call __extendhfsf2@plt +; RV32-NEXT: fmv.s fs1, fa0 +; RV32-NEXT: fmv.s fa0, fs0 +; RV32-NEXT: call __extendhfsf2@plt +; RV32-NEXT: fadd.s fa0, fa0, fs1 +; RV32-NEXT: call __truncsfbf2@plt +; RV32-NEXT: fmv.x.w a0, fa0 +; RV32-NEXT: lui a1, 1048560 +; RV32-NEXT: or a0, a0, a1 +; RV32-NEXT: fmv.w.x fa0, a0 +; RV32-NEXT: lw ra, 28(sp) # 4-byte Folded Reload +; RV32-NEXT: fld fs0, 16(sp) # 8-byte Folded Reload +; RV32-NEXT: fld fs1, 8(sp) # 8-byte Folded Reload +; RV32-NEXT: addi sp, sp, 32 +; RV32-NEXT: ret + %add = fadd bfloat %a, %b + ret bfloat %add +} + +define void @add_constant(ptr %pa, ptr %pc) nounwind { +; RV64-LABEL: add_constant: +; RV64: # %bb.0: +; RV64-NEXT: addi sp, sp, -16 +; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill +; RV64-NEXT: sd s0, 0(sp) # 8-byte Folded Spill +; RV64-NEXT: lhu a0, 0(a0) +; RV64-NEXT: mv s0, a1 +; RV64-NEXT: fmv.w.x fa0, a0 +; RV64-NEXT: call __extendhfsf2@plt +; RV64-NEXT: lui a0, 260096 +; RV64-NEXT: fmv.w.x fa5, a0 +; RV64-NEXT: fadd.s fa0, fa0, fa5 +; RV64-NEXT: call __truncsfbf2@plt +; RV64-NEXT: fmv.x.w a0, fa0 +; RV64-NEXT: sh a0, 0(s0) +; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload +; RV64-NEXT: ld s0, 0(sp) # 8-byte Folded Reload +; RV64-NEXT: addi sp, sp, 16 +; RV64-NEXT: ret +; +; RV32-LABEL: add_constant: +; RV32: # %bb.0: +; RV32-NEXT: addi sp, sp, -16 +; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill +; RV32-NEXT: sw s0, 8(sp) # 4-byte Folded Spill +; RV32-NEXT: lhu a0, 0(a0) +; RV32-NEXT: mv s0, a1 +; RV32-NEXT: fmv.w.x fa0, a0 +; RV32-NEXT: call __extendhfsf2@plt +; RV32-NEXT: lui a0, 260096 +; RV32-NEXT: fmv.w.x fa5, a0 +; RV32-NEXT: fadd.s fa0, fa0, fa5 +; RV32-NEXT: call __truncsfbf2@plt +; RV32-NEXT: fmv.x.w a0, fa0 +; RV32-NEXT: sh a0, 0(s0) +; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload +; RV32-NEXT: lw s0, 8(sp) # 4-byte Folded Reload +; RV32-NEXT: addi sp, sp, 16 +; RV32-NEXT: ret + %a = load bfloat, ptr %pa + %add = fadd bfloat %a, 1.0 + store bfloat %add, ptr %pc + ret void +} + +define bfloat @add_constant2(bfloat %a) nounwind { +; RV64-LABEL: add_constant2: +; RV64: # %bb.0: +; RV64-NEXT: addi sp, sp, -16 +; RV64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill +; RV64-NEXT: fmv.x.w a0, fa0 +; RV64-NEXT: slli a0, a0, 48 +; RV64-NEXT: srli a0, a0, 48 +; RV64-NEXT: fmv.w.x fa0, a0 +; RV64-NEXT: call __extendhfsf2@plt +; RV64-NEXT: lui a0, 260096 +; RV64-NEXT: fmv.w.x fa5, a0 +; RV64-NEXT: fadd.s fa0, fa0, fa5 +; RV64-NEXT: call __truncsfbf2@plt +; RV64-NEXT: fmv.x.w a0, fa0 +; RV64-NEXT: lui a1, 1048560 +; RV64-NEXT: or a0, a0, a1 +; RV64-NEXT: fmv.w.x fa0, a0 +; RV64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload +; RV64-NEXT: addi sp, sp, 16 +; RV64-NEXT: ret +; +; RV32-LABEL: add_constant2: +; RV32: # %bb.0: +; RV32-NEXT: addi sp, sp, -16 +; RV32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill +; RV32-NEXT: fmv.x.w a0, fa0 +; RV32-NEXT: slli a0, a0, 16 +; RV32-NEXT: srli a0, a0, 16 +; RV32-NEXT: fmv.w.x fa0, a0 +; RV32-NEXT: call __extendhfsf2@plt +; RV32-NEXT: lui a0, 260096 +; RV32-NEXT: fmv.w.x fa5, a0 +; RV32-NEXT: fadd.s fa0, fa0, fa5 +; RV32-NEXT: call __truncsfbf2@plt +; RV32-NEXT: fmv.x.w a0, fa0 +; RV32-NEXT: lui a1, 1048560 +; RV32-NEXT: or a0, a0, a1 +; RV32-NEXT: fmv.w.x fa0, a0 +; RV32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload +; RV32-NEXT: addi sp, sp, 16 +; RV32-NEXT: ret + %add = fadd bfloat %a, 1.0 + ret bfloat %add +} + +define void @store_constant(ptr %pc) nounwind { +; RV64-LABEL: store_constant: +; RV64: # %bb.0: +; RV64-NEXT: lui a1, 4 +; RV64-NEXT: addiw a1, a1, -128 +; RV64-NEXT: sh a1, 0(a0) +; RV64-NEXT: ret +; +; RV32-LABEL: store_constant: +; RV32: # %bb.0: +; RV32-NEXT: lui a1, 4 +; RV32-NEXT: addi a1, a1, -128 +; RV32-NEXT: sh a1, 0(a0) +; RV32-NEXT: ret + store bfloat 1.0, ptr %pc + ret void +} + +define void @fold_ext_trunc(ptr %pa, ptr %pc) nounwind { +; RV64-LABEL: fold_ext_trunc: +; RV64: # %bb.0: +; RV64-NEXT: lh a0, 0(a0) +; RV64-NEXT: sh a0, 0(a1) +; RV64-NEXT: ret +; +; RV32-LABEL: fold_ext_trunc: +; RV32: # %bb.0: +; RV32-NEXT: lh a0, 0(a0) +; RV32-NEXT: sh a0, 0(a1) +; RV32-NEXT: ret + %a = load bfloat, ptr %pa + %ext = fpext bfloat %a to float + %trunc = fptrunc float %ext to bfloat + store bfloat %trunc, ptr %pc + ret void +} + +define bfloat @fold_ext_trunc2(bfloat %a) nounwind { +; RV64-LABEL: fold_ext_trunc2: +; RV64: # %bb.0: +; RV64-NEXT: fmv.x.w a0, fa0 +; RV64-NEXT: lui a1, 1048560 +; RV64-NEXT: or a0, a0, a1 +; RV64-NEXT: fmv.w.x fa0, a0 +; RV64-NEXT: ret +; +; RV32-LABEL: fold_ext_trunc2: +; RV32: # %bb.0: +; RV32-NEXT: fmv.x.w a0, fa0 +; RV32-NEXT: lui a1, 1048560 +; RV32-NEXT: or a0, a0, a1 +; RV32-NEXT: fmv.w.x fa0, a0 +; RV32-NEXT: ret + %ext = fpext bfloat %a to float + %trunc = fptrunc float %ext to bfloat + ret bfloat %trunc +}