diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -14171,6 +14171,59 @@ return DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, CDE); } + // Transform + // (fadd (fma (a, b, ... (fma (c, d, (fmul (e, f)))))), g) + // into + // (fadd (fma (a, b, ... (fma (c, d, g)))), (fmul (e, f)) + // to allow for further FMA combines. + if (CanReassociate) { + auto IsCopyOrExtract = [](SDValue Val) { + return Val.getOpcode() == ISD::CopyFromReg || + Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT; + }; + auto TryOperandPair = [&](SDValue N0, SDValue N1) { + return N0.getOpcode() == ISD::FMA && N1.hasOneUse() && + IsCopyOrExtract(N1); + }; + + SDValue Fma; + SDValue G; + bool TryCombine = false; + + if (TryOperandPair(N0, N1)) { + Fma = N0; + G = N1; + TryCombine = true; + } else if (TryOperandPair(N1, N0)) { + Fma = N1; + G = N0; + TryCombine = true; + } + + if (TryCombine) { + SDValue Tmp = Fma->getOperand(2); + while (Tmp.getOpcode() == ISD::FMA) { + if (Tmp.getOperand(2).getOpcode() == ISD::FMUL) { + SDValue FMul = Tmp.getOperand(2); + SDValue FMul0 = FMul.getOperand(0); + SDValue FMul1 = FMul.getOperand(1); + if (IsCopyOrExtract(FMul0) && IsCopyOrExtract(FMul1)) { + SDValue UpdatedFAdd = DAG.getNode(ISD::FADD, SL, VT, Fma, FMul); + DAG.UpdateNodeOperands(Tmp.getNode(), Tmp.getOperand(0), + Tmp.getOperand(1), G); + + return UpdatedFAdd; + } + + // Give up + break; + } + + Tmp = Tmp.getOperand(2); + } + } + } + // Look through FP_EXTEND nodes to do more combining. // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) diff --git a/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll b/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll --- a/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll +++ b/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll @@ -196,10 +196,9 @@ ; GCN: ; %bb.0: ; GCN-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) ; GCN-NEXT: s_waitcnt_vscnt null, 0x0 -; GCN-NEXT: v_mul_f32_e32 v2, v2, v3 -; GCN-NEXT: v_fmac_f32_e32 v2, v0, v1 -; GCN-NEXT: v_fmac_f32_e32 v2, v4, v5 -; GCN-NEXT: v_add_f32_e32 v0, v2, v6 +; GCN-NEXT: v_fma_f32 v0, v0, v1, v6 +; GCN-NEXT: v_fmac_f32_e32 v0, v4, v5 +; GCN-NEXT: v_mac_f32_e32 v0, v2, v3 ; GCN-NEXT: s_setpc_b64 s[30:31] %t0 = fmul fast float %a, %b %t1 = fmul fast float %c, %d