diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -14168,6 +14168,59 @@ return DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, CDE); } + // Transform + // (fadd (fma (a, b, ... (fma (c, d, (fmul (e, f)))))), g) + // into + // (fadd (fma (a, b, ... (fma (c, d, g)))), (fmul (e, f)) + // to allow for further FMA combines. + if (CanReassociate) { + auto IsCopyOrExtract = [](SDValue Val) { + return Val.getOpcode() == ISD::CopyFromReg || + Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT; + }; + auto TryOperandPair = [&](SDValue N0, SDValue N1) { + return N0.getOpcode() == ISD::FMA && N1.hasOneUse() && + IsCopyOrExtract(N1); + }; + + SDValue Fma; + SDValue G; + bool TryCombine = false; + + if (TryOperandPair(N0, N1)) { + Fma = N0; + G = N1; + TryCombine = true; + } else if (TryOperandPair(N1, N0)) { + Fma = N1; + G = N0; + TryCombine = true; + } + + if (TryCombine) { + SDValue Tmp = Fma->getOperand(2); + while (Tmp.getOpcode() == ISD::FMA) { + if (Tmp.getOperand(2).getOpcode() == ISD::FMUL) { + SDValue FMul = Tmp.getOperand(2); + SDValue FMul0 = FMul.getOperand(0); + SDValue FMul1 = FMul.getOperand(1); + if (IsCopyOrExtract(FMul0) && IsCopyOrExtract(FMul1)) { + SDValue UpdatedFAdd = DAG.getNode(ISD::FADD, SL, VT, Fma, FMul); + DAG.UpdateNodeOperands(Tmp.getNode(), Tmp.getOperand(0), + Tmp.getOperand(1), G); + + return UpdatedFAdd; + } + + // Give up + break; + } + + Tmp = Tmp.getOperand(2); + } + } + } + // Look through FP_EXTEND nodes to do more combining. // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) diff --git a/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll b/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll --- a/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll +++ b/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll @@ -175,6 +175,40 @@ ret float %.i2551 } +define float @fmac_sequence_simple(float %a, float %b, float %c, float %d, float %e) #0 { +; GCN-LABEL: fmac_sequence_simple: +; GCN: ; %bb.0: +; GCN-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +; GCN-NEXT: s_waitcnt_vscnt null, 0x0 +; GCN-NEXT: v_fma_f32 v2, v2, v3, v4 +; GCN-NEXT: v_fmac_f32_e32 v2, v0, v1 +; GCN-NEXT: v_mov_b32_e32 v0, v2 +; GCN-NEXT: s_setpc_b64 s[30:31] + %t0 = fmul fast float %a, %b + %t1 = fmul fast float %c, %d + %t2 = fadd fast float %t0, %t1 + %t5 = fadd fast float %t2, %e + ret float %t5 +} + +define float @fmac_sequence_innermost_fmul(float %a, float %b, float %c, float %d, float %e, float %f, float %g) #0 { +; GCN-LABEL: fmac_sequence_innermost_fmul: +; GCN: ; %bb.0: +; GCN-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +; GCN-NEXT: s_waitcnt_vscnt null, 0x0 +; GCN-NEXT: v_fma_f32 v0, v0, v1, v6 +; GCN-NEXT: v_fmac_f32_e32 v0, v4, v5 +; GCN-NEXT: v_mac_f32_e32 v0, v2, v3 +; GCN-NEXT: s_setpc_b64 s[30:31] + %t0 = fmul fast float %a, %b + %t1 = fmul fast float %c, %d + %t2 = fadd fast float %t0, %t1 + %t3 = fmul fast float %e, %f + %t4 = fadd fast float %t2, %t3 + %t5 = fadd fast float %t4, %g + ret float %t5 +} + ; Function Attrs: nofree nosync nounwind readnone speculatable willreturn declare float @llvm.maxnum.f32(float, float) #1