diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -2269,6 +2269,7 @@ setOperationAction(ISD::FMUL, VT, Expand); setOperationAction(ISD::FDIV, VT, Expand); setOperationAction(ISD::BUILD_VECTOR, VT, Custom); + setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom); } addLegalFPImmediate(APFloat::getZero(APFloat::BFloat())); } @@ -2281,6 +2282,7 @@ setOperationAction(ISD::FMUL, MVT::v32bf16, Expand); setOperationAction(ISD::FDIV, MVT::v32bf16, Expand); setOperationAction(ISD::BUILD_VECTOR, MVT::v32bf16, Custom); + setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v32bf16, Custom); } if (!Subtarget.useSoftFloat() && Subtarget.hasVLX()) { @@ -19085,11 +19087,11 @@ return DAG.getBitcast(VT, DAG.getVectorShuffle(FpVT, DL, V1, V2, Mask)); } - if (VT == MVT::v16f16) { - V1 = DAG.getBitcast(MVT::v16i16, V1); - V2 = DAG.getBitcast(MVT::v16i16, V2); - return DAG.getBitcast(MVT::v16f16, - DAG.getVectorShuffle(MVT::v16i16, DL, V1, V2, Mask)); + if (VT == MVT::v16f16 || VT.getVectorElementType() == MVT::bf16) { + MVT IVT = VT.changeVectorElementTypeToInteger(); + V1 = DAG.getBitcast(IVT, V1); + V2 = DAG.getBitcast(IVT, V2); + return DAG.getBitcast(VT, DAG.getVectorShuffle(IVT, DL, V1, V2, Mask)); } switch (VT.SimpleTy) { diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td --- a/llvm/lib/Target/X86/X86InstrAVX512.td +++ b/llvm/lib/Target/X86/X86InstrAVX512.td @@ -12965,6 +12965,27 @@ (VCVTNEPS2BF16Z256rr VR256X:$src)>; def : Pat<(v8bf16 (int_x86_vcvtneps2bf16256 (loadv8f32 addr:$src))), (VCVTNEPS2BF16Z256rm addr:$src)>; + + def : Pat<(v8bf16 (X86VBroadcastld16 addr:$src)), + (VPBROADCASTWZ128rm addr:$src)>; + def : Pat<(v16bf16 (X86VBroadcastld16 addr:$src)), + (VPBROADCASTWZ256rm addr:$src)>; + + def : Pat<(v8bf16 (X86VBroadcast (v8bf16 VR128X:$src))), + (VPBROADCASTWZ128rr VR128X:$src)>; + def : Pat<(v16bf16 (X86VBroadcast (v8bf16 VR128X:$src))), + (VPBROADCASTWZ256rr VR128X:$src)>; + + // TODO: No scalar broadcast due to we don't support legal scalar bf16 so far. +} + +let Predicates = [HasBF16] in { + def : Pat<(v32bf16 (X86VBroadcastld16 addr:$src)), + (VPBROADCASTWZrm addr:$src)>; + + def : Pat<(v32bf16 (X86VBroadcast (v8bf16 VR128X:$src))), + (VPBROADCASTWZrr VR128X:$src)>; + // TODO: No scalar broadcast due to we don't support legal scalar bf16 so far. } let Constraints = "$src1 = $dst" in { diff --git a/llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics.ll b/llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics.ll --- a/llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics.ll +++ b/llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics.ll @@ -356,3 +356,49 @@ %2 = select <4 x i1> %1, <4 x float> %0, <4 x float> %E ret <4 x float> %2 } + +define <16 x i16> @test_no_vbroadcast1() { +; CHECK-LABEL: test_no_vbroadcast1: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0] +; CHECK-NEXT: vpbroadcastw %xmm0, %ymm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7d,0x79,0xc0] +; CHECK-NEXT: ret{{[l|q]}} # encoding: [0xc3] +entry: + %0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> poison, <8 x bfloat> zeroinitializer, <4 x i1> ) + %1 = bitcast <8 x bfloat> %0 to <8 x i16> + %2 = shufflevector <8 x i16> %1, <8 x i16> undef, <16 x i32> zeroinitializer + ret <16 x i16> %2 +} + +;; FIXME: This should generate the same output as above, but let's fix the crash first. +define <16 x bfloat> @test_no_vbroadcast2() nounwind { +; X86-LABEL: test_no_vbroadcast2: +; X86: # %bb.0: # %entry +; X86-NEXT: pushl %ebp # encoding: [0x55] +; X86-NEXT: movl %esp, %ebp # encoding: [0x89,0xe5] +; X86-NEXT: andl $-32, %esp # encoding: [0x83,0xe4,0xe0] +; X86-NEXT: subl $64, %esp # encoding: [0x83,0xec,0x40] +; X86-NEXT: vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0] +; X86-NEXT: vmovaps %xmm0, (%esp) # EVEX TO VEX Compression encoding: [0xc5,0xf8,0x29,0x04,0x24] +; X86-NEXT: vpbroadcastw (%esp), %ymm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7d,0x79,0x04,0x24] +; X86-NEXT: movl %ebp, %esp # encoding: [0x89,0xec] +; X86-NEXT: popl %ebp # encoding: [0x5d] +; X86-NEXT: retl # encoding: [0xc3] +; +; X64-LABEL: test_no_vbroadcast2: +; X64: # %bb.0: # %entry +; X64-NEXT: pushq %rbp # encoding: [0x55] +; X64-NEXT: movq %rsp, %rbp # encoding: [0x48,0x89,0xe5] +; X64-NEXT: andq $-32, %rsp # encoding: [0x48,0x83,0xe4,0xe0] +; X64-NEXT: subq $64, %rsp # encoding: [0x48,0x83,0xec,0x40] +; X64-NEXT: vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0] +; X64-NEXT: vmovaps %xmm0, (%rsp) # EVEX TO VEX Compression encoding: [0xc5,0xf8,0x29,0x04,0x24] +; X64-NEXT: vpbroadcastw (%rsp), %ymm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7d,0x79,0x04,0x24] +; X64-NEXT: movq %rbp, %rsp # encoding: [0x48,0x89,0xec] +; X64-NEXT: popq %rbp # encoding: [0x5d] +; X64-NEXT: retq # encoding: [0xc3] +entry: + %0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> poison, <8 x bfloat> zeroinitializer, <4 x i1> ) + %1 = shufflevector <8 x bfloat> %0, <8 x bfloat> undef, <16 x i32> zeroinitializer + ret <16 x bfloat> %1 +}