diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -1598,6 +1598,7 @@ SDValue lowerFaddFsub(SDValue Op, SelectionDAG &DAG) const; SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const; SDValue LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const; SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, @@ -1621,6 +1622,17 @@ MachineBasicBlock *Entry, const SmallVectorImpl &Exits) const override; + bool + splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val, + SDValue *Parts, unsigned NumParts, MVT PartVT, + Optional CC) const override; + + SDValue + joinRegisterPartsIntoValue(SelectionDAG &DAG, const SDLoc &DL, + const SDValue *Parts, unsigned NumParts, + MVT PartVT, EVT ValueVT, + Optional CC) const override; + bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override; bool mayBeEmittedAsTailCall(const CallInst *CI) const override; diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -418,7 +418,7 @@ setTruncStoreAction(VT, MVT::bf16, Expand); setOperationAction(ISD::BF16_TO_FP, VT, Expand); - setOperationAction(ISD::FP_TO_BF16, VT, Expand); + setOperationAction(ISD::FP_TO_BF16, VT, Custom); } setOperationAction(ISD::PARITY, MVT::i8, Custom); @@ -2495,6 +2495,10 @@ !Subtarget.hasX87()) return MVT::i32; + if (VT.isVector() && VT.getVectorElementType() == MVT::bf16) + return getRegisterTypeForCallingConv(Context, CC, + VT.changeVectorElementTypeToInteger()); + return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT); } @@ -2526,6 +2530,10 @@ return 3; } + if (VT.isVector() && VT.getVectorElementType() == MVT::bf16) + return getNumRegistersForCallingConv(Context, CC, + VT.changeVectorElementTypeToInteger()); + return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT); } @@ -2734,6 +2742,40 @@ return TargetLowering::getJumpTableEncoding(); } +bool X86TargetLowering::splitValueIntoRegisterParts( + SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts, + unsigned NumParts, MVT PartVT, Optional CC) const { + bool IsABIRegCopy = CC.has_value(); + EVT ValueVT = Val.getValueType(); + if (IsABIRegCopy && ValueVT == MVT::bf16 && PartVT == MVT::f32) { + unsigned ValueBits = ValueVT.getSizeInBits(); + unsigned PartBits = PartVT.getSizeInBits(); + Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(ValueBits), Val); + Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::getIntegerVT(PartBits), Val); + Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val); + Parts[0] = Val; + return true; + } + return false; +} + +SDValue X86TargetLowering::joinRegisterPartsIntoValue( + SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts, + MVT PartVT, EVT ValueVT, Optional CC) const { + bool IsABIRegCopy = CC.has_value(); + if (IsABIRegCopy && ValueVT == MVT::bf16 && PartVT == MVT::f32) { + unsigned ValueBits = ValueVT.getSizeInBits(); + unsigned PartBits = PartVT.getSizeInBits(); + SDValue Val = Parts[0]; + + Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(PartBits), Val); + Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::getIntegerVT(ValueBits), Val); + Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); + return Val; + } + return SDValue(); +} + bool X86TargetLowering::useSoftFloat() const { return Subtarget.useSoftFloat(); } @@ -23041,6 +23083,18 @@ return Res; } +SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + MakeLibCallOptions CallOptions; + RTLIB::Libcall LC = + RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16); + SDValue Res = + makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first; + return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, + DAG.getBitcast(MVT::i32, Res)); +} + /// Depending on uarch and/or optimizing for size, we might prefer to use a /// vector operation in place of the typical scalar operation. static SDValue lowerAddSubToHorizontalOp(SDValue Op, SelectionDAG &DAG, @@ -32233,6 +32287,7 @@ case ISD::STRICT_FP16_TO_FP: return LowerFP16_TO_FP(Op, DAG); case ISD::FP_TO_FP16: case ISD::STRICT_FP_TO_FP16: return LowerFP_TO_FP16(Op, DAG); + case ISD::FP_TO_BF16: return LowerFP_TO_BF16(Op, DAG); case ISD::LOAD: return LowerLoad(Op, Subtarget, DAG); case ISD::STORE: return LowerStore(Op, Subtarget, DAG); case ISD::FADD: diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll --- a/llvm/test/CodeGen/X86/bfloat.ll +++ b/llvm/test/CodeGen/X86/bfloat.ll @@ -1,12 +1,10 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc < %s -mtriple=x86_64-linux-gnu | FileCheck %s -define void @add(ptr %pa, ptr %pb, ptr %pc) { +define void @add(ptr %pa, ptr %pb, ptr %pc) nounwind { ; CHECK-LABEL: add: ; CHECK: # %bb.0: ; CHECK-NEXT: pushq %rbx -; CHECK-NEXT: .cfi_def_cfa_offset 16 -; CHECK-NEXT: .cfi_offset %rbx, -16 ; CHECK-NEXT: movq %rdx, %rbx ; CHECK-NEXT: movzwl (%rdi), %eax ; CHECK-NEXT: shll $16, %eax @@ -16,9 +14,9 @@ ; CHECK-NEXT: movd %eax, %xmm0 ; CHECK-NEXT: addss %xmm1, %xmm0 ; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: movd %xmm0, %eax ; CHECK-NEXT: movw %ax, (%rbx) ; CHECK-NEXT: popq %rbx -; CHECK-NEXT: .cfi_def_cfa_offset 8 ; CHECK-NEXT: retq %a = load bfloat, ptr %pa %b = load bfloat, ptr %pb @@ -27,38 +25,48 @@ ret void } -define void @add_double(ptr %pa, ptr %pb, ptr %pc) { +define bfloat @add2(bfloat %a, bfloat %b) nounwind { +; CHECK-LABEL: add2: +; CHECK: # %bb.0: +; CHECK-NEXT: pushq %rax +; CHECK-NEXT: movd %xmm1, %eax +; CHECK-NEXT: shll $16, %eax +; CHECK-NEXT: movd %eax, %xmm1 +; CHECK-NEXT: movd %xmm0, %eax +; CHECK-NEXT: shll $16, %eax +; CHECK-NEXT: movd %eax, %xmm0 +; CHECK-NEXT: addss %xmm1, %xmm0 +; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: popq %rax +; CHECK-NEXT: retq + %add = fadd bfloat %a, %b + ret bfloat %add +} + +define void @add_double(ptr %pa, ptr %pb, ptr %pc) nounwind { ; CHECK-LABEL: add_double: ; CHECK: # %bb.0: ; CHECK-NEXT: pushq %r14 -; CHECK-NEXT: .cfi_def_cfa_offset 16 ; CHECK-NEXT: pushq %rbx -; CHECK-NEXT: .cfi_def_cfa_offset 24 ; CHECK-NEXT: pushq %rax -; CHECK-NEXT: .cfi_def_cfa_offset 32 -; CHECK-NEXT: .cfi_offset %rbx, -24 -; CHECK-NEXT: .cfi_offset %r14, -16 ; CHECK-NEXT: movq %rdx, %r14 ; CHECK-NEXT: movq %rsi, %rbx -; CHECK-NEXT: movsd {{.*#+}} xmm0 = mem[0],zero +; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero ; CHECK-NEXT: callq __truncdfbf2@PLT -; CHECK-NEXT: # kill: def $ax killed $ax def $eax +; CHECK-NEXT: movd %xmm0, %eax ; CHECK-NEXT: shll $16, %eax ; CHECK-NEXT: movl %eax, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill -; CHECK-NEXT: movsd {{.*#+}} xmm0 = mem[0],zero +; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero ; CHECK-NEXT: callq __truncdfbf2@PLT -; CHECK-NEXT: # kill: def $ax killed $ax def $eax +; CHECK-NEXT: movd %xmm0, %eax ; CHECK-NEXT: shll $16, %eax ; CHECK-NEXT: movd %eax, %xmm0 ; CHECK-NEXT: addss {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload ; CHECK-NEXT: cvtss2sd %xmm0, %xmm0 ; CHECK-NEXT: movsd %xmm0, (%r14) ; CHECK-NEXT: addq $8, %rsp -; CHECK-NEXT: .cfi_def_cfa_offset 24 ; CHECK-NEXT: popq %rbx -; CHECK-NEXT: .cfi_def_cfa_offset 16 ; CHECK-NEXT: popq %r14 -; CHECK-NEXT: .cfi_def_cfa_offset 8 ; CHECK-NEXT: retq %la = load double, ptr %pa %a = fptrunc double %la to bfloat @@ -70,21 +78,45 @@ ret void } -define void @add_constant(ptr %pa, ptr %pc) { +define double @add_double2(double %da, double %db) nounwind { +; CHECK-LABEL: add_double2: +; CHECK: # %bb.0: +; CHECK-NEXT: subq $24, %rsp +; CHECK-NEXT: movsd %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill +; CHECK-NEXT: callq __truncdfbf2@PLT +; CHECK-NEXT: movd %xmm0, %eax +; CHECK-NEXT: shll $16, %eax +; CHECK-NEXT: movl %eax, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill +; CHECK-NEXT: movq {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 8-byte Folded Reload +; CHECK-NEXT: # xmm0 = mem[0],zero +; CHECK-NEXT: callq __truncdfbf2@PLT +; CHECK-NEXT: movd %xmm0, %eax +; CHECK-NEXT: shll $16, %eax +; CHECK-NEXT: movd %eax, %xmm0 +; CHECK-NEXT: addss {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload +; CHECK-NEXT: cvtss2sd %xmm0, %xmm0 +; CHECK-NEXT: addq $24, %rsp +; CHECK-NEXT: retq + %a = fptrunc double %da to bfloat + %b = fptrunc double %db to bfloat + %add = fadd bfloat %a, %b + %dadd = fpext bfloat %add to double + ret double %dadd +} + +define void @add_constant(ptr %pa, ptr %pc) nounwind { ; CHECK-LABEL: add_constant: ; CHECK: # %bb.0: ; CHECK-NEXT: pushq %rbx -; CHECK-NEXT: .cfi_def_cfa_offset 16 -; CHECK-NEXT: .cfi_offset %rbx, -16 ; CHECK-NEXT: movq %rsi, %rbx ; CHECK-NEXT: movzwl (%rdi), %eax ; CHECK-NEXT: shll $16, %eax ; CHECK-NEXT: movd %eax, %xmm0 ; CHECK-NEXT: addss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 ; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: movd %xmm0, %eax ; CHECK-NEXT: movw %ax, (%rbx) ; CHECK-NEXT: popq %rbx -; CHECK-NEXT: .cfi_def_cfa_offset 8 ; CHECK-NEXT: retq %a = load bfloat, ptr %pa %add = fadd bfloat %a, 1.0 @@ -92,7 +124,22 @@ ret void } -define void @store_constant(ptr %pc) { +define bfloat @add_constant2(bfloat %a) nounwind { +; CHECK-LABEL: add_constant2: +; CHECK: # %bb.0: +; CHECK-NEXT: pushq %rax +; CHECK-NEXT: movd %xmm0, %eax +; CHECK-NEXT: shll $16, %eax +; CHECK-NEXT: movd %eax, %xmm0 +; CHECK-NEXT: addss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 +; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: popq %rax +; CHECK-NEXT: retq + %add = fadd bfloat %a, 1.0 + ret bfloat %add +} + +define void @store_constant(ptr %pc) nounwind { ; CHECK-LABEL: store_constant: ; CHECK: # %bb.0: ; CHECK-NEXT: movw $16256, (%rdi) # imm = 0x3F80 @@ -101,7 +148,7 @@ ret void } -define void @fold_ext_trunc(ptr %pa, ptr %pc) { +define void @fold_ext_trunc(ptr %pa, ptr %pc) nounwind { ; CHECK-LABEL: fold_ext_trunc: ; CHECK: # %bb.0: ; CHECK-NEXT: movzwl (%rdi), %eax @@ -113,3 +160,150 @@ store bfloat %trunc, ptr %pc ret void } + +define bfloat @fold_ext_trunc2(bfloat %a) nounwind { +; CHECK-LABEL: fold_ext_trunc2: +; CHECK: # %bb.0: +; CHECK-NEXT: retq + %ext = fpext bfloat %a to float + %trunc = fptrunc float %ext to bfloat + ret bfloat %trunc +} + +define <8 x bfloat> @addv(<8 x bfloat> %a, <8 x bfloat> %b) nounwind { +; CHECK-LABEL: addv: +; CHECK: # %bb.0: +; CHECK-NEXT: pushq %rbp +; CHECK-NEXT: pushq %r14 +; CHECK-NEXT: pushq %rbx +; CHECK-NEXT: subq $32, %rsp +; CHECK-NEXT: movq %xmm1, %rax +; CHECK-NEXT: movq %rax, %rcx +; CHECK-NEXT: shrq $32, %rcx +; CHECK-NEXT: shll $16, %ecx +; CHECK-NEXT: movd %ecx, %xmm2 +; CHECK-NEXT: movq %xmm0, %rcx +; CHECK-NEXT: movq %rcx, %rdx +; CHECK-NEXT: shrq $32, %rdx +; CHECK-NEXT: shll $16, %edx +; CHECK-NEXT: movd %edx, %xmm3 +; CHECK-NEXT: addss %xmm2, %xmm3 +; CHECK-NEXT: movss %xmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill +; CHECK-NEXT: movq %rax, %rdx +; CHECK-NEXT: shrq $48, %rdx +; CHECK-NEXT: shll $16, %edx +; CHECK-NEXT: movd %edx, %xmm2 +; CHECK-NEXT: movq %rcx, %rdx +; CHECK-NEXT: shrq $48, %rdx +; CHECK-NEXT: shll $16, %edx +; CHECK-NEXT: movd %edx, %xmm3 +; CHECK-NEXT: addss %xmm2, %xmm3 +; CHECK-NEXT: movss %xmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill +; CHECK-NEXT: movl %eax, %edx +; CHECK-NEXT: shll $16, %edx +; CHECK-NEXT: movd %edx, %xmm2 +; CHECK-NEXT: movl %ecx, %edx +; CHECK-NEXT: shll $16, %edx +; CHECK-NEXT: movd %edx, %xmm3 +; CHECK-NEXT: addss %xmm2, %xmm3 +; CHECK-NEXT: movss %xmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill +; CHECK-NEXT: andl $-65536, %eax # imm = 0xFFFF0000 +; CHECK-NEXT: movd %eax, %xmm2 +; CHECK-NEXT: andl $-65536, %ecx # imm = 0xFFFF0000 +; CHECK-NEXT: movd %ecx, %xmm3 +; CHECK-NEXT: addss %xmm2, %xmm3 +; CHECK-NEXT: movss %xmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill +; CHECK-NEXT: pshufd {{.*#+}} xmm1 = xmm1[2,3,2,3] +; CHECK-NEXT: movq %xmm1, %rax +; CHECK-NEXT: movq %rax, %rcx +; CHECK-NEXT: shrq $32, %rcx +; CHECK-NEXT: shll $16, %ecx +; CHECK-NEXT: movd %ecx, %xmm1 +; CHECK-NEXT: pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3] +; CHECK-NEXT: movq %xmm0, %rcx +; CHECK-NEXT: movq %rcx, %rdx +; CHECK-NEXT: shrq $32, %rdx +; CHECK-NEXT: shll $16, %edx +; CHECK-NEXT: movd %edx, %xmm0 +; CHECK-NEXT: addss %xmm1, %xmm0 +; CHECK-NEXT: movss %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill +; CHECK-NEXT: movq %rax, %rdx +; CHECK-NEXT: shrq $48, %rdx +; CHECK-NEXT: shll $16, %edx +; CHECK-NEXT: movd %edx, %xmm0 +; CHECK-NEXT: movq %rcx, %rdx +; CHECK-NEXT: shrq $48, %rdx +; CHECK-NEXT: shll $16, %edx +; CHECK-NEXT: movd %edx, %xmm1 +; CHECK-NEXT: addss %xmm0, %xmm1 +; CHECK-NEXT: movss %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill +; CHECK-NEXT: movl %eax, %edx +; CHECK-NEXT: shll $16, %edx +; CHECK-NEXT: movd %edx, %xmm0 +; CHECK-NEXT: movl %ecx, %edx +; CHECK-NEXT: shll $16, %edx +; CHECK-NEXT: movd %edx, %xmm1 +; CHECK-NEXT: addss %xmm0, %xmm1 +; CHECK-NEXT: movss %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 4-byte Spill +; CHECK-NEXT: andl $-65536, %eax # imm = 0xFFFF0000 +; CHECK-NEXT: movd %eax, %xmm1 +; CHECK-NEXT: andl $-65536, %ecx # imm = 0xFFFF0000 +; CHECK-NEXT: movd %ecx, %xmm0 +; CHECK-NEXT: addss %xmm1, %xmm0 +; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: movd %xmm0, %ebx +; CHECK-NEXT: shll $16, %ebx +; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload +; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero +; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: movd %xmm0, %eax +; CHECK-NEXT: movzwl %ax, %r14d +; CHECK-NEXT: orl %ebx, %r14d +; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload +; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero +; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: movd %xmm0, %ebp +; CHECK-NEXT: shll $16, %ebp +; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload +; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero +; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: movd %xmm0, %eax +; CHECK-NEXT: movzwl %ax, %ebx +; CHECK-NEXT: orl %ebp, %ebx +; CHECK-NEXT: shlq $32, %rbx +; CHECK-NEXT: orq %r14, %rbx +; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload +; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero +; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: movd %xmm0, %ebp +; CHECK-NEXT: shll $16, %ebp +; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload +; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero +; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: movd %xmm0, %eax +; CHECK-NEXT: movzwl %ax, %r14d +; CHECK-NEXT: orl %ebp, %r14d +; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload +; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero +; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: movd %xmm0, %ebp +; CHECK-NEXT: shll $16, %ebp +; CHECK-NEXT: movd {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 4-byte Folded Reload +; CHECK-NEXT: # xmm0 = mem[0],zero,zero,zero +; CHECK-NEXT: callq __truncsfbf2@PLT +; CHECK-NEXT: movd %xmm0, %eax +; CHECK-NEXT: movzwl %ax, %eax +; CHECK-NEXT: orl %ebp, %eax +; CHECK-NEXT: shlq $32, %rax +; CHECK-NEXT: orq %r14, %rax +; CHECK-NEXT: movq %rax, %xmm0 +; CHECK-NEXT: movq %rbx, %xmm1 +; CHECK-NEXT: punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0] +; CHECK-NEXT: addq $32, %rsp +; CHECK-NEXT: popq %rbx +; CHECK-NEXT: popq %r14 +; CHECK-NEXT: popq %rbp +; CHECK-NEXT: retq + %add = fadd <8 x bfloat> %a, %b + ret <8 x bfloat> %add +}