diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -204,10 +204,18 @@ def AArch64umin_p : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>; def AArch64umulh_p : SDNode<"AArch64ISD::MULHU_PRED", SDT_AArch64Arith>; +def AArch64fadd_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3), + (AArch64fadd_p node:$op1, node:$op2, node:$op3), [{ + return N->getFlags().hasAllowContract(); +}]>; def AArch64fadd_p_nsz : PatFrag<(ops node:$op1, node:$op2, node:$op3), (AArch64fadd_p node:$op1, node:$op2, node:$op3), [{ return N->getFlags().hasNoSignedZeros(); }]>; +def AArch64fsub_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3), + (AArch64fsub_p node:$op1, node:$op2, node:$op3), [{ + return N->getFlags().hasAllowContract(); +}]>; def AArch64fsub_p_nsz : PatFrag<(ops node:$op1, node:$op2, node:$op3), (AArch64fsub_p node:$op1, node:$op2, node:$op3), [{ return N->getFlags().hasNoSignedZeros(); @@ -363,14 +371,12 @@ (AArch64fabs_mt node:$pg, (AArch64fsub_p node:$pg, node:$op1, node:$op2), undef)]>; def AArch64fmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), - [(AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), - (vselect node:$pg, (AArch64fma_p (AArch64ptrue 31), node:$zn, node:$zm, node:$za), node:$za)]>; + [(AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za)]>; def AArch64fmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), [(int_aarch64_sve_fmls_u node:$pg, node:$za, node:$zn, node:$zm), (AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, node:$za), - (AArch64fma_p node:$pg, node:$zm, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$za), - (vselect node:$pg, (AArch64fma_p (AArch64ptrue 31), (AArch64fneg_mt (AArch64ptrue 31), node:$zn, (undef)), node:$zm, node:$za), node:$za)]>; + (AArch64fma_p node:$pg, node:$zm, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$za)]>; def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), [(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm), @@ -423,18 +429,15 @@ [(int_aarch64_sve_eor3 node:$op1, node:$op2, node:$op3), (xor node:$op1, (xor node:$op2, node:$op3))]>; -class fma_patfrags - : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3), - [(intrinsic node:$pred, node:$op1, node:$op2, node:$op3), - (vselect node:$pred, (add (SVEAllActive), node:$op1, (AArch64fmul_p_oneuse (SVEAllActive), node:$op2, node:$op3)), node:$op1)], -[{ - if (N->getOpcode() == ISD::VSELECT) - return N->getOperand(1)->getFlags().hasAllowContract(); - return true; // it's the intrinsic -}]>; +def AArch64fmla_m1 : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), + [(int_aarch64_sve_fmla node:$pg, node:$za, node:$zn, node:$zm), + (vselect node:$pg, (AArch64fadd_p_contract (SVEAllActive), node:$za, (AArch64fmul_p_oneuse (SVEAllActive), node:$zn, node:$zm)), node:$za), + (vselect node:$pg, (AArch64fma_p (AArch64ptrue 31), node:$zn, node:$zm, node:$za), node:$za)]>; -def AArch64fmla_m1 : fma_patfrags; -def AArch64fmls_m1 : fma_patfrags; +def AArch64fmls_m1 : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm), + [(int_aarch64_sve_fmls node:$pg, node:$za, node:$zn, node:$zm), + (vselect node:$pg, (AArch64fsub_p_contract (SVEAllActive), node:$za, (AArch64fmul_p_oneuse (SVEAllActive), node:$zn, node:$zm)), node:$za), + (vselect node:$pg, (AArch64fma_p (AArch64ptrue 31), (AArch64fneg_mt (AArch64ptrue 31), node:$zn, (undef)), node:$zm, node:$za), node:$za)]>; def AArch64add_m1 : VSelectUnpredOrPassthruPatFrags; def AArch64sub_m1 : VSelectUnpredOrPassthruPatFrags; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -2317,7 +2317,10 @@ SVEPseudo2Instr, SVEInstr2Rev; def : SVE_4_Op_Pat(NAME # _H)>; + def : SVE_4_Op_Pat(NAME # _H)>; + def : SVE_4_Op_Pat(NAME # _H)>; def : SVE_4_Op_Pat(NAME # _S)>; + def : SVE_4_Op_Pat(NAME # _S)>; def : SVE_4_Op_Pat(NAME # _D)>; } diff --git a/llvm/test/CodeGen/AArch64/sve-fp-combine.ll b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll --- a/llvm/test/CodeGen/AArch64/sve-fp-combine.ll +++ b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll @@ -1272,7 +1272,8 @@ define @fadd_sel_fmul_h_different_arg_order( %pred, %m1, %m2, %acc) { ; CHECK-LABEL: fadd_sel_fmul_h_different_arg_order: ; CHECK: // %bb.0: -; CHECK-NEXT: fmad z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: fmla z2.h, p0/m, z0.h, z1.h +; CHECK-NEXT: mov z0.d, z2.d ; CHECK-NEXT: ret %mul = fmul contract %m1, %m2 %masked.mul = select %pred, %mul, zeroinitializer @@ -1283,7 +1284,8 @@ define @fadd_sel_fmul_s_different_arg_order( %pred, %m1, %m2, %acc) { ; CHECK-LABEL: fadd_sel_fmul_s_different_arg_order: ; CHECK: // %bb.0: -; CHECK-NEXT: fmad z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: fmla z2.s, p0/m, z0.s, z1.s +; CHECK-NEXT: mov z0.d, z2.d ; CHECK-NEXT: ret %mul = fmul contract %m1, %m2 %masked.mul = select %pred, %mul, zeroinitializer @@ -1294,7 +1296,8 @@ define @fadd_sel_fmul_d_different_arg_order( %pred, %m1, %m2, %acc) { ; CHECK-LABEL: fadd_sel_fmul_d_different_arg_order: ; CHECK: // %bb.0: -; CHECK-NEXT: fmad z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: fmla z2.d, p0/m, z0.d, z1.d +; CHECK-NEXT: mov z0.d, z2.d ; CHECK-NEXT: ret %mul = fmul contract %m1, %m2 %masked.mul = select %pred, %mul, zeroinitializer