Index: lib/Target/X86/X86InstrAVX512.td =================================================================== --- lib/Target/X86/X86InstrAVX512.td +++ lib/Target/X86/X86InstrAVX512.td @@ -6453,14 +6453,6 @@ _.RC:$src1,(_.VT (X86VBroadcast (_.ScalarLdFrag addr:$src3)))), 1, 0>, AVX512FMA3Base, EVEX_B; } - - // Additional pattern for folding broadcast nodes in other orders. - def : Pat<(_.VT (vselect _.KRCWM:$mask, - (OpNode _.RC:$src1, _.RC:$src2, - (X86VBroadcast (_.ScalarLdFrag addr:$src3))), - _.RC:$src1)), - (!cast(NAME#Suff#_.ZSuffix#mbk) _.RC:$src1, - _.KRCWM:$mask, _.RC:$src2, addr:$src3)>; } multiclass avx512_fma3_213_round opc, string OpcodeStr, SDNode OpNode, @@ -6497,7 +6489,7 @@ avx512vl_f64_info, "PD">, VEX_W; } -defm VFMADD213 : avx512_fma3p_213_f<0xA8, "vfmadd213", fma, X86FmaddRnd>; +defm VFMADD213 : avx512_fma3p_213_f<0xA8, "vfmadd213", X86Fmadd, X86FmaddRnd>; defm VFMSUB213 : avx512_fma3p_213_f<0xAA, "vfmsub213", X86Fmsub, X86FmsubRnd>; defm VFMADDSUB213 : avx512_fma3p_213_f<0xA6, "vfmaddsub213", X86Fmaddsub, X86FmaddsubRnd>; defm VFMSUBADD213 : avx512_fma3p_213_f<0xA7, "vfmsubadd213", X86Fmsubadd, X86FmsubaddRnd>; @@ -6528,24 +6520,6 @@ (_.VT (X86VBroadcast(_.ScalarLdFrag addr:$src3))), _.RC:$src1)), 1, 0>, AVX512FMA3Base, EVEX_B; } - - // Additional patterns for folding broadcast nodes in other orders. - def : Pat<(_.VT (OpNode (X86VBroadcast (_.ScalarLdFrag addr:$src3)), - _.RC:$src2, _.RC:$src1)), - (!cast(NAME#Suff#_.ZSuffix#mb) _.RC:$src1, - _.RC:$src2, addr:$src3)>; - def : Pat<(_.VT (vselect _.KRCWM:$mask, - (OpNode (X86VBroadcast (_.ScalarLdFrag addr:$src3)), - _.RC:$src2, _.RC:$src1), - _.RC:$src1)), - (!cast(NAME#Suff#_.ZSuffix#mbk) _.RC:$src1, - _.KRCWM:$mask, _.RC:$src2, addr:$src3)>; - def : Pat<(_.VT (vselect _.KRCWM:$mask, - (OpNode (X86VBroadcast (_.ScalarLdFrag addr:$src3)), - _.RC:$src2, _.RC:$src1), - _.ImmAllZerosV)), - (!cast(NAME#Suff#_.ZSuffix#mbkz) _.RC:$src1, - _.KRCWM:$mask, _.RC:$src2, addr:$src3)>; } multiclass avx512_fma3_231_round opc, string OpcodeStr, SDNode OpNode, @@ -6583,7 +6557,7 @@ avx512vl_f64_info, "PD">, VEX_W; } -defm VFMADD231 : avx512_fma3p_231_f<0xB8, "vfmadd231", fma, X86FmaddRnd>; +defm VFMADD231 : avx512_fma3p_231_f<0xB8, "vfmadd231", X86Fmadd, X86FmaddRnd>; defm VFMSUB231 : avx512_fma3p_231_f<0xBA, "vfmsub231", X86Fmsub, X86FmsubRnd>; defm VFMADDSUB231 : avx512_fma3p_231_f<0xB6, "vfmaddsub231", X86Fmaddsub, X86FmaddsubRnd>; defm VFMSUBADD231 : avx512_fma3p_231_f<0xB7, "vfmsubadd231", X86Fmsubadd, X86FmsubaddRnd>; @@ -6613,14 +6587,6 @@ (_.VT (X86VBroadcast(_.ScalarLdFrag addr:$src3))), _.RC:$src2)), 1, 0>, AVX512FMA3Base, EVEX_B; } - - // Additional patterns for folding broadcast nodes in other orders. - def : Pat<(_.VT (vselect _.KRCWM:$mask, - (OpNode (X86VBroadcast (_.ScalarLdFrag addr:$src3)), - _.RC:$src1, _.RC:$src2), - _.RC:$src1)), - (!cast(NAME#Suff#_.ZSuffix#mbk) _.RC:$src1, - _.KRCWM:$mask, _.RC:$src2, addr:$src3)>; } multiclass avx512_fma3_132_round opc, string OpcodeStr, SDNode OpNode, @@ -6658,7 +6624,7 @@ avx512vl_f64_info, "PD">, VEX_W; } -defm VFMADD132 : avx512_fma3p_132_f<0x98, "vfmadd132", fma, X86FmaddRnd>; +defm VFMADD132 : avx512_fma3p_132_f<0x98, "vfmadd132", X86Fmadd, X86FmaddRnd>; defm VFMSUB132 : avx512_fma3p_132_f<0x9A, "vfmsub132", X86Fmsub, X86FmsubRnd>; defm VFMADDSUB132 : avx512_fma3p_132_f<0x96, "vfmaddsub132", X86Fmaddsub, X86FmaddsubRnd>; defm VFMSUBADD132 : avx512_fma3p_132_f<0x97, "vfmsubadd132", X86Fmsubadd, X86FmsubaddRnd>; @@ -6757,7 +6723,7 @@ } } -defm VFMADD : avx512_fma3s<0xA9, 0xB9, 0x99, "vfmadd", fma, X86FmaddRnds1, +defm VFMADD : avx512_fma3s<0xA9, 0xB9, 0x99, "vfmadd", X86Fmadd, X86FmaddRnds1, X86FmaddRnds3>; defm VFMSUB : avx512_fma3s<0xAB, 0xBB, 0x9B, "vfmsub", X86Fmsub, X86FmsubRnds1, X86FmsubRnds3>; Index: lib/Target/X86/X86InstrFMA.td =================================================================== --- lib/Target/X86/X86InstrFMA.td +++ lib/Target/X86/X86InstrFMA.td @@ -89,7 +89,7 @@ // Fused Multiply-Add let ExeDomain = SSEPackedSingle in { defm VFMADD : fma3p_forms<0x98, 0xA8, 0xB8, "vfmadd", "ps", "PS", - loadv4f32, loadv8f32, fma, v4f32, v8f32>; + loadv4f32, loadv8f32, X86Fmadd, v4f32, v8f32>; defm VFMSUB : fma3p_forms<0x9A, 0xAA, 0xBA, "vfmsub", "ps", "PS", loadv4f32, loadv8f32, X86Fmsub, v4f32, v8f32>; defm VFMADDSUB : fma3p_forms<0x96, 0xA6, 0xB6, "vfmaddsub", "ps", "PS", @@ -102,7 +102,7 @@ let ExeDomain = SSEPackedDouble in { defm VFMADD : fma3p_forms<0x98, 0xA8, 0xB8, "vfmadd", "pd", "PD", - loadv2f64, loadv4f64, fma, v2f64, + loadv2f64, loadv4f64, X86Fmadd, v2f64, v4f64>, VEX_W; defm VFMSUB : fma3p_forms<0x9A, 0xAA, 0xBA, "vfmsub", "pd", "PD", loadv2f64, loadv4f64, X86Fmsub, v2f64, @@ -271,7 +271,7 @@ } defm VFMADD : fma3s<0x99, 0xA9, 0xB9, "vfmadd", int_x86_fma_vfmadd_ss, - int_x86_fma_vfmadd_sd, fma>, VEX_LIG; + int_x86_fma_vfmadd_sd, X86Fmadd>, VEX_LIG; defm VFMSUB : fma3s<0x9B, 0xAB, 0xBB, "vfmsub", int_x86_fma_vfmsub_ss, int_x86_fma_vfmsub_sd, X86Fmsub>, VEX_LIG; @@ -407,7 +407,7 @@ let ExeDomain = SSEPackedSingle in { // Scalar Instructions - defm VFMADDSS4 : fma4s<0x6A, "vfmaddss", FR32, f32mem, f32, fma, loadf32>, + defm VFMADDSS4 : fma4s<0x6A, "vfmaddss", FR32, f32mem, f32, X86Fmadd, loadf32>, fma4s_int<0x6A, "vfmaddss", ssmem, sse_load_f32, int_x86_fma_vfmadd_ss>; defm VFMSUBSS4 : fma4s<0x6E, "vfmsubss", FR32, f32mem, f32, X86Fmsub, loadf32>, @@ -422,7 +422,7 @@ fma4s_int<0x7E, "vfnmsubss", ssmem, sse_load_f32, int_x86_fma_vfnmsub_ss>; // Packed Instructions - defm VFMADDPS4 : fma4p<0x68, "vfmaddps", fma, v4f32, v8f32, + defm VFMADDPS4 : fma4p<0x68, "vfmaddps", X86Fmadd, v4f32, v8f32, loadv4f32, loadv8f32>; defm VFMSUBPS4 : fma4p<0x6C, "vfmsubps", X86Fmsub, v4f32, v8f32, loadv4f32, loadv8f32>; @@ -438,7 +438,7 @@ let ExeDomain = SSEPackedDouble in { // Scalar Instructions - defm VFMADDSD4 : fma4s<0x6B, "vfmaddsd", FR64, f64mem, f64, fma, loadf64>, + defm VFMADDSD4 : fma4s<0x6B, "vfmaddsd", FR64, f64mem, f64, X86Fmadd, loadf64>, fma4s_int<0x6B, "vfmaddsd", sdmem, sse_load_f64, int_x86_fma_vfmadd_sd>; defm VFMSUBSD4 : fma4s<0x6F, "vfmsubsd", FR64, f64mem, f64, X86Fmsub, loadf64>, @@ -453,7 +453,7 @@ fma4s_int<0x7F, "vfnmsubsd", sdmem, sse_load_f64, int_x86_fma_vfnmsub_sd>; // Packed Instructions - defm VFMADDPD4 : fma4p<0x69, "vfmaddpd", fma, v2f64, v4f64, + defm VFMADDPD4 : fma4p<0x69, "vfmaddpd", X86Fmadd, v2f64, v4f64, loadv2f64, loadv4f64>; defm VFMSUBPD4 : fma4p<0x6D, "vfmsubpd", X86Fmsub, v2f64, v4f64, loadv2f64, loadv4f64>; Index: lib/Target/X86/X86InstrFragmentsSIMD.td =================================================================== --- lib/Target/X86/X86InstrFragmentsSIMD.td +++ lib/Target/X86/X86InstrFragmentsSIMD.td @@ -481,19 +481,19 @@ def X86fgetexpRnd : SDNode<"X86ISD::FGETEXP_RND", SDTFPUnaryOpRound>; def X86fgetexpRnds : SDNode<"X86ISD::FGETEXPS_RND", SDTFPBinOpRound>; -// No need for FMADD because we use ISD::FMA. -def X86Fnmadd : SDNode<"X86ISD::FNMADD", SDTFPTernaryOp>; -def X86Fmsub : SDNode<"X86ISD::FMSUB", SDTFPTernaryOp>; -def X86Fnmsub : SDNode<"X86ISD::FNMSUB", SDTFPTernaryOp>; -def X86Fmaddsub : SDNode<"X86ISD::FMADDSUB", SDTFPTernaryOp>; -def X86Fmsubadd : SDNode<"X86ISD::FMSUBADD", SDTFPTernaryOp>; - -def X86FmaddRnd : SDNode<"X86ISD::FMADD_RND", SDTFmaRound>; -def X86FnmaddRnd : SDNode<"X86ISD::FNMADD_RND", SDTFmaRound>; -def X86FmsubRnd : SDNode<"X86ISD::FMSUB_RND", SDTFmaRound>; -def X86FnmsubRnd : SDNode<"X86ISD::FNMSUB_RND", SDTFmaRound>; -def X86FmaddsubRnd : SDNode<"X86ISD::FMADDSUB_RND", SDTFmaRound>; -def X86FmsubaddRnd : SDNode<"X86ISD::FMSUBADD_RND", SDTFmaRound>; +def X86Fmadd : SDNode<"ISD::FMA", SDTFPTernaryOp, [SDNPCommutative]>; +def X86Fnmadd : SDNode<"X86ISD::FNMADD", SDTFPTernaryOp, [SDNPCommutative]>; +def X86Fmsub : SDNode<"X86ISD::FMSUB", SDTFPTernaryOp, [SDNPCommutative]>; +def X86Fnmsub : SDNode<"X86ISD::FNMSUB", SDTFPTernaryOp, [SDNPCommutative]>; +def X86Fmaddsub : SDNode<"X86ISD::FMADDSUB", SDTFPTernaryOp, [SDNPCommutative]>; +def X86Fmsubadd : SDNode<"X86ISD::FMSUBADD", SDTFPTernaryOp, [SDNPCommutative]>; + +def X86FmaddRnd : SDNode<"X86ISD::FMADD_RND", SDTFmaRound, [SDNPCommutative]>; +def X86FnmaddRnd : SDNode<"X86ISD::FNMADD_RND", SDTFmaRound, [SDNPCommutative]>; +def X86FmsubRnd : SDNode<"X86ISD::FMSUB_RND", SDTFmaRound, [SDNPCommutative]>; +def X86FnmsubRnd : SDNode<"X86ISD::FNMSUB_RND", SDTFmaRound, [SDNPCommutative]>; +def X86FmaddsubRnd : SDNode<"X86ISD::FMADDSUB_RND", SDTFmaRound, [SDNPCommutative]>; +def X86FmsubaddRnd : SDNode<"X86ISD::FMSUBADD_RND", SDTFmaRound, [SDNPCommutative]>; // Scalar FMA intrinsics with passthru bits in operand 1. def X86FmaddRnds1 : SDNode<"X86ISD::FMADDS1_RND", SDTFmaRound>; @@ -502,10 +502,10 @@ def X86FnmsubRnds1 : SDNode<"X86ISD::FNMSUBS1_RND", SDTFmaRound>; // Scalar FMA intrinsics with passthru bits in operand 3. -def X86FmaddRnds3 : SDNode<"X86ISD::FMADDS3_RND", SDTFmaRound>; -def X86FnmaddRnds3 : SDNode<"X86ISD::FNMADDS3_RND", SDTFmaRound>; -def X86FmsubRnds3 : SDNode<"X86ISD::FMSUBS3_RND", SDTFmaRound>; -def X86FnmsubRnds3 : SDNode<"X86ISD::FNMSUBS3_RND", SDTFmaRound>; +def X86FmaddRnds3 : SDNode<"X86ISD::FMADDS3_RND", SDTFmaRound, [SDNPCommutative]>; +def X86FnmaddRnds3 : SDNode<"X86ISD::FNMADDS3_RND", SDTFmaRound, [SDNPCommutative]>; +def X86FmsubRnds3 : SDNode<"X86ISD::FMSUBS3_RND", SDTFmaRound, [SDNPCommutative]>; +def X86FnmsubRnds3 : SDNode<"X86ISD::FNMSUBS3_RND", SDTFmaRound, [SDNPCommutative]>; def SDTIFma : SDTypeProfile<1, 3, [SDTCisInt<0>, SDTCisSameAs<0,1>, SDTCisSameAs<1,2>, SDTCisSameAs<1,3>]>; Index: utils/TableGen/CodeGenDAGPatterns.cpp =================================================================== --- utils/TableGen/CodeGenDAGPatterns.cpp +++ utils/TableGen/CodeGenDAGPatterns.cpp @@ -3744,7 +3744,7 @@ // If this node is commutative, consider the commuted order. bool isCommIntrinsic = N->isCommutativeIntrinsic(CDP); if (NodeInfo.hasProperty(SDNPCommutative) || isCommIntrinsic) { - assert((N->getNumChildren()==2 || isCommIntrinsic) && + assert((N->getNumChildren()>=2 || isCommIntrinsic) && "Commutative but doesn't have 2 children!"); // Don't count children which are actually register references. unsigned NC = 0; @@ -3772,9 +3772,14 @@ for (unsigned i = 3; i != NC; ++i) Variants.push_back(ChildVariants[i]); CombineChildVariants(N, Variants, OutVariants, CDP, DepVars); - } else if (NC == 2) - CombineChildVariants(N, ChildVariants[1], ChildVariants[0], - OutVariants, CDP, DepVars); + } else if (NC == N->getNumChildren()) { + std::vector > Variants; + Variants.push_back(ChildVariants[1]); + Variants.push_back(ChildVariants[0]); + for (unsigned i = 2; i != NC; ++i) + Variants.push_back(ChildVariants[i]); + CombineChildVariants(N, Variants, OutVariants, CDP, DepVars); + } } }