diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -54523,6 +54523,59 @@ return SDValue(); } +// Inverting a constant vector is profitable if it can be eliminated and the +// inverted vector is already present in DAG. Otherwise, it will be loaded +// anyway. +// +// We determine which of the values can be completely eliminated and invert it. +// If both are eliminable, select a vector with the first negative element. +static SDValue getInvertedVectorForFMA(SDValue V, SelectionDAG &DAG) { + assert(ISD::isBuildVectorOfConstantFPSDNodes(V.getNode()) && + "ConstantFP build vector expected"); + // Check if we can eliminate V. We assume if a value is only used in FMAs, we + // can eliminate it. Since this function is invoked for each FMA with this + // vector. + auto IsNotFMA = [](SDNode *Use) { + return Use->getOpcode() != ISD::FMA && Use->getOpcode() != ISD::STRICT_FMA; + }; + if (llvm::any_of(V->uses(), IsNotFMA)) + return SDValue(); + + SmallVector Ops; + EVT VT = V.getValueType(); + EVT EltVT = VT.getVectorElementType(); + for (auto Op : V->op_values()) { + if (auto *Cst = dyn_cast(Op)) { + Ops.push_back(DAG.getConstantFP(-Cst->getValueAPF(), SDLoc(Op), EltVT)); + } else { + assert(Op.isUndef()); + Ops.push_back(DAG.getUNDEF(EltVT)); + } + } + + SDNode *NV = DAG.getNodeIfExists(ISD::BUILD_VECTOR, DAG.getVTList(VT), Ops); + if (!NV) + return SDValue(); + + // If an inverted version cannot be eliminated, choose it instead of the + // original version. + if (llvm::any_of(NV->uses(), IsNotFMA)) + return SDValue(NV, 0); + + // If the inverted version also can be eliminated, we have to consistently + // prefer one of the values. We prefer a constant with a negative value on + // the first place. + // N.B. We need to skip undefs that may precede a value. + for (auto op : V->op_values()) { + if (auto *Cst = dyn_cast(op)) { + if (Cst->isNegative()) + return SDValue(); + break; + } + } + return SDValue(NV, 0); +} + static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -54574,7 +54627,13 @@ return true; } } - + // Lookup if there is an inverted version of constant vector V in DAG. + if (ISD::isBuildVectorOfConstantFPSDNodes(V.getNode())) { + if (SDValue NegV = getInvertedVectorForFMA(V, DAG)) { + V = NegV; + return true; + } + } return false; }; diff --git a/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll b/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll --- a/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll +++ b/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll @@ -154,15 +154,13 @@ ; X32-LABEL: test9: ; X32: # %bb.0: ; X32-NEXT: vbroadcastsd {{.*#+}} ymm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] -; X32-NEXT: vbroadcastsd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] -; X32-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1 +; X32-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm1 ; X32-NEXT: retl ; ; X64-LABEL: test9: ; X64: # %bb.0: ; X64-NEXT: vbroadcastsd {{.*#+}} ymm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] -; X64-NEXT: vbroadcastsd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] -; X64-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1 +; X64-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm1 ; X64-NEXT: retq %t = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> ) ret <4 x double> %t @@ -172,17 +170,19 @@ ; X32-LABEL: test10: ; X32: # %bb.0: ; X32-NEXT: vmovapd {{.*#+}} ymm2 = <-9.5E+0,u,-5.5E+0,-2.5E+0> -; X32-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1 -; X32-NEXT: vfmadd231pd {{.*#+}} ymm1 = (ymm0 * mem) + ymm1 -; X32-NEXT: vaddpd %ymm1, %ymm2, %ymm0 +; X32-NEXT: vmovapd %ymm2, %ymm3 +; X32-NEXT: vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1 +; X32-NEXT: vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1 +; X32-NEXT: vaddpd %ymm2, %ymm3, %ymm0 ; X32-NEXT: retl ; ; X64-LABEL: test10: ; X64: # %bb.0: ; X64-NEXT: vmovapd {{.*#+}} ymm2 = <-9.5E+0,u,-5.5E+0,-2.5E+0> -; X64-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1 -; X64-NEXT: vfmadd231pd {{.*#+}} ymm1 = (ymm0 * mem) + ymm1 -; X64-NEXT: vaddpd %ymm1, %ymm2, %ymm0 +; X64-NEXT: vmovapd %ymm2, %ymm3 +; X64-NEXT: vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1 +; X64-NEXT: vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1 +; X64-NEXT: vaddpd %ymm2, %ymm3, %ymm0 ; X64-NEXT: retq %t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> %b) %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> %b) @@ -196,7 +196,7 @@ ; X32-NEXT: vbroadcastf128 {{.*#+}} ymm1 = [5.0E-1,2.5E+0,5.0E-1,2.5E+0] ; X32-NEXT: # ymm1 = mem[0,1,0,1] ; X32-NEXT: vaddpd %ymm1, %ymm0, %ymm0 -; X32-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem +; X32-NEXT: vfmsub213pd {{.*#+}} ymm0 = (ymm1 * ymm0) - ymm1 ; X32-NEXT: retl ; ; X64-LABEL: test11: @@ -204,7 +204,7 @@ ; X64-NEXT: vbroadcastf128 {{.*#+}} ymm1 = [5.0E-1,2.5E+0,5.0E-1,2.5E+0] ; X64-NEXT: # ymm1 = mem[0,1,0,1] ; X64-NEXT: vaddpd %ymm1, %ymm0, %ymm0 -; X64-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem +; X64-NEXT: vfmsub213pd {{.*#+}} ymm0 = (ymm1 * ymm0) - ymm1 ; X64-NEXT: retq %t0 = fadd <4 x double> %a, %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %t0, <4 x double> , <4 x double> ) @@ -214,20 +214,18 @@ define <4 x double> @test12(<4 x double> %a, <4 x double> %b) { ; X32-LABEL: test12: ; X32: # %bb.0: -; X32-NEXT: vmovapd {{.*#+}} ymm2 = [7.5E+0,2.5E+0,5.5E+0,9.5E+0] -; X32-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + mem -; X32-NEXT: vmovapd {{.*#+}} ymm0 = -; X32-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem -; X32-NEXT: vaddpd %ymm0, %ymm2, %ymm0 +; X32-NEXT: vmovapd {{.*#+}} ymm2 = [-7.5E+0,-2.5E+0,-5.5E+0,-9.5E+0] +; X32-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm2 * ymm0) + mem +; X32-NEXT: vfmadd132pd {{.*#+}} ymm1 = (ymm1 * mem) + ymm2 +; X32-NEXT: vaddpd %ymm1, %ymm0, %ymm0 ; X32-NEXT: retl ; ; X64-LABEL: test12: ; X64: # %bb.0: -; X64-NEXT: vmovapd {{.*#+}} ymm2 = [7.5E+0,2.5E+0,5.5E+0,9.5E+0] -; X64-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + mem -; X64-NEXT: vmovapd {{.*#+}} ymm0 = -; X64-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem -; X64-NEXT: vaddpd %ymm0, %ymm2, %ymm0 +; X64-NEXT: vmovapd {{.*#+}} ymm2 = [-7.5E+0,-2.5E+0,-5.5E+0,-9.5E+0] +; X64-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm2 * ymm0) + mem +; X64-NEXT: vfmadd132pd {{.*#+}} ymm1 = (ymm1 * mem) + ymm2 +; X64-NEXT: vaddpd %ymm1, %ymm0, %ymm0 ; X64-NEXT: retq %t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> ) %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %b, <4 x double> , <4 x double> ) diff --git a/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll b/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll --- a/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll +++ b/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll @@ -129,14 +129,14 @@ define <4 x double> @negated_constant_v4f64(<4 x double> %a) { ; FMA3-LABEL: negated_constant_v4f64: ; FMA3: # %bb.0: -; FMA3-NEXT: vmovapd {{.*#+}} ymm1 = [5.0E-1,2.5E-1,1.25E-1,6.25E-2] -; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem +; FMA3-NEXT: vmovapd {{.*#+}} ymm1 = [-5.0E-1,-2.5E-1,-1.25E-1,-6.25E-2] +; FMA3-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm1 ; FMA3-NEXT: retq ; ; FMA4-LABEL: negated_constant_v4f64: ; FMA4: # %bb.0: -; FMA4-NEXT: vmovapd {{.*#+}} ymm1 = [5.0E-1,2.5E-1,1.25E-1,6.25E-2] -; FMA4-NEXT: vfmaddpd {{.*#+}} ymm0 = (ymm0 * ymm1) + mem +; FMA4-NEXT: vmovapd {{.*#+}} ymm1 = [-5.0E-1,-2.5E-1,-1.25E-1,-6.25E-2] +; FMA4-NEXT: vfnmaddpd {{.*#+}} ymm0 = -(ymm0 * ymm1) + ymm1 ; FMA4-NEXT: retq %t = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> ) ret <4 x double> %t @@ -146,16 +146,18 @@ ; FMA3-LABEL: negated_constant_v4f64_2fmas: ; FMA3: # %bb.0: ; FMA3-NEXT: vmovapd {{.*#+}} ymm2 = <-5.0E-1,u,-2.5E+0,-4.5E+0> -; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1 -; FMA3-NEXT: vfmadd231pd {{.*#+}} ymm1 = (ymm0 * mem) + ymm1 -; FMA3-NEXT: vaddpd %ymm1, %ymm2, %ymm0 +; FMA3-NEXT: vmovapd %ymm2, %ymm3 +; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1 +; FMA3-NEXT: vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1 +; FMA3-NEXT: vaddpd %ymm2, %ymm3, %ymm0 ; FMA3-NEXT: retq ; ; FMA4-LABEL: negated_constant_v4f64_2fmas: ; FMA4: # %bb.0: -; FMA4-NEXT: vfmaddpd {{.*#+}} ymm2 = (ymm0 * mem) + ymm1 -; FMA4-NEXT: vfmaddpd {{.*#+}} ymm0 = (ymm0 * mem) + ymm1 -; FMA4-NEXT: vaddpd %ymm0, %ymm2, %ymm0 +; FMA4-NEXT: vmovapd {{.*#+}} ymm2 = <-5.0E-1,u,-2.5E+0,-4.5E+0> +; FMA4-NEXT: vfmaddpd {{.*#+}} ymm3 = (ymm0 * ymm2) + ymm1 +; FMA4-NEXT: vfnmaddpd {{.*#+}} ymm0 = -(ymm0 * ymm2) + ymm1 +; FMA4-NEXT: vaddpd %ymm0, %ymm3, %ymm0 ; FMA4-NEXT: retq %t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> %b) %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> %b) @@ -169,7 +171,7 @@ ; FMA3-NEXT: vbroadcastf128 {{.*#+}} ymm1 = [1.5E+0,1.25E-1,1.5E+0,1.25E-1] ; FMA3-NEXT: # ymm1 = mem[0,1,0,1] ; FMA3-NEXT: vaddpd %ymm1, %ymm0, %ymm0 -; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem +; FMA3-NEXT: vfmsub213pd {{.*#+}} ymm0 = (ymm1 * ymm0) - ymm1 ; FMA3-NEXT: retq ; ; FMA4-LABEL: negated_constant_v4f64_fadd: @@ -177,7 +179,7 @@ ; FMA4-NEXT: vbroadcastf128 {{.*#+}} ymm1 = [1.5E+0,1.25E-1,1.5E+0,1.25E-1] ; FMA4-NEXT: # ymm1 = mem[0,1,0,1] ; FMA4-NEXT: vaddpd %ymm1, %ymm0, %ymm0 -; FMA4-NEXT: vfmaddpd {{.*#+}} ymm0 = (ymm0 * ymm1) + mem +; FMA4-NEXT: vfmsubpd {{.*#+}} ymm0 = (ymm0 * ymm1) - ymm1 ; FMA4-NEXT: retq %t0 = fadd <4 x double> %a, %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %t0, <4 x double> , <4 x double> ) @@ -187,19 +189,17 @@ define <4 x double> @negated_constant_v4f64_2fma_undefs(<4 x double> %a, <4 x double> %b) { ; FMA3-LABEL: negated_constant_v4f64_2fma_undefs: ; FMA3: # %bb.0: -; FMA3-NEXT: vmovapd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] -; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + mem -; FMA3-NEXT: vmovapd {{.*#+}} ymm0 = -; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem -; FMA3-NEXT: vaddpd %ymm0, %ymm2, %ymm0 +; FMA3-NEXT: vmovapd {{.*#+}} ymm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] +; FMA3-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm2 * ymm0) + mem +; FMA3-NEXT: vfmadd132pd {{.*#+}} ymm1 = (ymm1 * mem) + ymm2 +; FMA3-NEXT: vaddpd %ymm1, %ymm0, %ymm0 ; FMA3-NEXT: retq ; ; FMA4-LABEL: negated_constant_v4f64_2fma_undefs: ; FMA4: # %bb.0: -; FMA4-NEXT: vmovapd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1] -; FMA4-NEXT: vfmaddpd {{.*#+}} ymm0 = (ymm0 * ymm2) + mem -; FMA4-NEXT: vmovapd {{.*#+}} ymm2 = -; FMA4-NEXT: vfmaddpd {{.*#+}} ymm1 = (ymm1 * ymm2) + mem +; FMA4-NEXT: vmovapd {{.*#+}} ymm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] +; FMA4-NEXT: vfnmaddpd {{.*#+}} ymm0 = -(ymm0 * ymm2) + mem +; FMA4-NEXT: vfmaddpd {{.*#+}} ymm1 = (ymm1 * mem) + ymm2 ; FMA4-NEXT: vaddpd %ymm1, %ymm0, %ymm0 ; FMA4-NEXT: retq %t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> , <4 x double> )