diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -1883,7 +1883,7 @@ unsigned KMaskOp = -1U; if (X86II::isKMasked(TSFlags)) { // For k-zero-masked operations it is Ok to commute the first vector - // operand. + // operand. Unless this is an intrinsic instruction. // For regular k-masked operations a conservative choice is done as the // elements of the first vector operand, for which the corresponding bit // in the k-mask operand is set to 0, are copied to the result of the @@ -1902,7 +1902,7 @@ // The operand with index = 1 is used as a source for those elements for // which the corresponding bit in the k-mask is set to 0. - if (X86II::isKMergeMasked(TSFlags)) + if (X86II::isKMergeMasked(TSFlags) || IsIntrinsic) FirstCommutableVecOp = 3; LastCommutableVecOp++; diff --git a/llvm/test/CodeGen/X86/avx512-intrinsics.ll b/llvm/test/CodeGen/X86/avx512-intrinsics.ll --- a/llvm/test/CodeGen/X86/avx512-intrinsics.ll +++ b/llvm/test/CodeGen/X86/avx512-intrinsics.ll @@ -5812,6 +5812,37 @@ ret <4 x float> %8 } +; Make sure we don't commute this to fold the load as that source isn't commutable. +define <4 x float> @test_int_x86_avx512_maskz_vfmadd_ss_load0(i8 zeroext %0, <4 x float>* nocapture readonly %1, <4 x float> %2, <4 x float> %3) { +; X64-LABEL: test_int_x86_avx512_maskz_vfmadd_ss_load0: +; X64: # %bb.0: +; X64-NEXT: vmovaps (%rsi), %xmm2 +; X64-NEXT: kmovw %edi, %k1 +; X64-NEXT: vfmadd213ss {{.*#+}} xmm2 = (xmm0 * xmm2) + xmm1 +; X64-NEXT: vmovaps %xmm2, %xmm0 +; X64-NEXT: retq +; +; X86-LABEL: test_int_x86_avx512_maskz_vfmadd_ss_load0: +; X86: # %bb.0: +; X86-NEXT: movb {{[0-9]+}}(%esp), %al +; X86-NEXT: movl {{[0-9]+}}(%esp), %ecx +; X86-NEXT: vmovaps (%ecx), %xmm2 +; X86-NEXT: kmovw %eax, %k1 +; X86-NEXT: vfmadd213ss {{.*#+}} xmm2 = (xmm0 * xmm2) + xmm1 +; X86-NEXT: vmovaps %xmm2, %xmm0 +; X86-NEXT: retl + %5 = load <4 x float>, <4 x float>* %1, align 16 + %6 = extractelement <4 x float> %5, i64 0 + %7 = extractelement <4 x float> %2, i64 0 + %8 = extractelement <4 x float> %3, i64 0 + %9 = tail call float @llvm.fma.f32(float %6, float %7, float %8) #2 + %10 = bitcast i8 %0 to <8 x i1> + %11 = extractelement <8 x i1> %10, i64 0 + %12 = select i1 %11, float %9, float 0.000000e+00 + %13 = insertelement <4 x float> %5, float %12, i64 0 + ret <4 x float> %13 +} + define <2 x double>@test_int_x86_avx512_mask3_vfmadd_sd(<2 x double> %x0, <2 x double> %x1, <2 x double> %x2, i8 %x3,i32 %x4 ){ ; X64-LABEL: test_int_x86_avx512_mask3_vfmadd_sd: ; X64: # %bb.0: