Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -11827,6 +11827,32 @@ N1.getOperand(0), N1.getOperand(1), N0, Flags); } + // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) + if (CanFuse && + N0.getOpcode() == PreferredFusedOpcode && + N0.getOperand(2).getOpcode() == ISD::FMUL && + N0.hasOneUse() && N0.getOperand(2).hasOneUse()) { + return DAG.getNode(PreferredFusedOpcode, SL, VT, + N0.getOperand(0), N0.getOperand(1), + DAG.getNode(PreferredFusedOpcode, SL, VT, + N0.getOperand(2).getOperand(0), + N0.getOperand(2).getOperand(1), + N1, Flags), Flags); + } + + // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x)) + if (CanFuse && + N1->getOpcode() == PreferredFusedOpcode && + N1.getOperand(2).getOpcode() == ISD::FMUL && + N1.hasOneUse() && N1.getOperand(2).hasOneUse()) { + return DAG.getNode(PreferredFusedOpcode, SL, VT, + N1.getOperand(0), N1.getOperand(1), + DAG.getNode(PreferredFusedOpcode, SL, VT, + N1.getOperand(2).getOperand(0), + N1.getOperand(2).getOperand(1), + N0, Flags), Flags); + } + // Look through FP_EXTEND nodes to do more combining. // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) @@ -11860,33 +11886,6 @@ // More folding opportunities when target permits. if (Aggressive) { - // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) - if (CanFuse && - N0.getOpcode() == PreferredFusedOpcode && - N0.getOperand(2).getOpcode() == ISD::FMUL && - N0.hasOneUse() && N0.getOperand(2).hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - N0.getOperand(0), N0.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - N0.getOperand(2).getOperand(0), - N0.getOperand(2).getOperand(1), - N1, Flags), Flags); - } - - // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x)) - if (CanFuse && - N1->getOpcode() == PreferredFusedOpcode && - N1.getOperand(2).getOpcode() == ISD::FMUL && - N1.hasOneUse() && N1.getOperand(2).hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - N1.getOperand(0), N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - N1.getOperand(2).getOperand(0), - N1.getOperand(2).getOperand(1), - N0, Flags), Flags); - } - - // fold (fadd (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y, (fma (fpext u), (fpext v), z)) auto FoldFAddFMAFPExtFMul = [&] ( Index: llvm/test/CodeGen/AArch64/fadd-combines.ll =================================================================== --- llvm/test/CodeGen/AArch64/fadd-combines.ll +++ llvm/test/CodeGen/AArch64/fadd-combines.ll @@ -197,9 +197,8 @@ define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, double %n1) nounwind { ; CHECK-LABEL: fadd_fma_fmul_1: ; CHECK: // %bb.0: -; CHECK-NEXT: fmul d2, d2, d3 +; CHECK-NEXT: fmadd d2, d2, d3, d4 ; CHECK-NEXT: fmadd d0, d0, d1, d2 -; CHECK-NEXT: fadd d0, d0, d4 ; CHECK-NEXT: ret %m1 = fmul fast double %a, %b %m2 = fmul fast double %c, %d @@ -213,9 +212,8 @@ define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind { ; CHECK-LABEL: fadd_fma_fmul_2: ; CHECK: // %bb.0: -; CHECK-NEXT: fmul s2, s2, s3 +; CHECK-NEXT: fmadd s2, s2, s3, s4 ; CHECK-NEXT: fmadd s0, s0, s1, s2 -; CHECK-NEXT: fadd s0, s4, s0 ; CHECK-NEXT: ret %m1 = fmul float %a, %b %m2 = fmul float %c, %d @@ -230,10 +228,10 @@ ; CHECK-LABEL: fadd_fma_fmul_3: ; CHECK: // %bb.0: ; CHECK-NEXT: fmul v2.2d, v2.2d, v3.2d -; CHECK-NEXT: fmul v3.2d, v6.2d, v7.2d ; CHECK-NEXT: fmla v2.2d, v1.2d, v0.2d -; CHECK-NEXT: fmla v3.2d, v5.2d, v4.2d -; CHECK-NEXT: fadd v0.2d, v2.2d, v3.2d +; CHECK-NEXT: fmla v2.2d, v7.2d, v6.2d +; CHECK-NEXT: fmla v2.2d, v5.2d, v4.2d +; CHECK-NEXT: mov v0.16b, v2.16b ; CHECK-NEXT: ret %m1 = fmul fast <2 x double> %x1, %x2 %m2 = fmul fast <2 x double> %x3, %x4 @@ -245,6 +243,8 @@ ret <2 x double> %a3 } +; negative test + define float @fadd_fma_fmul_extra_use_1(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind { ; CHECK-LABEL: fadd_fma_fmul_extra_use_1: ; CHECK: // %bb.0: @@ -261,6 +261,8 @@ ret float %a2 } +; negative test + define float @fadd_fma_fmul_extra_use_2(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind { ; CHECK-LABEL: fadd_fma_fmul_extra_use_2: ; CHECK: // %bb.0: @@ -277,6 +279,8 @@ ret float %a2 } +; negative test + define float @fadd_fma_fmul_extra_use_3(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind { ; CHECK-LABEL: fadd_fma_fmul_extra_use_3: ; CHECK: // %bb.0: Index: llvm/test/CodeGen/X86/fma_patterns.ll =================================================================== --- llvm/test/CodeGen/X86/fma_patterns.ll +++ llvm/test/CodeGen/X86/fma_patterns.ll @@ -1799,23 +1799,20 @@ define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, double %n1) nounwind { ; FMA-LABEL: fadd_fma_fmul_1: ; FMA: # %bb.0: -; FMA-NEXT: vmulsd %xmm3, %xmm2, %xmm2 -; FMA-NEXT: vfmadd231sd {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2 -; FMA-NEXT: vaddsd %xmm4, %xmm2, %xmm0 +; FMA-NEXT: vfmadd213sd {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4 +; FMA-NEXT: vfmadd213sd {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2 ; FMA-NEXT: retq ; ; FMA4-LABEL: fadd_fma_fmul_1: ; FMA4: # %bb.0: -; FMA4-NEXT: vmulsd %xmm3, %xmm2, %xmm2 +; FMA4-NEXT: vfmaddsd {{.*#+}} xmm2 = (xmm2 * xmm3) + xmm4 ; FMA4-NEXT: vfmaddsd {{.*#+}} xmm0 = (xmm0 * xmm1) + xmm2 -; FMA4-NEXT: vaddsd %xmm4, %xmm0, %xmm0 ; FMA4-NEXT: retq ; ; AVX512-LABEL: fadd_fma_fmul_1: ; AVX512: # %bb.0: -; AVX512-NEXT: vmulsd %xmm3, %xmm2, %xmm2 -; AVX512-NEXT: vfmadd231sd {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2 -; AVX512-NEXT: vaddsd %xmm4, %xmm2, %xmm0 +; AVX512-NEXT: vfmadd213sd {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4 +; AVX512-NEXT: vfmadd213sd {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2 ; AVX512-NEXT: retq %m1 = fmul fast double %a, %b %m2 = fmul fast double %c, %d @@ -1829,23 +1826,20 @@ define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind { ; FMA-LABEL: fadd_fma_fmul_2: ; FMA: # %bb.0: -; FMA-NEXT: vmulss %xmm3, %xmm2, %xmm2 -; FMA-NEXT: vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2 -; FMA-NEXT: vaddss %xmm2, %xmm4, %xmm0 +; FMA-NEXT: vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4 +; FMA-NEXT: vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2 ; FMA-NEXT: retq ; ; FMA4-LABEL: fadd_fma_fmul_2: ; FMA4: # %bb.0: -; FMA4-NEXT: vmulss %xmm3, %xmm2, %xmm2 +; FMA4-NEXT: vfmaddss {{.*#+}} xmm2 = (xmm2 * xmm3) + xmm4 ; FMA4-NEXT: vfmaddss {{.*#+}} xmm0 = (xmm0 * xmm1) + xmm2 -; FMA4-NEXT: vaddss %xmm0, %xmm4, %xmm0 ; FMA4-NEXT: retq ; ; AVX512-LABEL: fadd_fma_fmul_2: ; AVX512: # %bb.0: -; AVX512-NEXT: vmulss %xmm3, %xmm2, %xmm2 -; AVX512-NEXT: vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2 -; AVX512-NEXT: vaddss %xmm2, %xmm4, %xmm0 +; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4 +; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2 ; AVX512-NEXT: retq %m1 = fmul float %a, %b %m2 = fmul float %c, %d @@ -1860,28 +1854,27 @@ ; FMA-LABEL: fadd_fma_fmul_3: ; FMA: # %bb.0: ; FMA-NEXT: vmulpd %xmm3, %xmm2, %xmm2 -; FMA-NEXT: vmulpd %xmm7, %xmm6, %xmm3 ; FMA-NEXT: vfmadd231pd {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2 -; FMA-NEXT: vfmadd231pd {{.*#+}} xmm3 = (xmm5 * xmm4) + xmm3 -; FMA-NEXT: vaddpd %xmm3, %xmm2, %xmm0 +; FMA-NEXT: vfmadd231pd {{.*#+}} xmm2 = (xmm7 * xmm6) + xmm2 +; FMA-NEXT: vfmadd231pd {{.*#+}} xmm2 = (xmm5 * xmm4) + xmm2 +; FMA-NEXT: vmovapd %xmm2, %xmm0 ; FMA-NEXT: retq ; ; FMA4-LABEL: fadd_fma_fmul_3: ; FMA4: # %bb.0: ; FMA4-NEXT: vmulpd %xmm3, %xmm2, %xmm2 -; FMA4-NEXT: vmulpd %xmm7, %xmm6, %xmm3 ; FMA4-NEXT: vfmaddpd {{.*#+}} xmm0 = (xmm0 * xmm1) + xmm2 -; FMA4-NEXT: vfmaddpd {{.*#+}} xmm1 = (xmm4 * xmm5) + xmm3 -; FMA4-NEXT: vaddpd %xmm1, %xmm0, %xmm0 +; FMA4-NEXT: vfmaddpd {{.*#+}} xmm0 = (xmm6 * xmm7) + xmm0 +; FMA4-NEXT: vfmaddpd {{.*#+}} xmm0 = (xmm4 * xmm5) + xmm0 ; FMA4-NEXT: retq ; ; AVX512-LABEL: fadd_fma_fmul_3: ; AVX512: # %bb.0: ; AVX512-NEXT: vmulpd %xmm3, %xmm2, %xmm2 -; AVX512-NEXT: vmulpd %xmm7, %xmm6, %xmm3 ; AVX512-NEXT: vfmadd231pd {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2 -; AVX512-NEXT: vfmadd231pd {{.*#+}} xmm3 = (xmm5 * xmm4) + xmm3 -; AVX512-NEXT: vaddpd %xmm3, %xmm2, %xmm0 +; AVX512-NEXT: vfmadd231pd {{.*#+}} xmm2 = (xmm7 * xmm6) + xmm2 +; AVX512-NEXT: vfmadd231pd {{.*#+}} xmm2 = (xmm5 * xmm4) + xmm2 +; AVX512-NEXT: vmovapd %xmm2, %xmm0 ; AVX512-NEXT: retq %m1 = fmul fast <2 x double> %x1, %x2 %m2 = fmul fast <2 x double> %x3, %x4 @@ -1893,6 +1886,8 @@ ret <2 x double> %a3 } +; negative test + define float @fadd_fma_fmul_extra_use_1(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind { ; FMA-LABEL: fadd_fma_fmul_extra_use_1: ; FMA: # %bb.0: @@ -1925,6 +1920,8 @@ ret float %a2 } +; negative test + define float @fadd_fma_fmul_extra_use_2(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind { ; FMA-LABEL: fadd_fma_fmul_extra_use_2: ; FMA: # %bb.0: @@ -1957,6 +1954,8 @@ ret float %a2 } +; negative test + define float @fadd_fma_fmul_extra_use_3(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind { ; FMA-LABEL: fadd_fma_fmul_extra_use_3: ; FMA: # %bb.0: