diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -9895,6 +9895,9 @@ if (GA->getOpcode() == ISD::GlobalAddress && TLI->isOffsetFoldingLegal(GA)) return GA; + if ((N.getOpcode() == ISD::SPLAT_VECTOR) && + isa(N.getOperand(0))) + return N.getNode(); return nullptr; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -921,6 +921,7 @@ for (MVT VT : MVT::integer_scalable_vector_valuetypes()) { if (isTypeLegal(VT)) { setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); + setOperationAction(ISD::MUL, VT, Custom); setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::SDIV, VT, Custom); @@ -3102,7 +3103,7 @@ // If SVE is available then i64 vector multiplications can also be made legal. bool OverrideNEON = VT == MVT::v2i64 || VT == MVT::v1i64; - if (useSVEForFixedLengthVectorVT(VT, OverrideNEON)) + if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT, OverrideNEON)) return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED, OverrideNEON); // Multiplications are only custom-lowered for 128-bit vectors so that diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -291,23 +291,13 @@ defm UMAX_ZI : sve_int_arith_imm1_unsigned<0b01, "umax", AArch64umax_p>; defm UMIN_ZI : sve_int_arith_imm1_unsigned<0b11, "umin", AArch64umin_p>; - defm MUL_ZI : sve_int_arith_imm2<"mul", mul>; + defm MUL_ZI : sve_int_arith_imm2<"mul", AArch64mul_p>; defm MUL_ZPmZ : sve_int_bin_pred_arit_2<0b000, "mul", "MUL_ZPZZ", int_aarch64_sve_mul, DestructiveBinaryComm>; defm SMULH_ZPmZ : sve_int_bin_pred_arit_2<0b010, "smulh", "SMULH_ZPZZ", int_aarch64_sve_smulh, DestructiveBinaryComm>; defm UMULH_ZPmZ : sve_int_bin_pred_arit_2<0b011, "umulh", "UMULH_ZPZZ", int_aarch64_sve_umulh, DestructiveBinaryComm>; defm MUL_ZPZZ : sve_int_bin_pred_bhsd; - // Add unpredicated alternative for the mul instruction. - def : Pat<(mul nxv16i8:$Op1, nxv16i8:$Op2), - (MUL_ZPmZ_B (PTRUE_B 31), $Op1, $Op2)>; - def : Pat<(mul nxv8i16:$Op1, nxv8i16:$Op2), - (MUL_ZPmZ_H (PTRUE_H 31), $Op1, $Op2)>; - def : Pat<(mul nxv4i32:$Op1, nxv4i32:$Op2), - (MUL_ZPmZ_S (PTRUE_S 31), $Op1, $Op2)>; - def : Pat<(mul nxv2i64:$Op1, nxv2i64:$Op2), - (MUL_ZPmZ_D (PTRUE_D 31), $Op1, $Op2)>; - defm SDIV_ZPmZ : sve_int_bin_pred_arit_2_div<0b100, "sdiv", "SDIV_ZPZZ", int_aarch64_sve_sdiv, DestructiveBinaryCommWithRev, "SDIVR_ZPmZ">; defm UDIV_ZPmZ : sve_int_bin_pred_arit_2_div<0b101, "udiv", "UDIV_ZPZZ", int_aarch64_sve_udiv, DestructiveBinaryCommWithRev, "UDIVR_ZPmZ">; defm SDIVR_ZPmZ : sve_int_bin_pred_arit_2_div<0b110, "sdivr", "SDIVR_ZPZZ", int_aarch64_sve_sdivr, DestructiveBinaryCommWithRev, "SDIV_ZPmZ", /*isReverseInstr*/ 1>; @@ -2227,10 +2217,10 @@ defm SQRDMULH_ZZZ : sve2_int_mul<0b101, "sqrdmulh", int_aarch64_sve_sqrdmulh>; // SVE2 integer multiply vectors (unpredicated) - defm MUL_ZZZ : sve2_int_mul<0b000, "mul", mul>; + defm MUL_ZZZ : sve2_int_mul<0b000, "mul", null_frag, AArch64mul_p>; defm SMULH_ZZZ : sve2_int_mul<0b010, "smulh", null_frag>; defm UMULH_ZZZ : sve2_int_mul<0b011, "umulh", null_frag>; - defm PMUL_ZZZ : sve2_int_mul_single<0b001, "pmul", int_aarch64_sve_pmul>; + defm PMUL_ZZZ : sve2_int_mul_single<0b001, "pmul", int_aarch64_sve_pmul>; // Add patterns for unpredicated version of smulh and umulh. def : Pat<(nxv16i8 (int_aarch64_sve_smulh (nxv16i1 (AArch64ptrue 31)), nxv16i8:$Op1, nxv16i8:$Op2)), diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -315,11 +315,6 @@ : Pat<(vt (op (vt zprty:$Op1), (vt (AArch64dup (it (cpx i32:$imm, i32:$shift)))))), (inst $Op1, i32:$imm, i32:$shift)>; -class SVE_1_Op_Imm_Arith_Pat - : Pat<(vt (op (vt zprty:$Op1), (vt (AArch64dup (it (cpx i32:$imm)))))), - (inst $Op1, i32:$imm)>; - class SVE_1_Op_Imm_Shift_Pred_Pat : Pat<(vt (op (pt (AArch64ptrue 31)), (vt zprty:$Op1), (vt (AArch64dup (ImmTy:$imm))))), @@ -2867,7 +2862,8 @@ let Inst{4-0} = Zd; } -multiclass sve2_int_mul opc, string asm, SDPatternOperator op> { +multiclass sve2_int_mul opc, string asm, SDPatternOperator op, + SDPatternOperator op_pred = null_frag> { def _B : sve2_int_mul<0b00, opc, asm, ZPR8>; def _H : sve2_int_mul<0b01, opc, asm, ZPR16>; def _S : sve2_int_mul<0b10, opc, asm, ZPR32>; @@ -2877,6 +2873,11 @@ def : SVE_2_Op_Pat(NAME # _H)>; def : SVE_2_Op_Pat(NAME # _S)>; def : SVE_2_Op_Pat(NAME # _D)>; + + def : SVE_2_Op_Pred_All_Active(NAME # _B)>; + def : SVE_2_Op_Pred_All_Active(NAME # _H)>; + def : SVE_2_Op_Pred_All_Active(NAME # _S)>; + def : SVE_2_Op_Pred_All_Active(NAME # _D)>; } multiclass sve2_int_mul_single opc, string asm, SDPatternOperator op> { @@ -3914,10 +3915,10 @@ def _S : sve_int_arith_imm<0b10, 0b110000, asm, ZPR32, simm8>; def _D : sve_int_arith_imm<0b11, 0b110000, asm, ZPR64, simm8>; - def : SVE_1_Op_Imm_Arith_Pat(NAME # _B)>; - def : SVE_1_Op_Imm_Arith_Pat(NAME # _H)>; - def : SVE_1_Op_Imm_Arith_Pat(NAME # _S)>; - def : SVE_1_Op_Imm_Arith_Pat(NAME # _D)>; + def : SVE_1_Op_Imm_Arith_Pred_Pat(NAME # _B)>; + def : SVE_1_Op_Imm_Arith_Pred_Pat(NAME # _H)>; + def : SVE_1_Op_Imm_Arith_Pred_Pat(NAME # _S)>; + def : SVE_1_Op_Imm_Arith_Pred_Pat(NAME # _D)>; } //===----------------------------------------------------------------------===// diff --git a/llvm/test/CodeGen/AArch64/llvm-ir-to-intrinsic.ll b/llvm/test/CodeGen/AArch64/llvm-ir-to-intrinsic.ll --- a/llvm/test/CodeGen/AArch64/llvm-ir-to-intrinsic.ll +++ b/llvm/test/CodeGen/AArch64/llvm-ir-to-intrinsic.ll @@ -131,8 +131,8 @@ ; CHECK-NEXT: uzp1 z3.h, z4.h, z3.h ; CHECK-NEXT: uzp1 z2.b, z3.b, z2.b ; CHECK-NEXT: ptrue p0.b -; CHECK-NEXT: mul z2.b, p0/m, z2.b, z1.b -; CHECK-NEXT: sub z0.b, z0.b, z2.b +; CHECK-NEXT: mul z1.b, p0/m, z1.b, z2.b +; CHECK-NEXT: sub z0.b, z0.b, z1.b ; CHECK-NEXT: ret %div = srem %a, %b ret %div @@ -151,8 +151,8 @@ ; CHECK-NEXT: sdiv z3.s, p0/m, z3.s, z4.s ; CHECK-NEXT: uzp1 z2.h, z3.h, z2.h ; CHECK-NEXT: ptrue p0.h -; CHECK-NEXT: mul z2.h, p0/m, z2.h, z1.h -; CHECK-NEXT: sub z0.h, z0.h, z2.h +; CHECK-NEXT: mul z1.h, p0/m, z1.h, z2.h +; CHECK-NEXT: sub z0.h, z0.h, z1.h ; CHECK-NEXT: ret %div = srem %a, %b ret %div @@ -164,8 +164,8 @@ ; CHECK-NEXT: ptrue p0.s ; CHECK-NEXT: movprfx z2, z0 ; CHECK-NEXT: sdiv z2.s, p0/m, z2.s, z1.s -; CHECK-NEXT: mul z2.s, p0/m, z2.s, z1.s -; CHECK-NEXT: sub z0.s, z0.s, z2.s +; CHECK-NEXT: mul z1.s, p0/m, z1.s, z2.s +; CHECK-NEXT: sub z0.s, z0.s, z1.s ; CHECK-NEXT: ret %div = srem %a, %b ret %div @@ -177,8 +177,8 @@ ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: movprfx z2, z0 ; CHECK-NEXT: sdiv z2.d, p0/m, z2.d, z1.d -; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d -; CHECK-NEXT: sub z0.d, z0.d, z2.d +; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d +; CHECK-NEXT: sub z0.d, z0.d, z1.d ; CHECK-NEXT: ret %div = srem %a, %b ret %div @@ -315,8 +315,8 @@ ; CHECK-NEXT: uzp1 z3.h, z4.h, z3.h ; CHECK-NEXT: uzp1 z2.b, z3.b, z2.b ; CHECK-NEXT: ptrue p0.b -; CHECK-NEXT: mul z2.b, p0/m, z2.b, z1.b -; CHECK-NEXT: sub z0.b, z0.b, z2.b +; CHECK-NEXT: mul z1.b, p0/m, z1.b, z2.b +; CHECK-NEXT: sub z0.b, z0.b, z1.b ; CHECK-NEXT: ret %div = urem %a, %b ret %div @@ -335,8 +335,8 @@ ; CHECK-NEXT: udiv z3.s, p0/m, z3.s, z4.s ; CHECK-NEXT: uzp1 z2.h, z3.h, z2.h ; CHECK-NEXT: ptrue p0.h -; CHECK-NEXT: mul z2.h, p0/m, z2.h, z1.h -; CHECK-NEXT: sub z0.h, z0.h, z2.h +; CHECK-NEXT: mul z1.h, p0/m, z1.h, z2.h +; CHECK-NEXT: sub z0.h, z0.h, z1.h ; CHECK-NEXT: ret %div = urem %a, %b ret %div @@ -348,8 +348,8 @@ ; CHECK-NEXT: ptrue p0.s ; CHECK-NEXT: movprfx z2, z0 ; CHECK-NEXT: udiv z2.s, p0/m, z2.s, z1.s -; CHECK-NEXT: mul z2.s, p0/m, z2.s, z1.s -; CHECK-NEXT: sub z0.s, z0.s, z2.s +; CHECK-NEXT: mul z1.s, p0/m, z1.s, z2.s +; CHECK-NEXT: sub z0.s, z0.s, z1.s ; CHECK-NEXT: ret %div = urem %a, %b ret %div @@ -361,8 +361,8 @@ ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: movprfx z2, z0 ; CHECK-NEXT: udiv z2.d, p0/m, z2.d, z1.d -; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d -; CHECK-NEXT: sub z0.d, z0.d, z2.d +; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d +; CHECK-NEXT: sub z0.d, z0.d, z1.d ; CHECK-NEXT: ret %div = urem %a, %b ret %div