diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -53980,6 +53980,44 @@ return SDValue(); } +static SDValue getInvertedVectorForFMA(SDValue V, SelectionDAG &DAG) { + assert(ISD::isBuildVectorOfConstantFPSDNodes(V.getNode()) && + "ConstantFP build vector expected"); + // Check if we can eliminate a constant completely + for (const SDNode *User : V->uses()) { + if (User->getOpcode() != ISD::FMA || User->getOpcode() == ISD::STRICT_FMA) + return SDValue(); + } + // Form an inverted vector + SmallVector Ops; + EVT VT = V.getValueType(); + EVT EltVT = VT.getVectorElementType(); + for (auto op : V->op_values()) { + if (auto *Cst = dyn_cast(op)) { + Ops.push_back(DAG.getConstantFP(-Cst->getValueAPF(), SDLoc(op), EltVT)); + } else { + assert(op.isUndef()); + Ops.push_back(DAG.getUNDEF(EltVT)); + } + } + + SDNode *NV = DAG.getNodeIfExists(ISD::BUILD_VECTOR, DAG.getVTList(VT), Ops); + if (!NV) + return SDValue(); + + // If the inverted value also can be eliminated, we have to persistancy + // prefer one of the values. We prefer a constant with negative value on the + // first element. + for (const SDNode *User : NV->uses()) + if (User->getOpcode() != ISD::FMA || User->getOpcode() == ISD::STRICT_FMA) + return SDValue(NV, 0); + + if (cast(V->getOperand(0))->isNegative()) + return SDValue(); + + return SDValue(NV, 0); +} + static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -54031,7 +54069,13 @@ return true; } } - + // Lookup if there is an inverted version of constant vector V in DAG. + if (ISD::isBuildVectorOfConstantFPSDNodes(V.getNode())) { + if (SDValue NegV = getInvertedVectorForFMA(V, DAG)) { + V = NegV; + return true; + } + } return false; }; diff --git a/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll b/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll --- a/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll +++ b/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll @@ -154,15 +154,13 @@ ; X32-LABEL: test9: ; X32: # %bb.0: ; X32-NEXT: vbroadcastsd {{.*#+}} ymm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] -; X32-NEXT: vbroadcastsd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] -; X32-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1 +; X32-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm1 ; X32-NEXT: retl ; ; X64-LABEL: test9: ; X64: # %bb.0: ; X64-NEXT: vbroadcastsd {{.*#+}} ymm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] -; X64-NEXT: vbroadcastsd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] -; X64-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1 +; X64-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm1 ; X64-NEXT: retq %t = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> ) ret <4 x double> %t @@ -172,19 +170,19 @@ ; X32-LABEL: test10: ; X32: # %bb.0: ; X32-NEXT: vbroadcastsd {{.*#+}} ymm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] -; X32-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1 -; X32-NEXT: vbroadcastsd {{.*#+}} ymm3 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] +; X32-NEXT: vmovapd %ymm2, %ymm3 ; X32-NEXT: vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1 -; X32-NEXT: vaddpd %ymm3, %ymm2, %ymm0 +; X32-NEXT: vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1 +; X32-NEXT: vaddpd %ymm2, %ymm3, %ymm0 ; X32-NEXT: retl ; ; X64-LABEL: test10: ; X64: # %bb.0: ; X64-NEXT: vbroadcastsd {{.*#+}} ymm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] -; X64-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1 -; X64-NEXT: vbroadcastsd {{.*#+}} ymm3 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] +; X64-NEXT: vmovapd %ymm2, %ymm3 ; X64-NEXT: vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1 -; X64-NEXT: vaddpd %ymm3, %ymm2, %ymm0 +; X64-NEXT: vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1 +; X64-NEXT: vaddpd %ymm2, %ymm3, %ymm0 ; X64-NEXT: retq %t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> %b) %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> %b) @@ -196,17 +194,15 @@ ; X32-LABEL: test11: ; X32: # %bb.0: ; X32-NEXT: vbroadcastsd {{.*#+}} ymm1 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] -; X32-NEXT: vaddpd %ymm1, %ymm0, %ymm2 -; X32-NEXT: vbroadcastsd {{.*#+}} ymm0 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] -; X32-NEXT: vfmadd231pd {{.*#+}} ymm0 = (ymm1 * ymm2) + ymm0 +; X32-NEXT: vaddpd %ymm1, %ymm0, %ymm0 +; X32-NEXT: vfmsub213pd {{.*#+}} ymm0 = (ymm1 * ymm0) - ymm1 ; X32-NEXT: retl ; ; X64-LABEL: test11: ; X64: # %bb.0: ; X64-NEXT: vbroadcastsd {{.*#+}} ymm1 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] -; X64-NEXT: vaddpd %ymm1, %ymm0, %ymm2 -; X64-NEXT: vbroadcastsd {{.*#+}} ymm0 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] -; X64-NEXT: vfmadd231pd {{.*#+}} ymm0 = (ymm1 * ymm2) + ymm0 +; X64-NEXT: vaddpd %ymm1, %ymm0, %ymm0 +; X64-NEXT: vfmsub213pd {{.*#+}} ymm0 = (ymm1 * ymm0) - ymm1 ; X64-NEXT: retq %t0 = fadd <4 x double> %a, %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %t0, <4 x double> , <4 x double> ) diff --git a/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll b/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll --- a/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll +++ b/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll @@ -129,14 +129,14 @@ define <4 x double> @negated_constant_v4f64(<4 x double> %a) { ; FMA3-LABEL: negated_constant_v4f64: ; FMA3: # %bb.0: -; FMA3-NEXT: vmovapd {{.*#+}} ymm1 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] -; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem +; FMA3-NEXT: vmovapd {{.*#+}} ymm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] +; FMA3-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm1 ; FMA3-NEXT: retq ; ; FMA4-LABEL: negated_constant_v4f64: ; FMA4: # %bb.0: -; FMA4-NEXT: vmovapd {{.*#+}} ymm1 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] -; FMA4-NEXT: vfmaddpd {{.*#+}} ymm0 = (ymm0 * ymm1) + mem +; FMA4-NEXT: vmovapd {{.*#+}} ymm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] +; FMA4-NEXT: vfnmaddpd {{.*#+}} ymm0 = -(ymm0 * ymm1) + ymm1 ; FMA4-NEXT: retq %t = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> ) ret <4 x double> %t @@ -146,16 +146,18 @@ ; FMA3-LABEL: negated_constant_v4f64_2fmas: ; FMA3: # %bb.0: ; FMA3-NEXT: vmovapd {{.*#+}} ymm2 = <-5.0E-1,u,-5.0E-1,-5.0E-1> -; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1 -; FMA3-NEXT: vfmadd231pd {{.*#+}} ymm1 = (ymm0 * mem) + ymm1 -; FMA3-NEXT: vaddpd %ymm1, %ymm2, %ymm0 +; FMA3-NEXT: vmovapd %ymm2, %ymm3 +; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1 +; FMA3-NEXT: vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1 +; FMA3-NEXT: vaddpd %ymm2, %ymm3, %ymm0 ; FMA3-NEXT: retq ; ; FMA4-LABEL: negated_constant_v4f64_2fmas: ; FMA4: # %bb.0: -; FMA4-NEXT: vfmaddpd {{.*#+}} ymm2 = (ymm0 * mem) + ymm1 -; FMA4-NEXT: vfmaddpd {{.*#+}} ymm0 = (ymm0 * mem) + ymm1 -; FMA4-NEXT: vaddpd %ymm0, %ymm2, %ymm0 +; FMA4-NEXT: vmovapd {{.*#+}} ymm2 = <-5.0E-1,u,-5.0E-1,-5.0E-1> +; FMA4-NEXT: vfmaddpd {{.*#+}} ymm3 = (ymm0 * ymm2) + ymm1 +; FMA4-NEXT: vfnmaddpd {{.*#+}} ymm0 = -(ymm0 * ymm2) + ymm1 +; FMA4-NEXT: vaddpd %ymm0, %ymm3, %ymm0 ; FMA4-NEXT: retq %t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> %b) %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> %b) @@ -168,14 +170,14 @@ ; FMA3: # %bb.0: ; FMA3-NEXT: vmovapd {{.*#+}} ymm1 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] ; FMA3-NEXT: vaddpd %ymm1, %ymm0, %ymm0 -; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem +; FMA3-NEXT: vfmsub213pd {{.*#+}} ymm0 = (ymm1 * ymm0) - ymm1 ; FMA3-NEXT: retq ; ; FMA4-LABEL: negated_constant_v4f64_fadd: ; FMA4: # %bb.0: ; FMA4-NEXT: vmovapd {{.*#+}} ymm1 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] ; FMA4-NEXT: vaddpd %ymm1, %ymm0, %ymm0 -; FMA4-NEXT: vfmaddpd {{.*#+}} ymm0 = (ymm0 * ymm1) + mem +; FMA4-NEXT: vfmsubpd {{.*#+}} ymm0 = (ymm0 * ymm1) - ymm1 ; FMA4-NEXT: retq %t0 = fadd <4 x double> %a, %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %t0, <4 x double> , <4 x double> )