diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -14155,26 +14155,37 @@ // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E) // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E) + // This also works with nested fma instructions: + // fadd (fma A, B, (fma (C, D, (fmul (E, F))))), G --> + // fma A, B, (fma C, D, fma (E, F, G)) + // fadd (G, (fma A, B, (fma (C, D, (fmul (E, F)))))) --> + // fma A, B, (fma C, D, fma (E, F, G)). // This requires reassociation because it changes the order of operations. - SDValue FMA, E; - if (CanReassociate && isFusedOp(N0) && - N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() && - N0.getOperand(2).hasOneUse()) { - FMA = N0; - E = N1; - } else if (CanReassociate && isFusedOp(N1) && - N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() && - N1.getOperand(2).hasOneUse()) { - FMA = N1; - E = N0; - } - if (FMA && E) { - SDValue A = FMA.getOperand(0); - SDValue B = FMA.getOperand(1); - SDValue C = FMA.getOperand(2).getOperand(0); - SDValue D = FMA.getOperand(2).getOperand(1); - SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E); - return DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, CDE); + if (CanReassociate) { + SDValue FMA, E; + if (isFusedOp(N0) && N0.hasOneUse()) { + FMA = N0; + E = N1; + } else if (isFusedOp(N1) && N1.hasOneUse()) { + FMA = N1; + E = N0; + } + + SDValue TmpFMA = FMA; + while (E && TmpFMA && isFusedOp(TmpFMA)) { + SDValue FMul = TmpFMA->getOperand(2); + if (FMul.getOpcode() == ISD::FMUL && FMul.hasOneUse()) { + SDValue C = FMul.getOperand(0); + SDValue D = FMul.getOperand(1); + + DAG.MorphNodeTo(FMul.getNode(), PreferredFusedOpcode, FMul->getVTList(), + {C, D, E}); + + return FMA; + } + + TmpFMA = TmpFMA->getOperand(2); + } } // Look through FP_EXTEND nodes to do more combining. diff --git a/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll b/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll --- a/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll +++ b/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll @@ -49,7 +49,7 @@ ; GCN-NEXT: v_mad_f32 v10, s2, v6, v2 ; GCN-NEXT: s_mov_b32 s0, 0x3c23d70a ; GCN-NEXT: v_fmac_f32_e32 v1, v6, v8 -; GCN-NEXT: v_mac_f32_e32 v10, v7, v6 +; GCN-NEXT: v_fmac_f32_e32 v10, v7, v6 ; GCN-NEXT: s_waitcnt lgkmcnt(0) ; GCN-NEXT: v_mul_f32_e32 v9, s10, v0 ; GCN-NEXT: v_fma_f32 v0, -v0, s10, s14 @@ -196,10 +196,10 @@ ; GCN: ; %bb.0: ; GCN-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ; GCN-NEXT: s_waitcnt_vscnt null, 0x0 -; GCN-NEXT: v_mul_f32_e32 v2, v2, v3 +; GCN-NEXT: v_mad_f32 v2, v2, v3, v6 ; GCN-NEXT: v_fmac_f32_e32 v2, v0, v1 ; GCN-NEXT: v_fmac_f32_e32 v2, v4, v5 -; GCN-NEXT: v_add_f32_e32 v0, v2, v6 +; GCN-NEXT: v_mov_b32_e32 v0, v2 ; GCN-NEXT: s_setpc_b64 s[30:31] %t0 = fmul fast float %a, %b %t1 = fmul fast float %c, %d