diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -54011,7 +54011,7 @@ !(ScalarVT == MVT::f16 && Subtarget.hasFP16())) return SDValue(); - auto invertIfNegative = [&DAG, &TLI, &DCI](SDValue &V) { + auto invertIfNegative = [&DAG, &TLI, &DCI, &Subtarget](SDValue &V) { bool CodeSize = DAG.getMachineFunction().getFunction().hasOptSize(); bool LegalOperations = !DCI.isBeforeLegalizeOps(); if (SDValue NegV = TLI.getCheaperNegatedExpression(V, DAG, LegalOperations, @@ -54031,7 +54031,29 @@ return true; } } - + // Lookup if there is a negative version of V in DAG. + APInt SplatValue; + if (Subtarget.hasAVX2() && V.hasOneUse() && + ISD::isConstantSplatVector(V.getNode(), SplatValue)) { + SmallVector Ops; + EVT VT = V.getValueType(); + EVT EltVT = VT.getVectorElementType(); + for (auto op : V->op_values()) { + if (auto *Cst = dyn_cast(op)) { + Ops.push_back( + DAG.getConstantFP(-Cst->getValueAPF(), SDLoc(op), EltVT)); + } else { + assert(op.isUndef()); + Ops.push_back(DAG.getUNDEF(EltVT)); + } + } + SDNode *NV = + DAG.getNodeIfExists(ISD::BUILD_VECTOR, DAG.getVTList(VT), {Ops}); + if (NV && !NV->use_empty()) { + V.setNode(NV); + return true; + } + } return false; }; diff --git a/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll b/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll --- a/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll +++ b/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll @@ -154,19 +154,19 @@ ; X32-LABEL: test9: ; X32: # %bb.0: ; X32-NEXT: vbroadcastsd {{.*#+}} ymm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] -; X32-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1 -; X32-NEXT: vbroadcastsd {{.*#+}} ymm3 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] +; X32-NEXT: vmovapd %ymm2, %ymm3 ; X32-NEXT: vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1 -; X32-NEXT: vaddpd %ymm3, %ymm2, %ymm0 +; X32-NEXT: vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1 +; X32-NEXT: vaddpd %ymm2, %ymm3, %ymm0 ; X32-NEXT: retl ; ; X64-LABEL: test9: ; X64: # %bb.0: ; X64-NEXT: vbroadcastsd {{.*#+}} ymm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] -; X64-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1 -; X64-NEXT: vbroadcastsd {{.*#+}} ymm3 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] +; X64-NEXT: vmovapd %ymm2, %ymm3 ; X64-NEXT: vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1 -; X64-NEXT: vaddpd %ymm3, %ymm2, %ymm0 +; X64-NEXT: vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1 +; X64-NEXT: vaddpd %ymm2, %ymm3, %ymm0 ; X64-NEXT: retq %t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> %b) %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> %b)