Index: llvm/lib/Target/RISCV/RISCVISelLowering.h =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.h +++ llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -830,6 +830,9 @@ SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const; SDValue lowerSET_ROUNDING(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerBF16_TO_FP(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerEH_DWARF_CFA(SDValue Op, SelectionDAG &DAG) const; SDValue lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op, SelectionDAG &DAG) const; Index: llvm/lib/Target/RISCV/RISCVISelLowering.cpp =================================================================== --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -426,7 +426,11 @@ setOperationAction(ISD::BR_CC, MVT::f32, Expand); setOperationAction(FPOpToExpand, MVT::f32, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand); setTruncStoreAction(MVT::f32, MVT::f16, Expand); + setTruncStoreAction(MVT::f32, MVT::bf16, Expand); + setOperationAction(ISD::FP_TO_BF16, MVT::f32, Custom); + setOperationAction(ISD::BF16_TO_FP, MVT::f32, Custom); setOperationAction(ISD::IS_FPCLASS, MVT::f32, Custom); if (Subtarget.hasStdExtZfa()) @@ -461,6 +465,10 @@ setOperationAction(FPOpToExpand, MVT::f64, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand); setTruncStoreAction(MVT::f64, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand); + setTruncStoreAction(MVT::f64, MVT::bf16, Expand); + setOperationAction(ISD::FP_TO_BF16, MVT::f64, Custom); + setOperationAction(ISD::BF16_TO_FP, MVT::f64, Custom); setOperationAction(ISD::IS_FPCLASS, MVT::f64, Custom); } @@ -2424,6 +2432,74 @@ return Res; } +SDValue RISCVTargetLowering::lowerFP_TO_BF16(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + SDValue Res; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + if (Op.getOperand(0).getValueType() != MVT::f32) + Op = DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Op.getOperand(0), + DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); + + if (Subtarget.is64Bit()) { + SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, + Op.getOperand(0)); + Res = DAG.getNode( + ISD::SRL, DL, MVT::i64, FPConv, + DAG.getConstant(16, DL, + TLI.getShiftAmountTy(MVT::i64, DAG.getDataLayout()))); + } else { + Res = DAG.getNode( + ISD::SRL, DL, MVT::i32, + DAG.getNode(ISD::BITCAST, DL, MVT::i32, Op.getOperand(0)), + DAG.getConstant(16, DL, + TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout()))); + } + + // The result of this node can be bf16 or an integer type in case bf16 is + // not supported on the target and was softened to i16 for storage. + if (Op.getValueType() == MVT::bf16) + Res = DAG.getNode(ISD::BITCAST, DL, MVT::bf16, + DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Res)); + + return Res; +} + +SDValue RISCVTargetLowering::lowerBF16_TO_FP(SDValue Op, + SelectionDAG &DAG) const { + // Always expand bf16 to f32 casts, they lower to ext + shift. + // + // Note that the operand of this code can be bf16 or an integer type in case + // bf16 is not supported on the target and was softened. + SDLoc DL(Op); + SDValue Res; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (Op.getOperand(0).getValueType() == MVT::bf16) + Op = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, + DAG.getNode(ISD::BITCAST, DL, MVT::i16, Op.getOperand(0))); + + if (Subtarget.is64Bit()) { + SDValue Shift = DAG.getNode( + ISD::SHL, DL, MVT::i64, Op.getOperand(0), + DAG.getConstant(16, DL, + TLI.getShiftAmountTy(MVT::i64, DAG.getDataLayout()))); + Res = DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Shift); + } else { + SDValue Shift = DAG.getNode( + ISD::SHL, DL, MVT::i32, Op.getOperand(0), + DAG.getConstant(16, DL, + TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout()))); + Res = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Shift); + } + + // Add fp_extend in case the output is bigger than f32. + if (Op.getValueType() != MVT::f32) + Res = DAG.getNode(ISD::FP_EXTEND, DL, Op.getValueType(), Res); + + return Res; + } + static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) { switch (Opc) { case ISD::FROUNDEVEN: @@ -4751,6 +4827,10 @@ case ISD::FP_TO_SINT_SAT: case ISD::FP_TO_UINT_SAT: return lowerFP_TO_INT_SAT(Op, DAG, Subtarget); + case ISD::FP_TO_BF16: + return lowerFP_TO_BF16(Op, DAG); + case ISD::BF16_TO_FP: + return lowerBF16_TO_FP(Op, DAG); case ISD::FTRUNC: case ISD::FCEIL: case ISD::FFLOOR: @@ -15980,8 +16060,9 @@ unsigned NumParts, MVT PartVT, std::optional CC) const { bool IsABIRegCopy = CC.has_value(); EVT ValueVT = Val.getValueType(); - if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) { - // Cast the f16 to i16, extend to i32, pad with ones to make a float nan, + if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) && + PartVT == MVT::f32) { + // Cast the f16/bf16 to i16, extend to i32, pad with ones to make a float nan, // and cast to f32. Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val); Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val); @@ -16033,13 +16114,14 @@ SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts, MVT PartVT, EVT ValueVT, std::optional CC) const { bool IsABIRegCopy = CC.has_value(); - if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) { + if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) && + PartVT == MVT::f32) { SDValue Val = Parts[0]; - // Cast the f32 to i32, truncate to i16, and cast back to f16. + // Cast the f32 to i32, truncate to i16, and cast back to f16/bf16. Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val); Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val); - Val = DAG.getNode(ISD::BITCAST, DL, MVT::f16, Val); + Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); return Val; } Index: llvm/test/CodeGen/RISCV/bfloat.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/RISCV/bfloat.ll @@ -0,0 +1,204 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=riscv64 -mattr=+f,+d -verify-machineinstrs| FileCheck %s --check-prefixes=RV64 +; RUN: llc < %s -mtriple=riscv32 -mattr=+f,+d -verify-machineinstrs| FileCheck %s --check-prefixes=RV32 + +define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind { +; RV64-LABEL: add: +; RV64: # %bb.0: +; RV64-NEXT: lhu a1, 0(a1) +; RV64-NEXT: lhu a0, 0(a0) +; RV64-NEXT: slli a1, a1, 16 +; RV64-NEXT: fmv.w.x fa5, a1 +; RV64-NEXT: slli a0, a0, 16 +; RV64-NEXT: fmv.w.x fa4, a0 +; RV64-NEXT: fadd.s fa5, fa4, fa5 +; RV64-NEXT: fmv.x.w a0, fa5 +; RV64-NEXT: srli a0, a0, 16 +; RV64-NEXT: sh a0, 0(a2) +; RV64-NEXT: ret +; +; RV32-LABEL: add: +; RV32: # %bb.0: +; RV32-NEXT: lhu a1, 0(a1) +; RV32-NEXT: lhu a0, 0(a0) +; RV32-NEXT: slli a1, a1, 16 +; RV32-NEXT: fmv.w.x fa5, a1 +; RV32-NEXT: slli a0, a0, 16 +; RV32-NEXT: fmv.w.x fa4, a0 +; RV32-NEXT: fadd.s fa5, fa4, fa5 +; RV32-NEXT: fmv.x.w a0, fa5 +; RV32-NEXT: srli a0, a0, 16 +; RV32-NEXT: sh a0, 0(a2) +; RV32-NEXT: ret + %a = load bfloat, ptr %pa + %b = load bfloat, ptr %pb + %add = fadd bfloat %a, %b + store bfloat %add, ptr %pc + ret void +} + +define bfloat @add2(bfloat %a, bfloat %b) nounwind { +; RV64-LABEL: add2: +; RV64: # %bb.0: +; RV64-NEXT: fmv.x.w a0, fa0 +; RV64-NEXT: lui a1, 16 +; RV64-NEXT: addiw a1, a1, -1 +; RV64-NEXT: and a0, a0, a1 +; RV64-NEXT: fmv.x.w a2, fa1 +; RV64-NEXT: and a1, a2, a1 +; RV64-NEXT: slli a1, a1, 16 +; RV64-NEXT: fmv.w.x fa5, a1 +; RV64-NEXT: slli a0, a0, 16 +; RV64-NEXT: fmv.w.x fa4, a0 +; RV64-NEXT: fadd.s fa5, fa4, fa5 +; RV64-NEXT: fmv.x.w a0, fa5 +; RV64-NEXT: srli a0, a0, 16 +; RV64-NEXT: lui a1, 1048560 +; RV64-NEXT: or a0, a0, a1 +; RV64-NEXT: fmv.w.x fa0, a0 +; RV64-NEXT: ret +; +; RV32-LABEL: add2: +; RV32: # %bb.0: +; RV32-NEXT: fmv.x.w a0, fa0 +; RV32-NEXT: fmv.x.w a1, fa1 +; RV32-NEXT: slli a1, a1, 16 +; RV32-NEXT: fmv.w.x fa5, a1 +; RV32-NEXT: slli a0, a0, 16 +; RV32-NEXT: fmv.w.x fa4, a0 +; RV32-NEXT: fadd.s fa5, fa4, fa5 +; RV32-NEXT: fmv.x.w a0, fa5 +; RV32-NEXT: srli a0, a0, 16 +; RV32-NEXT: lui a1, 1048560 +; RV32-NEXT: or a0, a0, a1 +; RV32-NEXT: fmv.w.x fa0, a0 +; RV32-NEXT: ret + %add = fadd bfloat %a, %b + ret bfloat %add +} + +define void @add_constant(ptr %pa, ptr %pc) nounwind { +; RV64-LABEL: add_constant: +; RV64: # %bb.0: +; RV64-NEXT: lhu a0, 0(a0) +; RV64-NEXT: slli a0, a0, 16 +; RV64-NEXT: fmv.w.x fa5, a0 +; RV64-NEXT: lui a0, 260096 +; RV64-NEXT: fmv.w.x fa4, a0 +; RV64-NEXT: fadd.s fa5, fa5, fa4 +; RV64-NEXT: fmv.x.w a0, fa5 +; RV64-NEXT: srli a0, a0, 16 +; RV64-NEXT: sh a0, 0(a1) +; RV64-NEXT: ret +; +; RV32-LABEL: add_constant: +; RV32: # %bb.0: +; RV32-NEXT: lhu a0, 0(a0) +; RV32-NEXT: slli a0, a0, 16 +; RV32-NEXT: fmv.w.x fa5, a0 +; RV32-NEXT: lui a0, 260096 +; RV32-NEXT: fmv.w.x fa4, a0 +; RV32-NEXT: fadd.s fa5, fa5, fa4 +; RV32-NEXT: fmv.x.w a0, fa5 +; RV32-NEXT: srli a0, a0, 16 +; RV32-NEXT: sh a0, 0(a1) +; RV32-NEXT: ret + %a = load bfloat, ptr %pa + %add = fadd bfloat %a, 1.0 + store bfloat %add, ptr %pc + ret void +} + +define bfloat @add_constant2(bfloat %a) nounwind { +; RV64-LABEL: add_constant2: +; RV64: # %bb.0: +; RV64-NEXT: fmv.x.w a0, fa0 +; RV64-NEXT: slli a0, a0, 48 +; RV64-NEXT: srli a0, a0, 48 +; RV64-NEXT: slli a0, a0, 16 +; RV64-NEXT: fmv.w.x fa5, a0 +; RV64-NEXT: lui a0, 260096 +; RV64-NEXT: fmv.w.x fa4, a0 +; RV64-NEXT: fadd.s fa5, fa5, fa4 +; RV64-NEXT: fmv.x.w a0, fa5 +; RV64-NEXT: srli a0, a0, 16 +; RV64-NEXT: lui a1, 1048560 +; RV64-NEXT: or a0, a0, a1 +; RV64-NEXT: fmv.w.x fa0, a0 +; RV64-NEXT: ret +; +; RV32-LABEL: add_constant2: +; RV32: # %bb.0: +; RV32-NEXT: fmv.x.w a0, fa0 +; RV32-NEXT: slli a0, a0, 16 +; RV32-NEXT: fmv.w.x fa5, a0 +; RV32-NEXT: lui a0, 260096 +; RV32-NEXT: fmv.w.x fa4, a0 +; RV32-NEXT: fadd.s fa5, fa5, fa4 +; RV32-NEXT: fmv.x.w a0, fa5 +; RV32-NEXT: srli a0, a0, 16 +; RV32-NEXT: lui a1, 1048560 +; RV32-NEXT: or a0, a0, a1 +; RV32-NEXT: fmv.w.x fa0, a0 +; RV32-NEXT: ret + %add = fadd bfloat %a, 1.0 + ret bfloat %add +} + +define void @store_constant(ptr %pc) nounwind { +; RV64-LABEL: store_constant: +; RV64: # %bb.0: +; RV64-NEXT: lui a1, 4 +; RV64-NEXT: addiw a1, a1, -128 +; RV64-NEXT: sh a1, 0(a0) +; RV64-NEXT: ret +; +; RV32-LABEL: store_constant: +; RV32: # %bb.0: +; RV32-NEXT: lui a1, 4 +; RV32-NEXT: addi a1, a1, -128 +; RV32-NEXT: sh a1, 0(a0) +; RV32-NEXT: ret + store bfloat 1.0, ptr %pc + ret void +} + +define void @fold_ext_trunc(ptr %pa, ptr %pc) nounwind { +; RV64-LABEL: fold_ext_trunc: +; RV64: # %bb.0: +; RV64-NEXT: lh a0, 0(a0) +; RV64-NEXT: sh a0, 0(a1) +; RV64-NEXT: ret +; +; RV32-LABEL: fold_ext_trunc: +; RV32: # %bb.0: +; RV32-NEXT: lh a0, 0(a0) +; RV32-NEXT: sh a0, 0(a1) +; RV32-NEXT: ret + %a = load bfloat, ptr %pa + %ext = fpext bfloat %a to float + %trunc = fptrunc float %ext to bfloat + store bfloat %trunc, ptr %pc + ret void +} + +define bfloat @fold_ext_trunc2(bfloat %a) nounwind { +; RV64-LABEL: fold_ext_trunc2: +; RV64: # %bb.0: +; RV64-NEXT: fmv.x.w a0, fa0 +; RV64-NEXT: lui a1, 1048560 +; RV64-NEXT: or a0, a0, a1 +; RV64-NEXT: fmv.w.x fa0, a0 +; RV64-NEXT: ret +; +; RV32-LABEL: fold_ext_trunc2: +; RV32: # %bb.0: +; RV32-NEXT: fmv.x.w a0, fa0 +; RV32-NEXT: lui a1, 1048560 +; RV32-NEXT: or a0, a0, a1 +; RV32-NEXT: fmv.w.x fa0, a0 +; RV32-NEXT: ret + %ext = fpext bfloat %a to float + %trunc = fptrunc float %ext to bfloat + ret bfloat %trunc +} \ No newline at end of file