diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -345,6 +345,16 @@ def AArch64sub_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2), [(int_aarch64_sve_sub node:$pred, node:$op1, node:$op2), (sub node:$op1, (vselect node:$pred, node:$op2, (SVEDup0)))]>; +def AArch64mla_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3), + [(int_aarch64_sve_mla node:$pred, node:$op1, node:$op2, node:$op3), + (add node:$op1, (AArch64mul_p_oneuse node:$pred, node:$op2, node:$op3)), + // add(a, select(mask, mul(b, c), splat(0))) -> mla(a, mask, b, c) + (add node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))]>; +def AArch64mls_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3), + [(int_aarch64_sve_mls node:$pred, node:$op1, node:$op2, node:$op3), + (sub node:$op1, (AArch64mul_p_oneuse node:$pred, node:$op2, node:$op3)), + // sub(a, select(mask, mul(b, c), splat(0))) -> mls(a, mask, b, c) + (sub node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))]>; let Predicates = [HasSVE] in { defm RDFFR_PPz : sve_int_rdffr_pred<0b0, "rdffr", int_aarch64_sve_rdffr_z>; @@ -399,8 +409,8 @@ defm MAD_ZPmZZ : sve_int_mladdsub_vvv_pred<0b0, "mad", int_aarch64_sve_mad>; defm MSB_ZPmZZ : sve_int_mladdsub_vvv_pred<0b1, "msb", int_aarch64_sve_msb>; - defm MLA_ZPmZZ : sve_int_mlas_vvv_pred<0b0, "mla", int_aarch64_sve_mla, add, AArch64mul_p_oneuse>; - defm MLS_ZPmZZ : sve_int_mlas_vvv_pred<0b1, "mls", int_aarch64_sve_mls, sub, AArch64mul_p_oneuse>; + defm MLA_ZPmZZ : sve_int_mlas_vvv_pred<0b0, "mla", AArch64mla_m1>; + defm MLS_ZPmZZ : sve_int_mlas_vvv_pred<0b1, "mls", AArch64mls_m1>; // SVE predicated integer reductions. defm SADDV_VPZ : sve_int_reduce_0_saddv<0b000, "saddv", AArch64saddv_p>; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -2958,8 +2958,7 @@ let ElementSize = zprty.ElementSize; } -multiclass sve_int_mlas_vvv_pred opc, string asm, SDPatternOperator op, - SDPatternOperator outerop, SDPatternOperator mulop> { +multiclass sve_int_mlas_vvv_pred opc, string asm, SDPatternOperator op> { def _B : sve_int_mlas_vvv_pred<0b00, opc, asm, ZPR8>; def _H : sve_int_mlas_vvv_pred<0b01, opc, asm, ZPR16>; def _S : sve_int_mlas_vvv_pred<0b10, opc, asm, ZPR32>; @@ -2969,15 +2968,6 @@ def : SVE_4_Op_Pat(NAME # _H)>; def : SVE_4_Op_Pat(NAME # _S)>; def : SVE_4_Op_Pat(NAME # _D)>; - - def : Pat<(outerop nxv16i8:$Op1, (mulop nxv16i1:$pred, nxv16i8:$Op2, nxv16i8:$Op3)), - (!cast(NAME # _B) $pred, $Op1, $Op2, $Op3)>; - def : Pat<(outerop nxv8i16:$Op1, (mulop nxv8i1:$pred, nxv8i16:$Op2, nxv8i16:$Op3)), - (!cast(NAME # _H) $pred, $Op1, $Op2, $Op3)>; - def : Pat<(outerop nxv4i32:$Op1, (mulop nxv4i1:$pred, nxv4i32:$Op2, nxv4i32:$Op3)), - (!cast(NAME # _S) $pred, $Op1, $Op2, $Op3)>; - def : Pat<(outerop nxv2i64:$Op1, (mulop nxv2i1:$pred, nxv2i64:$Op2, nxv2i64:$Op3)), - (!cast(NAME # _D) $pred, $Op1, $Op2, $Op3)>; } //===----------------------------------------------------------------------===// diff --git a/llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll b/llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll --- a/llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll +++ b/llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll @@ -96,9 +96,7 @@ define @masked_mla_nxv16i8( %a, %b, %c, %mask) { ; CHECK-LABEL: masked_mla_nxv16i8: ; CHECK: // %bb.0: -; CHECK-NEXT: ptrue p1.b -; CHECK-NEXT: mul z1.b, p1/m, z1.b, z2.b -; CHECK-NEXT: add z0.b, p0/m, z0.b, z1.b +; CHECK-NEXT: mla z0.b, p0/m, z1.b, z2.b ; CHECK-NEXT: ret %mul = mul nsw %b, %c %sel = select %mask, %mul, zeroinitializer @@ -109,9 +107,7 @@ define @masked_mla_nxv8i16( %a, %b, %c, %mask) { ; CHECK-LABEL: masked_mla_nxv8i16: ; CHECK: // %bb.0: -; CHECK-NEXT: ptrue p1.h -; CHECK-NEXT: mul z1.h, p1/m, z1.h, z2.h -; CHECK-NEXT: add z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: mla z0.h, p0/m, z1.h, z2.h ; CHECK-NEXT: ret %mul = mul nsw %b, %c %sel = select %mask, %mul, zeroinitializer @@ -122,9 +118,7 @@ define @masked_mla_nxv4i32( %a, %b, %c, %mask) { ; CHECK-LABEL: masked_mla_nxv4i32: ; CHECK: // %bb.0: -; CHECK-NEXT: ptrue p1.s -; CHECK-NEXT: mul z1.s, p1/m, z1.s, z2.s -; CHECK-NEXT: add z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s ; CHECK-NEXT: ret %mul = mul nsw %b, %c %sel = select %mask, %mul, zeroinitializer @@ -135,9 +129,7 @@ define @masked_mla_nxv2i64( %a, %b, %c, %mask) { ; CHECK-LABEL: masked_mla_nxv2i64: ; CHECK: // %bb.0: -; CHECK-NEXT: ptrue p1.d -; CHECK-NEXT: mul z1.d, p1/m, z1.d, z2.d -; CHECK-NEXT: add z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d ; CHECK-NEXT: ret %mul = mul nsw %b, %c %sel = select %mask, %mul, zeroinitializer @@ -152,9 +144,7 @@ define @masked_mls_nxv16i8( %a, %b, %c, %mask) { ; CHECK-LABEL: masked_mls_nxv16i8: ; CHECK: // %bb.0: -; CHECK-NEXT: ptrue p1.b -; CHECK-NEXT: mul z1.b, p1/m, z1.b, z2.b -; CHECK-NEXT: sub z0.b, p0/m, z0.b, z1.b +; CHECK-NEXT: mls z0.b, p0/m, z1.b, z2.b ; CHECK-NEXT: ret %mul = mul nsw %b, %c %sel = select %mask, %mul, zeroinitializer @@ -165,9 +155,7 @@ define @masked_mls_nxv8i16( %a, %b, %c, %mask) { ; CHECK-LABEL: masked_mls_nxv8i16: ; CHECK: // %bb.0: -; CHECK-NEXT: ptrue p1.h -; CHECK-NEXT: mul z1.h, p1/m, z1.h, z2.h -; CHECK-NEXT: sub z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: mls z0.h, p0/m, z1.h, z2.h ; CHECK-NEXT: ret %mul = mul nsw %b, %c %sel = select %mask, %mul, zeroinitializer @@ -178,9 +166,7 @@ define @masked_mls_nxv4i32( %a, %b, %c, %mask) { ; CHECK-LABEL: masked_mls_nxv4i32: ; CHECK: // %bb.0: -; CHECK-NEXT: ptrue p1.s -; CHECK-NEXT: mul z1.s, p1/m, z1.s, z2.s -; CHECK-NEXT: sub z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: mls z0.s, p0/m, z1.s, z2.s ; CHECK-NEXT: ret %mul = mul nsw %b, %c %sel = select %mask, %mul, zeroinitializer @@ -191,9 +177,7 @@ define @masked_mls_nxv2i64( %a, %b, %c, %mask) { ; CHECK-LABEL: masked_mls_nxv2i64: ; CHECK: // %bb.0: -; CHECK-NEXT: ptrue p1.d -; CHECK-NEXT: mul z1.d, p1/m, z1.d, z2.d -; CHECK-NEXT: sub z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d ; CHECK-NEXT: ret %mul = mul nsw %b, %c %sel = select %mask, %mul, zeroinitializer