diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -417,6 +417,10 @@ return Val; if (PartEVT.isInteger() && ValueVT.isFloatingPoint()) return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); + + // Vector/Vector bitcast (e.g. <2 x bfloat> -> <2 x half>). + if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits()) + return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); } // Promoted vector extract @@ -622,6 +626,8 @@ return SDValue(); EVT ValueVT = Val.getValueType(); + EVT PartEVT = PartVT.getVectorElementType(); + EVT ValueEVT = ValueVT.getVectorElementType(); ElementCount PartNumElts = PartVT.getVectorElementCount(); ElementCount ValueNumElts = ValueVT.getVectorElementCount(); @@ -629,22 +635,30 @@ // fixed/scalable properties. If a target needs to widen a fixed-length type // to a scalable one, it should be possible to use INSERT_SUBVECTOR below. if (ElementCount::isKnownLE(PartNumElts, ValueNumElts) || - PartNumElts.isScalable() != ValueNumElts.isScalable() || - PartVT.getVectorElementType() != ValueVT.getVectorElementType()) + PartNumElts.isScalable() != ValueNumElts.isScalable()) return SDValue(); + // Have a try for bf16 because some targets share its ABI with fp16. + if (ValueEVT == MVT::bf16 && PartEVT == MVT::f16) { + assert(DAG.getTargetLoweringInfo().isTypeLegal(PartVT) && + "Cannot widen to illegal type"); + Val = DAG.getNode(ISD::BITCAST, DL, + ValueVT.changeVectorElementType(MVT::f16), Val); + } else if (PartEVT != ValueEVT) { + return SDValue(); + } + // Widening a scalable vector to another scalable vector is done by inserting // the vector into a larger undef one. if (PartNumElts.isScalable()) return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, PartVT, DAG.getUNDEF(PartVT), Val, DAG.getVectorIdxConstant(0, DL)); - EVT ElementVT = PartVT.getVectorElementType(); // Vector widening case, e.g. <2 x float> -> <4 x float>. Shuffle in // undef elements. SmallVector Ops; DAG.ExtractVectorElements(Val, Ops); - SDValue EltUndef = DAG.getUNDEF(ElementVT); + SDValue EltUndef = DAG.getUNDEF(PartEVT); Ops.append((PartNumElts - ValueNumElts).getFixedValue(), EltUndef); // FIXME: Use CONCAT for 2x -> 4x. diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -2608,7 +2608,7 @@ if (VT.isVector() && VT.getVectorElementType() == MVT::bf16) return getRegisterTypeForCallingConv(Context, CC, - VT.changeVectorElementTypeToInteger()); + VT.changeVectorElementType(MVT::f16)); return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT); } @@ -2643,7 +2643,7 @@ if (VT.isVector() && VT.getVectorElementType() == MVT::bf16) return getNumRegistersForCallingConv(Context, CC, - VT.changeVectorElementTypeToInteger()); + VT.changeVectorElementType(MVT::f16)); return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT); } diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll --- a/llvm/test/CodeGen/X86/bfloat.ll +++ b/llvm/test/CodeGen/X86/bfloat.ll @@ -317,13 +317,13 @@ ; SSE2-NEXT: movq %rdx, %rax ; SSE2-NEXT: shrq $48, %rax ; SSE2-NEXT: movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill -; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3] +; SSE2-NEXT: punpckhqdq {{.*#+}} xmm0 = xmm0[1,1] ; SSE2-NEXT: movq %xmm0, %r12 ; SSE2-NEXT: movq %r12, %rax ; SSE2-NEXT: shrq $32, %rax ; SSE2-NEXT: movq %rax, (%rsp) # 8-byte Spill -; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3] -; SSE2-NEXT: movq %xmm0, %r14 +; SSE2-NEXT: punpckhqdq {{.*#+}} xmm1 = xmm1[1,1] +; SSE2-NEXT: movq %xmm1, %r14 ; SSE2-NEXT: movq %r14, %rbp ; SSE2-NEXT: shrq $32, %rbp ; SSE2-NEXT: movq %r12, %r15 @@ -543,3 +543,25 @@ %add = fadd <8 x bfloat> %a, %b ret <8 x bfloat> %add } + +define <2 x bfloat> @pr62997(bfloat %a, bfloat %b) { +; SSE2-LABEL: pr62997: +; SSE2: # %bb.0: +; SSE2-NEXT: movd %xmm0, %eax +; SSE2-NEXT: movd %xmm1, %ecx +; SSE2-NEXT: pinsrw $0, %ecx, %xmm1 +; SSE2-NEXT: pinsrw $0, %eax, %xmm0 +; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] +; SSE2-NEXT: retq +; +; BF16-LABEL: pr62997: +; BF16: # %bb.0: +; BF16-NEXT: vmovd %xmm1, %eax +; BF16-NEXT: vmovd %xmm0, %ecx +; BF16-NEXT: vmovd %ecx, %xmm0 +; BF16-NEXT: vpinsrw $1, %eax, %xmm0, %xmm0 +; BF16-NEXT: retq + %1 = insertelement <2 x bfloat> undef, bfloat %a, i64 0 + %2 = insertelement <2 x bfloat> %1, bfloat %b, i64 1 + ret <2 x bfloat> %2 +}