diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -2907,8 +2907,7 @@ break; case ISD::BF16_TO_FP: { // Always expand bf16 to f32 casts, they lower to ext + shift. - SDValue Op = DAG.getNode(ISD::BITCAST, dl, MVT::i16, Node->getOperand(0)); - Op = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Op); + SDValue Op = DAG.getNode(ISD::BITCAST, dl, MVT::i32, Node->getOperand(0)); Op = DAG.getNode( ISD::SHL, dl, MVT::i32, Op, DAG.getConstant(16, dl, diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp @@ -2132,10 +2132,16 @@ SDValue Promoted = GetPromotedFloat(N->getOperand(0)); EVT PromotedVT = Promoted->getValueType(0); + unsigned Opc = GetPromotionOpcode(PromotedVT, OpVT); + if (OpVT == MVT::bf16 && !isa(Op)) { + SDValue Convert = DAG.getNode(Opc, SDLoc(N), MVT::f32, Promoted); + return DAG.getNode(ISD::TRUNCATE, SDLoc(N), MVT::i16, + DAG.getBitcast(MVT::i32, Convert)); + } + // Convert the promoted float value to the desired IVT. EVT IVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits()); - SDValue Convert = DAG.getNode(GetPromotionOpcode(PromotedVT, OpVT), SDLoc(N), - IVT, Promoted); + SDValue Convert = DAG.getNode(Opc, SDLoc(N), IVT, Promoted); // The final result type might not be an scalar so we need a bitcast. The // bitcast will be further legalized if needed. return DAG.getBitcast(N->getValueType(0), Convert); @@ -2209,11 +2215,17 @@ SDValue Promoted = GetPromotedFloat(Val); EVT VT = ST->getOperand(1).getValueType(); - EVT IVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits()); + unsigned Opc = GetPromotionOpcode(Promoted.getValueType(), VT); SDValue NewVal; - NewVal = DAG.getNode(GetPromotionOpcode(Promoted.getValueType(), VT), DL, - IVT, Promoted); + if (VT == MVT::bf16 && !isa(ST->getOperand(1))) { + NewVal = DAG.getNode(Opc, DL, MVT::f32, Promoted); + NewVal = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, + DAG.getBitcast(MVT::i32, NewVal)); + } else { + EVT IVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits()); + NewVal = DAG.getNode(Opc, DL, IVT, Promoted); + } return DAG.getStore(ST->getChain(), DL, NewVal, ST->getBasePtr(), ST->getMemOperand()); @@ -2323,11 +2335,17 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_BITCAST(SDNode *N) { EVT VT = N->getValueType(0); EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT); - // Input type isn't guaranteed to be a scalar int so bitcast if not. The - // bitcast will be legalized further if necessary. - EVT IVT = EVT::getIntegerVT(*DAG.getContext(), - N->getOperand(0).getValueType().getSizeInBits()); - SDValue Cast = DAG.getBitcast(IVT, N->getOperand(0)); + SDValue Cast; + if (VT == MVT::bf16) { + Cast = DAG.getBitcast(MVT::f32, DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), + MVT::i32, N->getOperand(0))); + } else { + // Input type isn't guaranteed to be a scalar int so bitcast if not. The + // bitcast will be legalized further if necessary. + EVT IVT = EVT::getIntegerVT( + *DAG.getContext(), N->getOperand(0).getValueType().getSizeInBits()); + Cast = DAG.getBitcast(IVT, N->getOperand(0)); + } return DAG.getNode(GetPromotionOpcode(VT, NVT), SDLoc(N), NVT, Cast); } @@ -2475,10 +2493,12 @@ EVT VT = N->getValueType(0); EVT OpVT = Op->getValueType(0); EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); - EVT IVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits()); + EVT RVT = VT == MVT::bf16 + ? MVT::f32 + : EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits()); // Round promoted float to desired precision - SDValue Round = DAG.getNode(GetPromotionOpcode(OpVT, VT), DL, IVT, Op); + SDValue Round = DAG.getNode(GetPromotionOpcode(OpVT, VT), DL, RVT, Op); // Promote it back to the legal output type return DAG.getNode(GetPromotionOpcode(VT, NVT), DL, NVT, Round); } @@ -2497,6 +2517,10 @@ // new one ReplaceValueWith(SDValue(N, 1), newL.getValue(1)); + if (VT == MVT::bf16) + newL = DAG.getBitcast( + MVT::f32, DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), MVT::i32, newL)); + // Convert the integer value to the desired FP type EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT); return DAG.getNode(GetPromotionOpcode(VT, NVT), SDLoc(N), NVT, newL); diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -1621,6 +1621,17 @@ MachineBasicBlock *Entry, const SmallVectorImpl &Exits) const override; + bool + splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val, + SDValue *Parts, unsigned NumParts, MVT PartVT, + Optional CC) const override; + + SDValue + joinRegisterPartsIntoValue(SelectionDAG &DAG, const SDLoc &DL, + const SDValue *Parts, unsigned NumParts, + MVT PartVT, EVT ValueVT, + Optional CC) const override; + bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override; bool mayBeEmittedAsTailCall(const CallInst *CI) const override; diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -2709,6 +2709,40 @@ return TargetLowering::getJumpTableEncoding(); } +bool X86TargetLowering::splitValueIntoRegisterParts( + SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts, + unsigned NumParts, MVT PartVT, Optional CC) const { + bool IsABIRegCopy = CC.has_value(); + EVT ValueVT = Val.getValueType(); + if (IsABIRegCopy && ValueVT == MVT::bf16 && PartVT == MVT::f32) { + unsigned ValueBits = ValueVT.getSizeInBits(); + unsigned PartBits = PartVT.getSizeInBits(); + Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(ValueBits), Val); + Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::getIntegerVT(PartBits), Val); + Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val); + Parts[0] = Val; + return true; + } + return false; +} + +SDValue X86TargetLowering::joinRegisterPartsIntoValue( + SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts, + MVT PartVT, EVT ValueVT, Optional CC) const { + bool IsABIRegCopy = CC.has_value(); + if (IsABIRegCopy && ValueVT == MVT::bf16 && PartVT == MVT::f32) { + unsigned ValueBits = ValueVT.getSizeInBits(); + unsigned PartBits = PartVT.getSizeInBits(); + SDValue Val = Parts[0]; + + Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(PartBits), Val); + Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::getIntegerVT(ValueBits), Val); + Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); + return Val; + } + return SDValue(); +} + bool X86TargetLowering::useSoftFloat() const { return Subtarget.useSoftFloat(); } diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll --- a/llvm/test/CodeGen/X86/bfloat.ll +++ b/llvm/test/CodeGen/X86/bfloat.ll @@ -1,12 +1,10 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc < %s -mtriple=x86_64-linux-gnu | FileCheck %s -define void @add(ptr %pa, ptr %pb, ptr %pc) { +define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind { ; CHECK-LABEL: add: ; CHECK: # %bb.0: ; CHECK-NEXT: pushq %rbx -; CHECK-NEXT: .cfi_def_cfa_offset 16 -; CHECK-NEXT: .cfi_offset %rbx, -16 ; CHECK-NEXT: movq %rdx, %rbx ; CHECK-NEXT: movzwl (%rdi), %eax ; CHECK-NEXT: shll $16, %eax @@ -16,9 +14,9 @@ ; CHECK-NEXT: movd %eax, %xmm0 ; CHECK-NEXT: addss %xmm1, %xmm0 ; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: movd %xmm0, %eax ; CHECK-NEXT: movw %ax, (%rbx) ; CHECK-NEXT: popq %rbx -; CHECK-NEXT: .cfi_def_cfa_offset 8 ; CHECK-NEXT: retq %a = load bfloat, ptr %pa %b = load bfloat, ptr %pb @@ -27,38 +25,48 @@ ret void } -define void @add_double(ptr %pa, ptr %pb, ptr %pc) { +define bfloat @add2(bfloat %a, bfloat %b, bfloat %c) nounwind { +; CHECK-LABEL: add2: +; CHECK: # %bb.0: +; CHECK-NEXT: pushq %rax +; CHECK-NEXT: movd %xmm1, %eax +; CHECK-NEXT: shll $16, %eax +; CHECK-NEXT: movd %eax, %xmm1 +; CHECK-NEXT: movd %xmm0, %eax +; CHECK-NEXT: shll $16, %eax +; CHECK-NEXT: movd %eax, %xmm0 +; CHECK-NEXT: addss %xmm1, %xmm0 +; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: popq %rax +; CHECK-NEXT: retq + %add = fadd bfloat %a, %b + ret bfloat %add +} + +define void @add_double(ptr %pa, ptr %pb, ptr %pc) nounwind { ; CHECK-LABEL: add_double: ; CHECK: # %bb.0: ; CHECK-NEXT: pushq %r14 -; CHECK-NEXT: .cfi_def_cfa_offset 16 ; CHECK-NEXT: pushq %rbx -; CHECK-NEXT: .cfi_def_cfa_offset 24 ; CHECK-NEXT: pushq %rax -; CHECK-NEXT: .cfi_def_cfa_offset 32 -; CHECK-NEXT: .cfi_offset %rbx, -24 -; CHECK-NEXT: .cfi_offset %r14, -16 ; CHECK-NEXT: movq %rdx, %r14 ; CHECK-NEXT: movq %rsi, %rbx -; CHECK-NEXT: movsd {{.*#+}} xmm0 = mem[0],zero +; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero ; CHECK-NEXT: callq __truncdfbf2@PLT -; CHECK-NEXT: # kill: def $ax killed $ax def $eax +; CHECK-NEXT: movd %xmm0, %eax ; CHECK-NEXT: shll $16, %eax ; CHECK-NEXT: movl %eax, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill -; CHECK-NEXT: movsd {{.*#+}} xmm0 = mem[0],zero +; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero ; CHECK-NEXT: callq __truncdfbf2@PLT -; CHECK-NEXT: # kill: def $ax killed $ax def $eax +; CHECK-NEXT: movd %xmm0, %eax ; CHECK-NEXT: shll $16, %eax ; CHECK-NEXT: movd %eax, %xmm0 ; CHECK-NEXT: addss {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload ; CHECK-NEXT: cvtss2sd %xmm0, %xmm0 ; CHECK-NEXT: movsd %xmm0, (%r14) ; CHECK-NEXT: addq $8, %rsp -; CHECK-NEXT: .cfi_def_cfa_offset 24 ; CHECK-NEXT: popq %rbx -; CHECK-NEXT: .cfi_def_cfa_offset 16 ; CHECK-NEXT: popq %r14 -; CHECK-NEXT: .cfi_def_cfa_offset 8 ; CHECK-NEXT: retq %la = load double, ptr %pa %a = fptrunc double %la to bfloat @@ -70,21 +78,45 @@ ret void } -define void @add_constant(ptr %pa, ptr %pc) { +define double @add_double2(double %da, double %db, double %dc) nounwind { +; CHECK-LABEL: add_double2: +; CHECK: # %bb.0: +; CHECK-NEXT: subq $24, %rsp +; CHECK-NEXT: movsd %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill +; CHECK-NEXT: callq __truncdfbf2@PLT +; CHECK-NEXT: movd %xmm0, %eax +; CHECK-NEXT: shll $16, %eax +; CHECK-NEXT: movl %eax, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill +; CHECK-NEXT: movq {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 8-byte Folded Reload +; CHECK-NEXT: # xmm0 = mem[0],zero +; CHECK-NEXT: callq __truncdfbf2@PLT +; CHECK-NEXT: movd %xmm0, %eax +; CHECK-NEXT: shll $16, %eax +; CHECK-NEXT: movd %eax, %xmm0 +; CHECK-NEXT: addss {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload +; CHECK-NEXT: cvtss2sd %xmm0, %xmm0 +; CHECK-NEXT: addq $24, %rsp +; CHECK-NEXT: retq + %a = fptrunc double %da to bfloat + %b = fptrunc double %db to bfloat + %add = fadd bfloat %a, %b + %dadd = fpext bfloat %add to double + ret double %dadd +} + +define void @add_constant(ptr %pa, ptr %pc) nounwind { ; CHECK-LABEL: add_constant: ; CHECK: # %bb.0: ; CHECK-NEXT: pushq %rbx -; CHECK-NEXT: .cfi_def_cfa_offset 16 -; CHECK-NEXT: .cfi_offset %rbx, -16 ; CHECK-NEXT: movq %rsi, %rbx ; CHECK-NEXT: movzwl (%rdi), %eax ; CHECK-NEXT: shll $16, %eax ; CHECK-NEXT: movd %eax, %xmm0 ; CHECK-NEXT: addss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 ; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: movd %xmm0, %eax ; CHECK-NEXT: movw %ax, (%rbx) ; CHECK-NEXT: popq %rbx -; CHECK-NEXT: .cfi_def_cfa_offset 8 ; CHECK-NEXT: retq %a = load bfloat, ptr %pa %add = fadd bfloat %a, 1.0 @@ -92,7 +124,22 @@ ret void } -define void @store_constant(ptr %pc) { +define bfloat @add_constant2(bfloat %a) nounwind { +; CHECK-LABEL: add_constant2: +; CHECK: # %bb.0: +; CHECK-NEXT: pushq %rax +; CHECK-NEXT: movd %xmm0, %eax +; CHECK-NEXT: shll $16, %eax +; CHECK-NEXT: movd %eax, %xmm0 +; CHECK-NEXT: addss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 +; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: popq %rax +; CHECK-NEXT: retq + %add = fadd bfloat %a, 1.0 + ret bfloat %add +} + +define void @store_constant(ptr %pc) nounwind { ; CHECK-LABEL: store_constant: ; CHECK: # %bb.0: ; CHECK-NEXT: movw $16256, (%rdi) # imm = 0x3F80 @@ -101,7 +148,7 @@ ret void } -define void @fold_ext_trunc(ptr %pa, ptr %pc) { +define void @fold_ext_trunc(ptr %pa, ptr %pc) nounwind { ; CHECK-LABEL: fold_ext_trunc: ; CHECK: # %bb.0: ; CHECK-NEXT: movzwl (%rdi), %eax @@ -113,3 +160,12 @@ store bfloat %trunc, ptr %pc ret void } + +define bfloat @fold_ext_trunc2(bfloat %a) nounwind { +; CHECK-LABEL: fold_ext_trunc2: +; CHECK: # %bb.0: +; CHECK-NEXT: retq + %ext = fpext bfloat %a to float + %trunc = fptrunc float %ext to bfloat + ret bfloat %trunc +}