diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -171,7 +171,8 @@ def SDT_AArch64FMA : SDTypeProfile<1, 4, [ SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>, SDTCisVec<4>, - SDTCVecEltisVT<1,i1>, SDTCisSameAs<0,2>, SDTCisSameAs<2,3>, SDTCisSameAs<3,4> + SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0,1>, + SDTCisSameAs<0,2>, SDTCisSameAs<0,3>, SDTCisSameAs<0,4> ]>; // Predicated operations with the result of inactive lanes being unspecified. @@ -244,6 +245,11 @@ def AArch64revw_mt : SDNode<"AArch64ISD::REVW_MERGE_PASSTHRU", SDT_AArch64Arith>; def AArch64revd_mt : SDNode<"AArch64ISD::REVD_MERGE_PASSTHRU", SDT_AArch64Arith>; +def AArch64fneg_mt_nsz : PatFrag<(ops node:$pred, node:$op, node:$pt), + (AArch64fneg_mt node:$pred, node:$op, node:$pt), [{ + return N->getFlags().hasNoSignedZeros(); +}]>; + // These are like the above but we don't yet have need for ISD nodes. They allow // a single pattern to match intrinsic and ISD operand layouts. def AArch64cls_mt : PatFrags<(ops node:$pg, node:$op, node:$pt), [(int_aarch64_sve_cls node:$pt, node:$pg, node:$op)]>; @@ -349,19 +355,25 @@ def AArch64fabd_p : PatFrag<(ops node:$pg, node:$op1, node:$op2), (AArch64fabs_mt node:$pg, (AArch64fsub_p node:$pg, node:$op1, node:$op2), undef)>; -// FMAs with a negated multiplication operand can be commuted. -def AArch64fmls_p : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3), - [(AArch64fma_p node:$pred, (AArch64fneg_mt node:$pred, node:$op1, (undef)), node:$op2, node:$op3), - (AArch64fma_p node:$pred, node:$op2, (AArch64fneg_mt node:$pred, node:$op1, (undef)), node:$op3)]>; +def AArch64fmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), + [(AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), + (vselect node:$pg, (AArch64fma_p (AArch64ptrue 31), node:$zn, node:$zm, node:$za), node:$za)]>; + +def AArch64fmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), + [(AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, node:$za), + (AArch64fma_p node:$pg, node:$zm, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$za), + (vselect node:$pg, (AArch64fma_p (AArch64ptrue 31), (AArch64fneg_mt (AArch64ptrue 31), node:$zn, (undef)), node:$zm, node:$za), node:$za)]>; + +def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), + [(AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))), + (AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>; + +def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), + [(AArch64fma_p node:$pg, node:$zn, node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef)))]>; def AArch64fsubr_p : PatFrag<(ops node:$pg, node:$op1, node:$op2), (AArch64fsub_p node:$pg, node:$op2, node:$op1)>; -def AArch64fneg_mt_nsz : PatFrag<(ops node:$pred, node:$op, node:$pt), - (AArch64fneg_mt node:$pred, node:$op, node:$pt), [{ - return N->getFlags().hasNoSignedZeros(); -}]>; - def SDT_AArch64Arith_Unpred : SDTypeProfile<1, 2, [ SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisSameAs<0,1>, SDTCisSameAs<1,2> @@ -649,7 +661,7 @@ } // End HasSVE let Predicates = [HasSVEorSME] in { - defm FCADD_ZPmZ : sve_fp_fcadd<"fcadd", int_aarch64_sve_fcadd>; + defm FCADD_ZPmZ : sve_fp_fcadd<"fcadd", int_aarch64_sve_fcadd>; defm FCMLA_ZPmZZ : sve_fp_fcmla<"fcmla", int_aarch64_sve_fcmla>; defm FMLA_ZPmZZ : sve_fp_3op_p_zds_a<0b00, "fmla", "FMLA_ZPZZZ", AArch64fmla_m1, "FMAD_ZPmZZ">; @@ -662,48 +674,10 @@ defm FNMAD_ZPmZZ : sve_fp_3op_p_zds_b<0b10, "fnmad", int_aarch64_sve_fnmad, "FNMLA_ZPmZZ", /*isReverseInstr*/ 1>; defm FNMSB_ZPmZZ : sve_fp_3op_p_zds_b<0b11, "fnmsb", int_aarch64_sve_fnmsb, "FNMLS_ZPmZZ", /*isReverseInstr*/ 1>; - defm FMLA_ZPZZZ : sve_fp_3op_p_zds_zx; - defm FMLS_ZPZZZ : sve_fp_3op_p_zds_zx; - defm FNMLA_ZPZZZ : sve_fp_3op_p_zds_zx; - defm FNMLS_ZPZZZ : sve_fp_3op_p_zds_zx; - - multiclass fma { - // Zd = Za + Zn * Zm - def : Pat<(Ty (AArch64fma_p PredTy:$P, Ty:$Zn, Ty:$Zm, Ty:$Za)), - (!cast("FMLA_ZPZZZ_UNDEF_"#Suffix) $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; - - // Zd = Za + -Zn * Zm - def : Pat<(Ty (AArch64fmls_p PredTy:$P, Ty:$Zn, Ty:$Zm, Ty:$Za)), - (!cast("FMLS_ZPZZZ_UNDEF_"#Suffix) $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; - - // Zd = -Za + Zn * Zm - def : Pat<(Ty (AArch64fma_p PredTy:$P, Ty:$Zn, Ty:$Zm, (AArch64fneg_mt PredTy:$P, Ty:$Za, (Ty (undef))))), - (!cast("FNMLS_ZPZZZ_UNDEF_"#Suffix) $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; - - // Zd = -Za + -Zn * Zm - def : Pat<(Ty (AArch64fma_p PredTy:$P, (AArch64fneg_mt PredTy:$P, Ty:$Zn, (Ty (undef))), Ty:$Zm, (AArch64fneg_mt PredTy:$P, Ty:$Za, (Ty (undef))))), - (!cast("FNMLA_ZPZZZ_UNDEF_"#Suffix) $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; - - // Zd = -(Za + Zn * Zm) - // (with nsz neg.) - def : Pat<(AArch64fneg_mt_nsz PredTy:$P, (AArch64fma_p PredTy:$P, Ty:$Zn, Ty:$Zm, Ty:$Za), (Ty (undef))), - (!cast("FNMLA_ZPZZZ_UNDEF_"#Suffix) $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; - - // Zda = Zda + Zn * Zm - def : Pat<(vselect (PredTy PPR:$Pg), (Ty (AArch64fma_p (PredTy (AArch64ptrue 31)), ZPR:$Zn, ZPR:$Zm, ZPR:$Za)), ZPR:$Za), - (!cast("FMLA_ZPmZZ_"#Suffix) PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; - - // Zda = Zda + -Zn * Zm - def : Pat<(vselect (PredTy PPR:$Pg), (Ty (AArch64fma_p (PredTy (AArch64ptrue 31)), (AArch64fneg_mt (PredTy (AArch64ptrue 31)), Ty:$Zn, (Ty (undef))), ZPR:$Zm, ZPR:$Za)), ZPR:$Za), - (!cast("FMLS_ZPmZZ_"#Suffix) PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; - } - - defm : fma; - defm : fma; - defm : fma; - defm : fma; - defm : fma; - defm : fma; + defm FMLA_ZPZZZ : sve_fp_3op_pred_hfd; + defm FMLS_ZPZZZ : sve_fp_3op_pred_hfd; + defm FNMLA_ZPZZZ : sve_fp_3op_pred_hfd; + defm FNMLS_ZPZZZ : sve_fp_3op_pred_hfd; } // End HasSVEorSME let Predicates = [HasSVE] in { diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -2286,12 +2286,6 @@ def : SVE_4_Op_Pat(NAME # _D)>; } -multiclass sve_fp_3op_p_zds_zx { - def _UNDEF_H : PredThreeOpPseudo; - def _UNDEF_S : PredThreeOpPseudo; - def _UNDEF_D : PredThreeOpPseudo; -} - //===----------------------------------------------------------------------===// // SVE Floating Point Multiply-Add - Indexed Group //===----------------------------------------------------------------------===// @@ -8963,6 +8957,20 @@ def : SVE_3_Op_Pat(NAME # _UNDEF_D)>; } +// Predicated pseudo floating point three operand instructions. +multiclass sve_fp_3op_pred_hfd { + def _UNDEF_H : PredThreeOpPseudo; + def _UNDEF_S : PredThreeOpPseudo; + def _UNDEF_D : PredThreeOpPseudo; + + def : SVE_4_Op_Pat(NAME # _UNDEF_H)>; + def : SVE_4_Op_Pat(NAME # _UNDEF_H)>; + def : SVE_4_Op_Pat(NAME # _UNDEF_H)>; + def : SVE_4_Op_Pat(NAME # _UNDEF_S)>; + def : SVE_4_Op_Pat(NAME # _UNDEF_S)>; + def : SVE_4_Op_Pat(NAME # _UNDEF_D)>; +} + // Predicated pseudo integer two operand instructions. multiclass sve_int_bin_pred_bhsd { def _UNDEF_B : PredTwoOpPseudo;