diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp @@ -107,6 +107,7 @@ case ISD::STRICT_FP_ROUND: case ISD::FP_ROUND: R = SoftenFloatRes_FP_ROUND(N); break; case ISD::FP16_TO_FP: R = SoftenFloatRes_FP16_TO_FP(N); break; + case ISD::BF16_TO_FP: R = SoftenFloatRes_BF16_TO_FP(N); break; case ISD::STRICT_FPOW: case ISD::FPOW: R = SoftenFloatRes_FPOW(N); break; case ISD::STRICT_FPOWI: @@ -510,10 +511,12 @@ return BitConvertToInteger(Op); } - // There's only a libcall for f16 -> f32, so proceed in two stages. Also, it's - // entirely possible for both f16 and f32 to be legal, so use the fully - // hard-float FP_EXTEND rather than FP16_TO_FP. - if (Op.getValueType() == MVT::f16 && N->getValueType(0) != MVT::f32) { + // There's only a libcall for f16 -> f32 and shifting is only valid for bf16 + // -> f32, so proceed in two stages. Also, it's entirely possible for both + // f16 and f32 to be legal, so use the fully hard-float FP_EXTEND rather + // than FP16_TO_FP. + if ((Op.getValueType() == MVT::f16 || Op.getValueType() == MVT::bf16) && + N->getValueType(0) != MVT::f32) { if (IsStrict) { Op = DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(N), { MVT::f32, MVT::Other }, { Chain, Op }); @@ -523,6 +526,9 @@ } } + if (Op.getValueType() == MVT::bf16) + return SoftenFloatRes_BF16_TO_FP(N); + RTLIB::Libcall LC = RTLIB::getFPEXT(Op.getValueType(), N->getValueType(0)); assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported FP_EXTEND!"); TargetLowering::MakeLibCallOptions CallOptions; @@ -555,6 +561,21 @@ return TLI.makeLibCall(DAG, LC, NVT, Res32, CallOptions, SDLoc(N)).first; } +// FIXME: Should we just use 'normal' FP_EXTEND / FP_TRUNC instead of special +// nodes? +SDValue DAGTypeLegalizer::SoftenFloatRes_BF16_TO_FP(SDNode *N) { + assert(N->getValueType(0) == MVT::f32 && + "Can only soften BF16_TO_FP with f32 result"); + EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), MVT::f32); + SDValue Op = N->getOperand(0); + SDLoc DL(N); + Op = DAG.getNode(ISD::ANY_EXTEND, DL, NVT, + DAG.getNode(ISD::BITCAST, DL, MVT::i16, Op)); + SDValue Res = DAG.getNode(ISD::SHL, DL, NVT, Op, + DAG.getShiftAmountConstant(16, NVT, DL)); + return Res; +} + SDValue DAGTypeLegalizer::SoftenFloatRes_FP_ROUND(SDNode *N) { bool IsStrict = N->isStrictFPOpcode(); EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -560,6 +560,7 @@ SDValue SoftenFloatRes_FNEG(SDNode *N); SDValue SoftenFloatRes_FP_EXTEND(SDNode *N); SDValue SoftenFloatRes_FP16_TO_FP(SDNode *N); + SDValue SoftenFloatRes_BF16_TO_FP(SDNode *N); SDValue SoftenFloatRes_FP_ROUND(SDNode *N); SDValue SoftenFloatRes_FPOW(SDNode *N); SDValue SoftenFloatRes_FPOWI(SDNode *N); diff --git a/llvm/test/CodeGen/RISCV/bfloat.ll b/llvm/test/CodeGen/RISCV/bfloat.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/bfloat.ll @@ -0,0 +1,116 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -verify-machineinstrs < %s | FileCheck %s -check-prefix=RV32I-ILP32 +; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s | FileCheck %s -check-prefix=RV64I-LP64 + +; TODO: Enable codegen for hard float. + +define bfloat @float_to_bfloat(float %a) nounwind { +; RV32I-ILP32-LABEL: float_to_bfloat: +; RV32I-ILP32: # %bb.0: +; RV32I-ILP32-NEXT: addi sp, sp, -16 +; RV32I-ILP32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill +; RV32I-ILP32-NEXT: call __truncsfbf2@plt +; RV32I-ILP32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload +; RV32I-ILP32-NEXT: addi sp, sp, 16 +; RV32I-ILP32-NEXT: ret +; +; RV64I-LP64-LABEL: float_to_bfloat: +; RV64I-LP64: # %bb.0: +; RV64I-LP64-NEXT: addi sp, sp, -16 +; RV64I-LP64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill +; RV64I-LP64-NEXT: call __truncsfbf2@plt +; RV64I-LP64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload +; RV64I-LP64-NEXT: addi sp, sp, 16 +; RV64I-LP64-NEXT: ret + %1 = fptrunc float %a to bfloat + ret bfloat %1 +} + +define bfloat @double_to_bfloat(double %a) nounwind { +; RV32I-ILP32-LABEL: double_to_bfloat: +; RV32I-ILP32: # %bb.0: +; RV32I-ILP32-NEXT: addi sp, sp, -16 +; RV32I-ILP32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill +; RV32I-ILP32-NEXT: call __truncdfbf2@plt +; RV32I-ILP32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload +; RV32I-ILP32-NEXT: addi sp, sp, 16 +; RV32I-ILP32-NEXT: ret +; +; RV64I-LP64-LABEL: double_to_bfloat: +; RV64I-LP64: # %bb.0: +; RV64I-LP64-NEXT: addi sp, sp, -16 +; RV64I-LP64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill +; RV64I-LP64-NEXT: call __truncdfbf2@plt +; RV64I-LP64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload +; RV64I-LP64-NEXT: addi sp, sp, 16 +; RV64I-LP64-NEXT: ret + %1 = fptrunc double %a to bfloat + ret bfloat %1 +} + +define float @bfloat_to_float(bfloat %a) nounwind { +; RV32I-ILP32-LABEL: bfloat_to_float: +; RV32I-ILP32: # %bb.0: +; RV32I-ILP32-NEXT: slli a0, a0, 16 +; RV32I-ILP32-NEXT: ret +; +; RV64I-LP64-LABEL: bfloat_to_float: +; RV64I-LP64: # %bb.0: +; RV64I-LP64-NEXT: slliw a0, a0, 16 +; RV64I-LP64-NEXT: ret + %1 = fpext bfloat %a to float + ret float %1 +} + +define double @bfloat_to_double(bfloat %a) nounwind { +; RV32I-ILP32-LABEL: bfloat_to_double: +; RV32I-ILP32: # %bb.0: +; RV32I-ILP32-NEXT: addi sp, sp, -16 +; RV32I-ILP32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill +; RV32I-ILP32-NEXT: slli a0, a0, 16 +; RV32I-ILP32-NEXT: call __extendsfdf2@plt +; RV32I-ILP32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload +; RV32I-ILP32-NEXT: addi sp, sp, 16 +; RV32I-ILP32-NEXT: ret +; +; RV64I-LP64-LABEL: bfloat_to_double: +; RV64I-LP64: # %bb.0: +; RV64I-LP64-NEXT: addi sp, sp, -16 +; RV64I-LP64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill +; RV64I-LP64-NEXT: slli a0, a0, 48 +; RV64I-LP64-NEXT: srli a0, a0, 32 +; RV64I-LP64-NEXT: call __extendsfdf2@plt +; RV64I-LP64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload +; RV64I-LP64-NEXT: addi sp, sp, 16 +; RV64I-LP64-NEXT: ret + %1 = fpext bfloat %a to double + ret double %1 +} + +define bfloat @bfloat_add(bfloat %a, bfloat %b) nounwind { +; RV32I-ILP32-LABEL: bfloat_add: +; RV32I-ILP32: # %bb.0: +; RV32I-ILP32-NEXT: addi sp, sp, -16 +; RV32I-ILP32-NEXT: sw ra, 12(sp) # 4-byte Folded Spill +; RV32I-ILP32-NEXT: slli a0, a0, 16 +; RV32I-ILP32-NEXT: slli a1, a1, 16 +; RV32I-ILP32-NEXT: call __addsf3@plt +; RV32I-ILP32-NEXT: call __truncsfbf2@plt +; RV32I-ILP32-NEXT: lw ra, 12(sp) # 4-byte Folded Reload +; RV32I-ILP32-NEXT: addi sp, sp, 16 +; RV32I-ILP32-NEXT: ret +; +; RV64I-LP64-LABEL: bfloat_add: +; RV64I-LP64: # %bb.0: +; RV64I-LP64-NEXT: addi sp, sp, -16 +; RV64I-LP64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill +; RV64I-LP64-NEXT: slliw a0, a0, 16 +; RV64I-LP64-NEXT: slliw a1, a1, 16 +; RV64I-LP64-NEXT: call __addsf3@plt +; RV64I-LP64-NEXT: call __truncsfbf2@plt +; RV64I-LP64-NEXT: ld ra, 8(sp) # 8-byte Folded Reload +; RV64I-LP64-NEXT: addi sp, sp, 16 +; RV64I-LP64-NEXT: ret + %1 = fadd bfloat %a, %b + ret bfloat %1 +}