diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -13051,6 +13051,11 @@ unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); + auto isFusedOp = [&](SDValue N) { + unsigned Opcode = N.getOpcode(); + return Opcode == ISD::FMA || Opcode == ISD::FMAD; + }; + // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { @@ -13082,12 +13087,12 @@ // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E) // This requires reassociation because it changes the order of operations. SDValue FMA, E; - if (CanReassociate && N0.getOpcode() == PreferredFusedOpcode && + if (CanReassociate && isFusedOp(N0) && N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() && N0.getOperand(2).hasOneUse()) { FMA = N0; E = N1; - } else if (CanReassociate && N1.getOpcode() == PreferredFusedOpcode && + } else if (CanReassociate && isFusedOp(N1) && N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() && N1.getOperand(2).hasOneUse()) { FMA = N1; @@ -13098,8 +13103,8 @@ SDValue B = FMA.getOperand(1); SDValue C = FMA.getOperand(2).getOperand(0); SDValue D = FMA.getOperand(2).getOperand(1); - SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E); - return DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, CDE); + SDValue CDE = DAG.getNode(FMA.getOpcode(), SL, VT, C, D, E); + return DAG.getNode(FMA.getOpcode(), SL, VT, A, B, CDE); } // Look through FP_EXTEND nodes to do more combining. @@ -13135,24 +13140,24 @@ if (Aggressive) { // fold (fadd (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y, (fma (fpext u), (fpext v), z)) - auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V, - SDValue Z) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y, - DAG.getNode(PreferredFusedOpcode, SL, VT, + auto FoldFAddFMAFPExtFMul = [&](unsigned FusedOpcode, SDValue X, SDValue Y, + SDValue U, SDValue V, SDValue Z) { + return DAG.getNode(FusedOpcode, SL, VT, X, Y, + DAG.getNode(FusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, U), DAG.getNode(ISD::FP_EXTEND, SL, VT, V), Z)); }; - if (N0.getOpcode() == PreferredFusedOpcode) { + if (isFusedOp(N0)) { + unsigned FusedOpcode = N0.getOpcode(); SDValue N02 = N0.getOperand(2); if (N02.getOpcode() == ISD::FP_EXTEND) { SDValue N020 = N02.getOperand(0); if (isContractableFMUL(N020) && - TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, - N020.getValueType())) { - return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1), - N020.getOperand(0), N020.getOperand(1), - N1); + TLI.isFPExtFoldable(DAG, FusedOpcode, VT, N020.getValueType())) { + return FoldFAddFMAFPExtFMul(FusedOpcode, N0.getOperand(0), + N0.getOperand(1), N020.getOperand(0), + N020.getOperand(1), N1); } } } @@ -13162,41 +13167,41 @@ // FIXME: This turns two single-precision and one double-precision // operation into two double-precision operations, which might not be // interesting for all targets, especially GPUs. - auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V, - SDValue Z) { + auto FoldFAddFPExtFMAFMul = [&](unsigned FusedOpcode, SDValue X, SDValue Y, + SDValue U, SDValue V, SDValue Z) { return DAG.getNode( - PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, X), + FusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, X), DAG.getNode(ISD::FP_EXTEND, SL, VT, Y), - DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(FusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, U), DAG.getNode(ISD::FP_EXTEND, SL, VT, V), Z)); }; if (N0.getOpcode() == ISD::FP_EXTEND) { SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == PreferredFusedOpcode) { + if (isFusedOp(N00)) { + unsigned FusedOpcode = N00.getOpcode(); SDValue N002 = N00.getOperand(2); if (isContractableFMUL(N002) && - TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, - N00.getValueType())) { - return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1), - N002.getOperand(0), N002.getOperand(1), - N1); + TLI.isFPExtFoldable(DAG, FusedOpcode, VT, N00.getValueType())) { + return FoldFAddFPExtFMAFMul(FusedOpcode, N00.getOperand(0), + N00.getOperand(1), N002.getOperand(0), + N002.getOperand(1), N1); } } } // fold (fadd x, (fma y, z, (fpext (fmul u, v))) // -> (fma y, z, (fma (fpext u), (fpext v), x)) - if (N1.getOpcode() == PreferredFusedOpcode) { + if (isFusedOp(N1)) { + unsigned FusedOpcode = N1.getOpcode(); SDValue N12 = N1.getOperand(2); if (N12.getOpcode() == ISD::FP_EXTEND) { SDValue N120 = N12.getOperand(0); if (isContractableFMUL(N120) && - TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, - N120.getValueType())) { - return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1), - N120.getOperand(0), N120.getOperand(1), - N0); + TLI.isFPExtFoldable(DAG, FusedOpcode, VT, N120.getValueType())) { + return FoldFAddFMAFPExtFMul(FusedOpcode, N1.getOperand(0), + N1.getOperand(1), N120.getOperand(0), + N120.getOperand(1), N0); } } } @@ -13208,14 +13213,14 @@ // interesting for all targets, especially GPUs. if (N1.getOpcode() == ISD::FP_EXTEND) { SDValue N10 = N1.getOperand(0); - if (N10.getOpcode() == PreferredFusedOpcode) { + if (isFusedOp(N10)) { + unsigned FusedOpcode = N10.getOpcode(); SDValue N102 = N10.getOperand(2); if (isContractableFMUL(N102) && - TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, - N10.getValueType())) { - return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1), - N102.getOperand(0), N102.getOperand(1), - N0); + TLI.isFPExtFoldable(DAG, FusedOpcode, VT, N10.getValueType())) { + return FoldFAddFPExtFMAFMul(FusedOpcode, N10.getOperand(0), + N10.getOperand(1), N102.getOperand(0), + N102.getOperand(1), N0); } } } @@ -13404,50 +13409,55 @@ return isContractableFMUL(N) && isReassociable(N.getNode()); }; + auto isFusedOp = [&](SDValue N) { + unsigned Opcode = N.getOpcode(); + return Opcode == ISD::FMA || Opcode == ISD::FMAD; + }; + // More folding opportunities when target permits. if (Aggressive && isReassociable(N)) { bool CanFuse = Options.UnsafeFPMath || N->getFlags().hasAllowContract(); // fold (fsub (fma x, y, (fmul u, v)), z) // -> (fma x, y (fma u, v, (fneg z))) - if (CanFuse && N0.getOpcode() == PreferredFusedOpcode && + if (CanFuse && isFusedOp(N0) && isContractableAndReassociableFMUL(N0.getOperand(2)) && N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), - N0.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - N0.getOperand(2).getOperand(0), - N0.getOperand(2).getOperand(1), - DAG.getNode(ISD::FNEG, SL, VT, N1))); + unsigned FusedOpcode = N0.getOpcode(); + return DAG.getNode( + FusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), + DAG.getNode(FusedOpcode, SL, VT, N0.getOperand(2).getOperand(0), + N0.getOperand(2).getOperand(1), + DAG.getNode(ISD::FNEG, SL, VT, N1))); } // fold (fsub x, (fma y, z, (fmul u, v))) // -> (fma (fneg y), z, (fma (fneg u), v, x)) - if (CanFuse && N1.getOpcode() == PreferredFusedOpcode && + if (CanFuse && isFusedOp(N1) && isContractableAndReassociableFMUL(N1.getOperand(2)) && N1->hasOneUse() && NoSignedZero) { + unsigned FusedOpcode = N1.getOpcode(); SDValue N20 = N1.getOperand(2).getOperand(0); SDValue N21 = N1.getOperand(2).getOperand(1); return DAG.getNode( - PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, N20), N21, N0)); + FusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), + N1.getOperand(1), + DAG.getNode(FusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, N20), + N21, N0)); } // fold (fsub (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y (fma (fpext u), (fpext v), (fneg z))) - if (N0.getOpcode() == PreferredFusedOpcode && - N0->hasOneUse()) { + if (isFusedOp(N0) && N0->hasOneUse()) { + unsigned FusedOpcode = N0.getOpcode(); SDValue N02 = N0.getOperand(2); if (N02.getOpcode() == ISD::FP_EXTEND) { SDValue N020 = N02.getOperand(0); if (isContractableAndReassociableFMUL(N020) && - TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, - N020.getValueType())) { + TLI.isFPExtFoldable(DAG, FusedOpcode, VT, N020.getValueType())) { return DAG.getNode( - PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), + FusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), DAG.getNode( - PreferredFusedOpcode, SL, VT, + FusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)), DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)), DAG.getNode(ISD::FNEG, SL, VT, N1))); @@ -13463,17 +13473,17 @@ // interesting for all targets, especially GPUs. if (N0.getOpcode() == ISD::FP_EXTEND) { SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == PreferredFusedOpcode) { + if (isFusedOp(N00)) { + unsigned FusedOpcode = N00.getOpcode(); SDValue N002 = N00.getOperand(2); if (isContractableAndReassociableFMUL(N002) && - TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, - N00.getValueType())) { + TLI.isFPExtFoldable(DAG, FusedOpcode, VT, N00.getValueType())) { return DAG.getNode( - PreferredFusedOpcode, SL, VT, + FusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), DAG.getNode( - PreferredFusedOpcode, SL, VT, + FusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)), DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)), DAG.getNode(ISD::FNEG, SL, VT, N1))); @@ -13483,19 +13493,18 @@ // fold (fsub x, (fma y, z, (fpext (fmul u, v)))) // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x)) - if (N1.getOpcode() == PreferredFusedOpcode && - N1.getOperand(2).getOpcode() == ISD::FP_EXTEND && + if (isFusedOp(N1) && N1.getOperand(2).getOpcode() == ISD::FP_EXTEND && N1->hasOneUse()) { + unsigned FusedOpcode = N1.getOpcode(); SDValue N120 = N1.getOperand(2).getOperand(0); if (isContractableAndReassociableFMUL(N120) && - TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, - N120.getValueType())) { + TLI.isFPExtFoldable(DAG, FusedOpcode, VT, N120.getValueType())) { SDValue N1200 = N120.getOperand(0); SDValue N1201 = N120.getOperand(1); return DAG.getNode( - PreferredFusedOpcode, SL, VT, + FusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(FusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, N1200)), DAG.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0)); @@ -13508,23 +13517,22 @@ // FIXME: This turns two single-precision and one double-precision // operation into two double-precision operations, which might not be // interesting for all targets, especially GPUs. - if (N1.getOpcode() == ISD::FP_EXTEND && - N1.getOperand(0).getOpcode() == PreferredFusedOpcode) { + if (N1.getOpcode() == ISD::FP_EXTEND && isFusedOp(N1.getOperand(0))) { SDValue CvtSrc = N1.getOperand(0); SDValue N100 = CvtSrc.getOperand(0); SDValue N101 = CvtSrc.getOperand(1); SDValue N102 = CvtSrc.getOperand(2); + unsigned FusedOpcode = CvtSrc.getOpcode(); if (isContractableAndReassociableFMUL(N102) && - TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, - CvtSrc.getValueType())) { + TLI.isFPExtFoldable(DAG, FusedOpcode, VT, CvtSrc.getValueType())) { SDValue N1020 = N102.getOperand(0); SDValue N1021 = N102.getOperand(1); return DAG.getNode( - PreferredFusedOpcode, SL, VT, + FusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, N100)), DAG.getNode(ISD::FP_EXTEND, SL, VT, N101), - DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(FusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, N1020)), DAG.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0)); diff --git a/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll b/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll --- a/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll +++ b/llvm/test/CodeGen/AMDGPU/dagcombine-fma-fmad.ll @@ -27,50 +27,49 @@ ; GCN-NEXT: s_buffer_load_dwordx4 s[0:3], s[0:3], 0x40 ; GCN-NEXT: s_waitcnt lgkmcnt(0) ; GCN-NEXT: v_sub_f32_e64 v5, s24, s28 -; GCN-NEXT: v_add_f32_e64 v7, s29, -1.0 ; GCN-NEXT: s_clause 0x1 ; GCN-NEXT: s_buffer_load_dwordx4 s[4:7], s[0:3], 0x50 ; GCN-NEXT: s_nop 0 ; GCN-NEXT: s_buffer_load_dword s0, s[0:3], 0x2c ; GCN-NEXT: v_fma_f32 v1, v1, v5, s28 +; GCN-NEXT: v_add_f32_e64 v5, s29, -1.0 ; GCN-NEXT: s_waitcnt lgkmcnt(0) -; GCN-NEXT: s_clause 0x3 +; GCN-NEXT: s_clause 0x4 ; GCN-NEXT: s_buffer_load_dwordx4 s[8:11], s[0:3], 0x60 ; GCN-NEXT: s_buffer_load_dwordx4 s[12:15], s[0:3], 0x20 ; GCN-NEXT: s_buffer_load_dwordx4 s[16:19], s[0:3], 0x0 ; GCN-NEXT: s_buffer_load_dwordx4 s[20:23], s[0:3], 0x70 -; GCN-NEXT: v_max_f32_e64 v6, s0, s0 clamp ; GCN-NEXT: s_buffer_load_dwordx4 s[24:27], s[0:3], 0x10 -; GCN-NEXT: v_sub_f32_e32 v9, s0, v1 +; GCN-NEXT: v_max_f32_e64 v6, s0, s0 clamp +; GCN-NEXT: v_sub_f32_e32 v8, s0, v1 ; GCN-NEXT: s_mov_b32 s0, 0x3c23d70a -; GCN-NEXT: v_mul_f32_e32 v5, s2, v6 -; GCN-NEXT: v_fma_f32 v8, -s2, v6, s6 -; GCN-NEXT: v_fmac_f32_e32 v1, v6, v9 -; GCN-NEXT: v_fma_f32 v7, v6, v7, 1.0 -; GCN-NEXT: v_fmac_f32_e32 v5, v8, v6 +; GCN-NEXT: v_fma_f32 v7, -s2, v6, s6 +; GCN-NEXT: v_fmac_f32_e32 v1, v6, v8 +; GCN-NEXT: v_fma_f32 v5, v6, v5, 1.0 ; GCN-NEXT: s_waitcnt lgkmcnt(0) -; GCN-NEXT: v_mul_f32_e32 v8, s10, v0 +; GCN-NEXT: v_mul_f32_e32 v9, s10, v0 ; GCN-NEXT: v_fma_f32 v0, -v0, s10, s14 -; GCN-NEXT: v_fmac_f32_e32 v8, v0, v6 -; GCN-NEXT: v_sub_f32_e32 v0, v1, v7 -; GCN-NEXT: v_fmac_f32_e32 v7, v0, v6 +; GCN-NEXT: v_fmac_f32_e32 v9, v0, v6 +; GCN-NEXT: v_sub_f32_e32 v0, v1, v5 +; GCN-NEXT: v_fmac_f32_e32 v5, v0, v6 ; GCN-NEXT: s_waitcnt vmcnt(2) -; GCN-NEXT: v_mul_f32_e32 v9, s18, v2 +; GCN-NEXT: v_fma_f32 v10, s2, v6, v2 +; GCN-NEXT: v_mul_f32_e32 v8, s18, v2 ; GCN-NEXT: s_waitcnt vmcnt(1) ; GCN-NEXT: v_mul_f32_e32 v3, s22, v3 -; GCN-NEXT: v_add_f32_e32 v5, v2, v5 -; GCN-NEXT: v_mul_f32_e32 v1, v9, v6 -; GCN-NEXT: v_mul_f32_e32 v9, v6, v3 -; GCN-NEXT: v_fmac_f32_e64 v8, -v6, v3 +; GCN-NEXT: v_fmac_f32_e32 v10, v7, v6 +; GCN-NEXT: v_mul_f32_e32 v1, v8, v6 +; GCN-NEXT: v_mul_f32_e32 v7, v6, v3 +; GCN-NEXT: v_fmac_f32_e64 v9, -v6, v3 ; GCN-NEXT: s_waitcnt vmcnt(0) -; GCN-NEXT: v_add_f32_e32 v4, v4, v5 +; GCN-NEXT: v_add_f32_e32 v3, v4, v10 ; GCN-NEXT: v_fma_f32 v0, v2, s26, -v1 -; GCN-NEXT: v_fmac_f32_e32 v9, v8, v6 -; GCN-NEXT: v_mul_f32_e32 v3, v4, v6 -; GCN-NEXT: v_fma_f32 v4, v7, s0, 0x3ca3d70a +; GCN-NEXT: v_fma_f32 v4, v5, s0, 0x3ca3d70a +; GCN-NEXT: v_fmac_f32_e32 v7, v9, v6 +; GCN-NEXT: v_mul_f32_e32 v3, v3, v6 ; GCN-NEXT: v_fmac_f32_e32 v1, v0, v6 ; GCN-NEXT: v_mul_f32_e32 v0, v2, v6 -; GCN-NEXT: v_mul_f32_e32 v2, v9, v4 +; GCN-NEXT: v_mul_f32_e32 v2, v7, v4 ; GCN-NEXT: v_mul_f32_e32 v1, v3, v1 ; GCN-NEXT: v_fmac_f32_e32 v1, v2, v0 ; GCN-NEXT: v_max_f32_e32 v0, 0, v1 diff --git a/llvm/test/CodeGen/AMDGPU/mad-combine.ll b/llvm/test/CodeGen/AMDGPU/mad-combine.ll --- a/llvm/test/CodeGen/AMDGPU/mad-combine.ll +++ b/llvm/test/CodeGen/AMDGPU/mad-combine.ll @@ -400,9 +400,12 @@ ; SI-DAG: buffer_load_dword [[D:v[0-9]+]], v{{\[[0-9]+:[0-9]+\]}}, s{{\[[0-9]+:[0-9]+\]}}, 0 addr64 offset:12 glc{{$}} ; SI-DAG: buffer_load_dword [[E:v[0-9]+]], v{{\[[0-9]+:[0-9]+\]}}, s{{\[[0-9]+:[0-9]+\]}}, 0 addr64 offset:16 glc{{$}} -; SI-STD: v_mul_f32_e32 [[TMP0:v[0-9]+]], [[D]], [[E]] -; SI-STD: v_fma_f32 [[TMP1:v[0-9]+]], [[A]], [[B]], [[TMP0]] -; SI-STD: v_sub_f32_e32 [[RESULT:v[0-9]+]], [[TMP1]], [[C]] +; SI-STD-SAFE: v_mul_f32_e32 [[TMP0:v[0-9]+]], [[D]], [[E]] +; SI-STD-SAFE: v_fma_f32 [[TMP1:v[0-9]+]], [[A]], [[B]], [[TMP0]] +; SI-STD-SAFE: v_sub_f32_e32 [[RESULT:v[0-9]+]], [[TMP1]], [[C]] + +; SI-STD-UNSAFE: v_fma_f32 [[TMP0:v[0-9]+]], [[D]], [[E]], -[[C]] +; SI-STD-UNSAFE: v_fma_f32 [[RESULT:v[0-9]+]], [[A]], [[B]], [[TMP0]] ; SI-DENORM: v_mul_f32_e32 [[TMP0:v[0-9]+]], [[D]], [[E]] ; SI-DENORM: v_fma_f32 [[TMP1:v[0-9]+]], [[A]], [[B]], [[TMP0]]