diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -199,6 +199,11 @@ def AArch64umin_p : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>; def AArch64umulh_p : SDNode<"AArch64ISD::MULHU_PRED", SDT_AArch64Arith>; +def AArch64fadd_p_nsz : PatFrag<(ops node:$op1, node:$op2, node:$op3), + (AArch64fadd_p node:$op1, node:$op2, node:$op3), [{ + return N->getFlags().hasNoSignedZeros(); +}]>; + def SDT_AArch64Arith_Imm : SDTypeProfile<1, 3, [ SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVT<3,i32>, SDTCVecEltisVT<1,i1>, SDTCisSameAs<0,2> @@ -242,8 +247,16 @@ def AArch64not_mt : PatFrags<(ops node:$pg, node:$op, node:$pt), [(int_aarch64_sve_not node:$pt, node:$pg, node:$op)]>; def AArch64fmul_m1 : EitherVSelectOrPassthruPatFrags; -def AArch64fadd_m1 : EitherVSelectOrPassthruPatFrags; -def AArch64fsub_m1 : EitherVSelectOrPassthruPatFrags; +def AArch64fadd_m1 : PatFrags<(ops node:$pg, node:$op1, node:$op2), [ + (int_aarch64_sve_fadd node:$pg, node:$op1, node:$op2), + (vselect node:$pg, (AArch64fadd_p (SVEAllActive), node:$op1, node:$op2), node:$op1), + (AArch64fadd_p_nsz (SVEAllActive), node:$op1, (vselect node:$pg, node:$op2, (SVEDup0))) +]>; +def AArch64fsub_m1 : PatFrags<(ops node:$pg, node:$op1, node:$op2), [ + (int_aarch64_sve_fsub node:$pg, node:$op1, node:$op2), + (vselect node:$pg, (AArch64fsub_p (SVEAllActive), node:$op1, node:$op2), node:$op1), + (AArch64fsub_p (SVEAllActive), node:$op1, (vselect node:$pg, node:$op2, (SVEDup0))) +]>; def AArch64saba : PatFrags<(ops node:$op1, node:$op2, node:$op3), [(int_aarch64_sve_saba node:$op1, node:$op2, node:$op3), @@ -308,6 +321,12 @@ return N->hasOneUse(); }]>; +def AArch64fmul_p_oneuse : PatFrag<(ops node:$pred, node:$src1, node:$src2), + (AArch64fmul_p node:$pred, node:$src1, node:$src2), [{ + return N->hasOneUse(); +}]>; + + def AArch64fabd_p : PatFrag<(ops node:$pg, node:$op1, node:$op2), (AArch64fabs_mt node:$pg, (AArch64fsub_p node:$pg, node:$op1, node:$op2), undef)>; @@ -356,6 +375,20 @@ // sub(a, select(mask, mul(b, c), splat(0))) -> mls(a, mask, b, c) (sub node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))]>; +class fma_patfrags + : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3), + [(intrinsic node:$pred, node:$op1, node:$op2, node:$op3), + (sdnode (SVEAllActive), node:$op1, (vselect node:$pred, (AArch64fmul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))], + [{ + if ((N->getOpcode() != AArch64ISD::FADD_PRED) && + (N->getOpcode() != AArch64ISD::FSUB_PRED)) + return true; // it's the intrinsic + return N->getFlags().hasAllowContract(); +}]>; + +def AArch64fmla_m1 : fma_patfrags; +def AArch64fmls_m1 : fma_patfrags; + let Predicates = [HasSVE] in { defm RDFFR_PPz : sve_int_rdffr_pred<0b0, "rdffr", int_aarch64_sve_rdffr_z>; def RDFFRS_PPz : sve_int_rdffr_pred<0b1, "rdffrs">; @@ -592,8 +625,8 @@ defm FCADD_ZPmZ : sve_fp_fcadd<"fcadd", int_aarch64_sve_fcadd>; defm FCMLA_ZPmZZ : sve_fp_fcmla<"fcmla", int_aarch64_sve_fcmla>; - defm FMLA_ZPmZZ : sve_fp_3op_p_zds_a<0b00, "fmla", "FMLA_ZPZZZ", int_aarch64_sve_fmla, "FMAD_ZPmZZ">; - defm FMLS_ZPmZZ : sve_fp_3op_p_zds_a<0b01, "fmls", "FMLS_ZPZZZ", int_aarch64_sve_fmls, "FMSB_ZPmZZ">; + defm FMLA_ZPmZZ : sve_fp_3op_p_zds_a<0b00, "fmla", "FMLA_ZPZZZ", AArch64fmla_m1, "FMAD_ZPmZZ">; + defm FMLS_ZPmZZ : sve_fp_3op_p_zds_a<0b01, "fmls", "FMLS_ZPZZZ", AArch64fmls_m1, "FMSB_ZPmZZ">; defm FNMLA_ZPmZZ : sve_fp_3op_p_zds_a<0b10, "fnmla", "FNMLA_ZPZZZ", int_aarch64_sve_fnmla, "FNMAD_ZPmZZ">; defm FNMLS_ZPmZZ : sve_fp_3op_p_zds_a<0b11, "fnmls", "FNMLS_ZPZZZ", int_aarch64_sve_fnmls, "FNMSB_ZPmZZ">; diff --git a/llvm/test/CodeGen/AArch64/sve-fp-combine.ll b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll --- a/llvm/test/CodeGen/AArch64/sve-fp-combine.ll +++ b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll @@ -826,9 +826,7 @@ define @fadd_h_sel( %a, %b, %mask) { ; CHECK-LABEL: fadd_h_sel: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z2.h, #0 // =0x0 -; CHECK-NEXT: sel z1.h, p0, z1.h, z2.h -; CHECK-NEXT: fadd z0.h, z0.h, z1.h +; CHECK-NEXT: fadd z0.h, p0/m, z0.h, z1.h ; CHECK-NEXT: ret %sel = select %mask, %b, zeroinitializer %fadd = fadd nsz %a, %sel @@ -838,9 +836,7 @@ define @fadd_s_sel( %a, %b, %mask) { ; CHECK-LABEL: fadd_s_sel: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z2.s, #0 // =0x0 -; CHECK-NEXT: sel z1.s, p0, z1.s, z2.s -; CHECK-NEXT: fadd z0.s, z0.s, z1.s +; CHECK-NEXT: fadd z0.s, p0/m, z0.s, z1.s ; CHECK-NEXT: ret %sel = select %mask, %b, zeroinitializer %fadd = fadd nsz %a, %sel @@ -850,9 +846,7 @@ define @fadd_d_sel( %a, %b, %mask) { ; CHECK-LABEL: fadd_d_sel: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z2.d, #0 // =0x0 -; CHECK-NEXT: sel z1.d, p0, z1.d, z2.d -; CHECK-NEXT: fadd z0.d, z0.d, z1.d +; CHECK-NEXT: fadd z0.d, p0/m, z0.d, z1.d ; CHECK-NEXT: ret %sel = select %mask, %b, zeroinitializer %fadd = fadd nsz %a, %sel @@ -862,9 +856,7 @@ define @fsub_h_sel( %a, %b, %mask) { ; CHECK-LABEL: fsub_h_sel: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z2.h, #0 // =0x0 -; CHECK-NEXT: sel z1.h, p0, z1.h, z2.h -; CHECK-NEXT: fsub z0.h, z0.h, z1.h +; CHECK-NEXT: fsub z0.h, p0/m, z0.h, z1.h ; CHECK-NEXT: ret %sel = select %mask, %b, zeroinitializer %fsub = fsub %a, %sel @@ -874,9 +866,7 @@ define @fsub_s_sel( %a, %b, %mask) { ; CHECK-LABEL: fsub_s_sel: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z2.s, #0 // =0x0 -; CHECK-NEXT: sel z1.s, p0, z1.s, z2.s -; CHECK-NEXT: fsub z0.s, z0.s, z1.s +; CHECK-NEXT: fsub z0.s, p0/m, z0.s, z1.s ; CHECK-NEXT: ret %sel = select %mask, %b, zeroinitializer %fsub = fsub %a, %sel @@ -886,9 +876,7 @@ define @fsub_d_sel( %a, %b, %mask) { ; CHECK-LABEL: fsub_d_sel: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z2.d, #0 // =0x0 -; CHECK-NEXT: sel z1.d, p0, z1.d, z2.d -; CHECK-NEXT: fsub z0.d, z0.d, z1.d +; CHECK-NEXT: fsub z0.d, p0/m, z0.d, z1.d ; CHECK-NEXT: ret %sel = select %mask, %b, zeroinitializer %fsub = fsub %a, %sel @@ -898,10 +886,7 @@ define @fadd_sel_fmul_h( %a, %b, %c, %mask) { ; CHECK-LABEL: fadd_sel_fmul_h: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z3.h, #0 // =0x0 -; CHECK-NEXT: fmul z1.h, z1.h, z2.h -; CHECK-NEXT: sel z1.h, p0, z1.h, z3.h -; CHECK-NEXT: fadd z0.h, z0.h, z1.h +; CHECK-NEXT: fmla z0.h, p0/m, z1.h, z2.h ; CHECK-NEXT: ret %fmul = fmul %b, %c %sel = select %mask, %fmul, zeroinitializer @@ -912,10 +897,7 @@ define @fadd_sel_fmul_s( %a, %b, %c, %mask) { ; CHECK-LABEL: fadd_sel_fmul_s: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z3.s, #0 // =0x0 -; CHECK-NEXT: fmul z1.s, z1.s, z2.s -; CHECK-NEXT: sel z1.s, p0, z1.s, z3.s -; CHECK-NEXT: fadd z0.s, z0.s, z1.s +; CHECK-NEXT: fmla z0.s, p0/m, z1.s, z2.s ; CHECK-NEXT: ret %fmul = fmul %b, %c %sel = select %mask, %fmul, zeroinitializer @@ -926,10 +908,7 @@ define @fadd_sel_fmul_d( %a, %b, %c, %mask) { ; CHECK-LABEL: fadd_sel_fmul_d: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z3.d, #0 // =0x0 -; CHECK-NEXT: fmul z1.d, z1.d, z2.d -; CHECK-NEXT: sel z1.d, p0, z1.d, z3.d -; CHECK-NEXT: fadd z0.d, z0.d, z1.d +; CHECK-NEXT: fmla z0.d, p0/m, z1.d, z2.d ; CHECK-NEXT: ret %fmul = fmul %b, %c %sel = select %mask, %fmul, zeroinitializer @@ -940,10 +919,7 @@ define @fsub_sel_fmul_h( %a, %b, %c, %mask) { ; CHECK-LABEL: fsub_sel_fmul_h: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z3.h, #0 // =0x0 -; CHECK-NEXT: fmul z1.h, z1.h, z2.h -; CHECK-NEXT: sel z1.h, p0, z1.h, z3.h -; CHECK-NEXT: fsub z0.h, z0.h, z1.h +; CHECK-NEXT: fmls z0.h, p0/m, z1.h, z2.h ; CHECK-NEXT: ret %fmul = fmul %b, %c %sel = select %mask, %fmul, zeroinitializer @@ -954,10 +930,7 @@ define @fsub_sel_fmul_s( %a, %b, %c, %mask) { ; CHECK-LABEL: fsub_sel_fmul_s: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z3.s, #0 // =0x0 -; CHECK-NEXT: fmul z1.s, z1.s, z2.s -; CHECK-NEXT: sel z1.s, p0, z1.s, z3.s -; CHECK-NEXT: fsub z0.s, z0.s, z1.s +; CHECK-NEXT: fmls z0.s, p0/m, z1.s, z2.s ; CHECK-NEXT: ret %fmul = fmul %b, %c %sel = select %mask, %fmul, zeroinitializer @@ -968,13 +941,38 @@ define @fsub_sel_fmul_d( %a, %b, %c, %mask) { ; CHECK-LABEL: fsub_sel_fmul_d: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z3.d, #0 // =0x0 -; CHECK-NEXT: fmul z1.d, z1.d, z2.d -; CHECK-NEXT: sel z1.d, p0, z1.d, z3.d -; CHECK-NEXT: fsub z0.d, z0.d, z1.d +; CHECK-NEXT: fmls z0.d, p0/m, z1.d, z2.d ; CHECK-NEXT: ret %fmul = fmul %b, %c %sel = select %mask, %fmul, zeroinitializer %fsub = fsub contract %a, %sel ret %fsub } + +; Verify combine requires contract fast-math flag. +define @fadd_sel_fmul_no_contract_s( %a, %b, %c, %mask) { +; CHECK-LABEL: fadd_sel_fmul_no_contract_s: +; CHECK: // %bb.0: +; CHECK-NEXT: fmul z1.s, z1.s, z2.s +; CHECK-NEXT: fadd z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %fmul = fmul %b, %c + %sel = select %mask, %fmul, zeroinitializer + %fadd = fadd nsz %a, %sel + ret %fadd +} + +; Verify combine requires no-signed zeros fast-math flag. +define @fadd_sel_fmul_no_nsz_s( %a, %b, %c, %mask) { +; CHECK-LABEL: fadd_sel_fmul_no_nsz_s: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z3.s, #0 // =0x0 +; CHECK-NEXT: fmul z1.s, z1.s, z2.s +; CHECK-NEXT: sel z1.s, p0, z1.s, z3.s +; CHECK-NEXT: fadd z0.s, z0.s, z1.s +; CHECK-NEXT: ret + %fmul = fmul %b, %c + %sel = select %mask, %fmul, zeroinitializer + %fadd = fadd contract %a, %sel + ret %fadd +}