diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -14202,6 +14202,28 @@ // More folding opportunities when target permits. if (Aggressive) { + // fold (fadd x, (fma a, b, (fma c, d, (fmul y, z)))) + // -> (fma y, z (fma a, b, (fma c, d, x))) + auto FoldFAddFMAFMAFMul = [&](SDValue A, SDValue B, SDValue C, SDValue D, + SDValue X, SDValue Y, SDValue Z) { + return DAG.getNode( + PreferredFusedOpcode, SL, VT, Y, Z, + DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, + DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, X))); + }; + if (isFusedOp(N1)) { + SDValue FMA1_N2 = N1.getOperand(2); + if (isFusedOp(FMA1_N2)) { + SDValue FMA2_N2 = FMA1_N2.getOperand(2); + if (FMA2_N2.getOpcode() == ISD::FMUL) { + return FoldFAddFMAFMAFMul( + N1.getOperand(0), N1.getOperand(1), FMA1_N2.getOperand(0), + FMA1_N2.getOperand(1), N0, FMA2_N2.getOperand(0), + FMA2_N2.getOperand(1)); + } + } + } + // fold (fadd (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y, (fma (fpext u), (fpext v), z)) auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V, diff --git a/llvm/test/CodeGen/NVPTX/fma-assoc.ll b/llvm/test/CodeGen/NVPTX/fma-assoc.ll --- a/llvm/test/CodeGen/NVPTX/fma-assoc.ll +++ b/llvm/test/CodeGen/NVPTX/fma-assoc.ll @@ -39,3 +39,23 @@ ret double %3 } +define float @FoldFAddFMAFMAFMul(float %a, float %b, float %c, float %d, + float %x, float %y, float %z) { +; CHECK-LABEL: FoldFAddFMAFMAFMul( +; fold (fadd x, (fma a, b, (fma c, d, (fmul y, z)))) +; -> (fma y, z (fma a, b, (fma c, d, x))) +; CHECK-NOT: mul.f32 +; CHECK: fma.rn.f32 +; CHECK: fma.rn.f32 +; CHECK: fma.rn.f32 +; CHECK-NOT: add.f32 + %mul = fmul float %y, %z + %fma1 = call float @llvm.fma.f32(float %c, float %d, float %mul) + %fma2 = call float @llvm.fma.f32(float %a, float %b, float %fma1) + %res = fadd float %x, %fma2 + ret float %res +} + +declare float @llvm.fma.f32(float, float, float) #0 + +attributes #0 = { nounwind readnone }