diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -2276,9 +2276,10 @@ addRegisterClass(MVT::v8bf16, &X86::VR128XRegClass); addRegisterClass(MVT::v16bf16, &X86::VR256XRegClass); // We set the type action of bf16 to TypeSoftPromoteHalf, but we don't - // provide the method to promote BUILD_VECTOR. Set the operation action - // Custom to do the customization later. + // provide the method to promote BUILD_VECTOR and INSERT_VECTOR_ELT. + // Set the operation action Custom to do the customization later. setOperationAction(ISD::BUILD_VECTOR, MVT::bf16, Custom); + setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::bf16, Custom); for (auto VT : {MVT::v8bf16, MVT::v16bf16}) { setF16Action(VT, Expand); setOperationAction(ISD::FADD, VT, Expand); @@ -20751,6 +20752,14 @@ SDValue N2 = Op.getOperand(2); auto *N2C = dyn_cast(N2); + if (EltVT == MVT::bf16) { + MVT IVT = VT.changeVectorElementTypeToInteger(); + SDValue Res = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, IVT, + DAG.getBitcast(IVT, N0), + DAG.getBitcast(MVT::i16, N1), N2); + return DAG.getBitcast(VT, Res); + } + if (!N2C) { // Variable insertion indices, usually we're better off spilling to stack, // but AVX512 can use a variable compare+select by comparing against all diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll --- a/llvm/test/CodeGen/X86/bfloat.ll +++ b/llvm/test/CodeGen/X86/bfloat.ll @@ -1158,4 +1158,29 @@ ret <32 x bfloat> %1 } +define <32 x bfloat> @pr62997_3(<32 x bfloat> %0, bfloat %1) { +; SSE2-LABEL: pr62997_3: +; SSE2: # %bb.0: +; SSE2-NEXT: movq %xmm0, %rax +; SSE2-NEXT: movabsq $-4294967296, %rcx # imm = 0xFFFFFFFF00000000 +; SSE2-NEXT: andq %rax, %rcx +; SSE2-NEXT: movzwl %ax, %eax +; SSE2-NEXT: movd %xmm4, %edx +; SSE2-NEXT: shll $16, %edx +; SSE2-NEXT: orl %eax, %edx +; SSE2-NEXT: orq %rcx, %rdx +; SSE2-NEXT: movq %rdx, %xmm4 +; SSE2-NEXT: movsd {{.*#+}} xmm0 = xmm4[0],xmm0[1] +; SSE2-NEXT: retq +; +; BF16-LABEL: pr62997_3: +; BF16: # %bb.0: +; BF16-NEXT: vmovd %xmm1, %eax +; BF16-NEXT: vpinsrw $1, %eax, %xmm0, %xmm1 +; BF16-NEXT: vinserti32x4 $0, %xmm1, %zmm0, %zmm0 +; BF16-NEXT: retq + %3 = insertelement <32 x bfloat> %0, bfloat %1, i64 1 + ret <32 x bfloat> %3 +} + declare <32 x bfloat> @llvm.masked.load.v32bf16.p0(ptr, i32, <32 x i1>, <32 x bfloat>)