diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -2778,7 +2778,9 @@ ``contract`` Allow floating-point contraction (e.g. fusing a multiply followed by an - addition into a fused multiply-and-add). + addition into a fused multiply-and-add). This does not enable reassociating + to form arbitrary contractions. For example, ``(a*b) + (c*d) + e`` can not + be transformed into ``(a*b) + ((c*d) + e)`` to create two fma operations. ``afn`` Approximate functions - Allow substitution of approximate calculations for diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -11986,6 +11986,8 @@ SDNodeFlags Flags = N->getFlags(); bool CanFuse = Options.UnsafeFPMath || isContractable(N); + bool CanReassociate = + Options.UnsafeFPMath || N->getFlags().hasAllowReassociation(); bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast || CanFuse || HasFMAD); // If the addition is not contractable, do not combine. @@ -12028,13 +12030,14 @@ // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E) // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E) + // This requires reassociation because it changes the order of operations. SDValue FMA, E; - if (CanFuse && N0.getOpcode() == PreferredFusedOpcode && + if (CanReassociate && N0.getOpcode() == PreferredFusedOpcode && N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() && N0.getOperand(2).hasOneUse()) { FMA = N0; E = N1; - } else if (CanFuse && N1.getOpcode() == PreferredFusedOpcode && + } else if (CanReassociate && N1.getOpcode() == PreferredFusedOpcode && N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() && N1.getOperand(2).hasOneUse()) { FMA = N1; diff --git a/llvm/test/CodeGen/AArch64/fadd-combines.ll b/llvm/test/CodeGen/AArch64/fadd-combines.ll --- a/llvm/test/CodeGen/AArch64/fadd-combines.ll +++ b/llvm/test/CodeGen/AArch64/fadd-combines.ll @@ -207,6 +207,10 @@ ret double %a2 } +; Minimum FMF - the 1st fadd is contracted because that combines +; fmul+fadd as specified by the order of operations; the 2nd fadd +; requires reassociation to fuse with c*d. + define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n0) nounwind { ; CHECK-LABEL: fadd_fma_fmul_fmf: ; CHECK: // %bb.0: @@ -220,13 +224,14 @@ ret float %a2 } -; Minimum FMF, commute final add operands, change type. +; Not minimum FMF. define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind { ; CHECK-LABEL: fadd_fma_fmul_2: ; CHECK: // %bb.0: -; CHECK-NEXT: fmadd s2, s2, s3, s4 +; CHECK-NEXT: fmul s2, s2, s3 ; CHECK-NEXT: fmadd s0, s0, s1, s2 +; CHECK-NEXT: fadd s0, s4, s0 ; CHECK-NEXT: ret %m1 = fmul float %a, %b %m2 = fmul float %c, %d diff --git a/llvm/test/CodeGen/X86/fma_patterns.ll b/llvm/test/CodeGen/X86/fma_patterns.ll --- a/llvm/test/CodeGen/X86/fma_patterns.ll +++ b/llvm/test/CodeGen/X86/fma_patterns.ll @@ -1821,6 +1821,10 @@ ret double %a2 } +; Minimum FMF - the 1st fadd is contracted because that combines +; fmul+fadd as specified by the order of operations; the 2nd fadd +; requires reassociation to fuse with c*d. + define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n0) nounwind { ; FMA-LABEL: fadd_fma_fmul_fmf: ; FMA: # %bb.0: @@ -1846,25 +1850,28 @@ ret float %a2 } -; Minimum FMF, commute final add operands, change type. +; Not minimum FMF. define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind { ; FMA-LABEL: fadd_fma_fmul_2: ; FMA: # %bb.0: -; FMA-NEXT: vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4 -; FMA-NEXT: vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2 +; FMA-NEXT: vmulss %xmm3, %xmm2, %xmm2 +; FMA-NEXT: vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2 +; FMA-NEXT: vaddss %xmm2, %xmm4, %xmm0 ; FMA-NEXT: retq ; ; FMA4-LABEL: fadd_fma_fmul_2: ; FMA4: # %bb.0: -; FMA4-NEXT: vfmaddss {{.*#+}} xmm2 = (xmm2 * xmm3) + xmm4 +; FMA4-NEXT: vmulss %xmm3, %xmm2, %xmm2 ; FMA4-NEXT: vfmaddss {{.*#+}} xmm0 = (xmm0 * xmm1) + xmm2 +; FMA4-NEXT: vaddss %xmm0, %xmm4, %xmm0 ; FMA4-NEXT: retq ; ; AVX512-LABEL: fadd_fma_fmul_2: ; AVX512: # %bb.0: -; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4 -; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2 +; AVX512-NEXT: vmulss %xmm3, %xmm2, %xmm2 +; AVX512-NEXT: vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2 +; AVX512-NEXT: vaddss %xmm2, %xmm4, %xmm0 ; AVX512-NEXT: retq %m1 = fmul float %a, %b %m2 = fmul float %c, %d