Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -384,10 +384,7 @@ [(int_aarch64_sve_sub node:$pred, node:$op1, node:$op2), (sub node:$op1, (vselect node:$pred, node:$op2, (SVEDup0)))]>; def AArch64mla_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3), - [(int_aarch64_sve_mla node:$pred, node:$op1, node:$op2, node:$op3), - (add node:$op1, (AArch64mul_p_oneuse node:$pred, node:$op2, node:$op3)), - // add(a, select(mask, mul(b, c), splat(0))) -> mla(a, mask, b, c) - (add node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))]>; + [(int_aarch64_sve_mla node:$pred, node:$op1, node:$op2, node:$op3)]>; def AArch64mls_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3), [(int_aarch64_sve_mls node:$pred, node:$op1, node:$op2, node:$op3), (sub node:$op1, (AArch64mul_p_oneuse node:$pred, node:$op2, node:$op3)), @@ -467,10 +464,30 @@ defm SQSUB_ZI : sve_int_arith_imm0<0b110, "sqsub", ssubsat>; defm UQSUB_ZI : sve_int_arith_imm0<0b111, "uqsub", usubsat>; - defm MAD_ZPmZZ : sve_int_mladdsub_vvv_pred<0b0, "mad", int_aarch64_sve_mad>; - defm MSB_ZPmZZ : sve_int_mladdsub_vvv_pred<0b1, "msb", int_aarch64_sve_msb>; - defm MLA_ZPmZZ : sve_int_mlas_vvv_pred<0b0, "mla", AArch64mla_m1>; - defm MLS_ZPmZZ : sve_int_mlas_vvv_pred<0b1, "mls", AArch64mls_m1>; + defm MAD_ZPmZZ : sve_int_mladdsub_vvv_pred<0b0, "mad", int_aarch64_sve_mad, "MLA_ZPmZZ", /*isReverseInstr*/ 1>; + defm MSB_ZPmZZ : sve_int_mladdsub_vvv_pred<0b1, "msb", int_aarch64_sve_msb, "MSB_ZPmZZ", /*isReverseInstr*/ 1>; + defm MLA_ZPmZZ : sve_int_mlas_vvv_pred<0b0, "mla", AArch64mla_m1, "MLA_ZPZZZ", "MAD_ZPmZZ">; + defm MLS_ZPmZZ : sve_int_mlas_vvv_pred<0b1, "mls", AArch64mls_m1, "MSB_ZPZZZ", "MSB_ZPmZZ">; + + defm MLA_ZPZZZ : sve_int_mladdsub_pseudo ; + + multiclass sve_int_mladdsub { + + let AddedComplexity = 9 in { + // Zd = Za + Zn * Zm + def : Pat<(add (AArch64mul_p_oneuse PredTy:$P, Ty:$Zn, Ty:$Zm), Ty:$Za), + (!cast("MLA_ZPZZZ_UNDEF_"#Suffix) $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + + // add(a, select(mask, mul(b, c), splat(0))) -> mla(a, mask, b, c) + def : Pat<(add (vselect PredTy:$P, (AArch64mul_p_oneuse (SVEAllActive), Ty:$Zn, Ty:$Zm), (SVEDup0)), Ty:$Za), + (!cast("MLA_ZPZZZ_UNDEF_"#Suffix) $P, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + } + + } + defm: sve_int_mladdsub; + defm: sve_int_mladdsub; + defm: sve_int_mladdsub; + defm: sve_int_mladdsub; // SVE predicated integer reductions. defm SADDV_VPZ : sve_int_reduce_0_saddv<0b000, "saddv", AArch64saddv_p>; Index: llvm/lib/Target/AArch64/SVEInstrFormats.td =================================================================== --- llvm/lib/Target/AArch64/SVEInstrFormats.td +++ llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -3129,11 +3129,16 @@ let hasSideEffects = 0; } -multiclass sve_int_mladdsub_vvv_pred opc, string asm, SDPatternOperator op> { - def _B : sve_int_mladdsub_vvv_pred<0b00, opc, asm, ZPR8>; - def _H : sve_int_mladdsub_vvv_pred<0b01, opc, asm, ZPR16>; - def _S : sve_int_mladdsub_vvv_pred<0b10, opc, asm, ZPR32>; - def _D : sve_int_mladdsub_vvv_pred<0b11, opc, asm, ZPR64>; +multiclass sve_int_mladdsub_vvv_pred opc, string asm, SDPatternOperator op, + string revname, bit isReverseInstr=0> { + def _B : sve_int_mladdsub_vvv_pred<0b00, opc, asm, ZPR8>, + SVEInstr2Rev; + def _H : sve_int_mladdsub_vvv_pred<0b01, opc, asm, ZPR16>, + SVEInstr2Rev; + def _S : sve_int_mladdsub_vvv_pred<0b10, opc, asm, ZPR32>, + SVEInstr2Rev; + def _D : sve_int_mladdsub_vvv_pred<0b11, opc, asm, ZPR64>, + SVEInstr2Rev; def : SVE_4_Op_Pat(NAME # _B)>; def : SVE_4_Op_Pat(NAME # _H)>; @@ -3162,16 +3167,21 @@ let Inst{4-0} = Zda; let Constraints = "$Zda = $_Zda"; - let DestructiveInstType = DestructiveOther; + let DestructiveInstType = DestructiveTernaryCommWithRev; let ElementSize = zprty.ElementSize; let hasSideEffects = 0; } -multiclass sve_int_mlas_vvv_pred opc, string asm, SDPatternOperator op> { - def _B : sve_int_mlas_vvv_pred<0b00, opc, asm, ZPR8>; - def _H : sve_int_mlas_vvv_pred<0b01, opc, asm, ZPR16>; - def _S : sve_int_mlas_vvv_pred<0b10, opc, asm, ZPR32>; - def _D : sve_int_mlas_vvv_pred<0b11, opc, asm, ZPR64>; +multiclass sve_int_mlas_vvv_pred opc, string asm, SDPatternOperator op, + string Ps, string revname, bit isReverseInstr=0> { + def _B : sve_int_mlas_vvv_pred<0b00, opc, asm, ZPR8>, + SVEPseudo2Instr, SVEInstr2Rev; + def _H : sve_int_mlas_vvv_pred<0b01, opc, asm, ZPR16>, + SVEPseudo2Instr, SVEInstr2Rev; + def _S : sve_int_mlas_vvv_pred<0b10, opc, asm, ZPR32>, + SVEPseudo2Instr, SVEInstr2Rev; + def _D : sve_int_mlas_vvv_pred<0b11, opc, asm, ZPR64>, + SVEPseudo2Instr, SVEInstr2Rev; def : SVE_4_Op_Pat(NAME # _B)>; def : SVE_4_Op_Pat(NAME # _H)>; @@ -3179,6 +3189,13 @@ def : SVE_4_Op_Pat(NAME # _D)>; } +multiclass sve_int_mladdsub_pseudo { + def _UNDEF_B : PredThreeOpPseudo; + def _UNDEF_H : PredThreeOpPseudo; + def _UNDEF_S : PredThreeOpPseudo; + def _UNDEF_D : PredThreeOpPseudo; +} + //===----------------------------------------------------------------------===// // SVE2 Integer Multiply-Add - Unpredicated Group //===----------------------------------------------------------------------===// Index: llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll +++ llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll @@ -72,19 +72,18 @@ ; CHECK-LABEL: scatter_f16_index_offset_var: ; CHECK: // %bb.0: ; CHECK-NEXT: index z1.d, #0, #1 -; CHECK-NEXT: mov z3.d, x1 -; CHECK-NEXT: mov z2.d, z1.d -; CHECK-NEXT: mov z4.d, z3.d ; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: mov z2.d, z1.d +; CHECK-NEXT: mov z3.d, x1 ; CHECK-NEXT: incd z2.d -; CHECK-NEXT: mla z3.d, p1/m, z1.d, z3.d -; CHECK-NEXT: mla z4.d, p1/m, z2.d, z4.d +; CHECK-NEXT: mad z1.d, p1/m, z3.d, z3.d +; CHECK-NEXT: mad z2.d, p1/m, z3.d, z3.d ; CHECK-NEXT: punpklo p1.h, p0.b -; CHECK-NEXT: uunpklo z1.d, z0.s +; CHECK-NEXT: uunpklo z3.d, z0.s ; CHECK-NEXT: punpkhi p0.h, p0.b ; CHECK-NEXT: uunpkhi z0.d, z0.s -; CHECK-NEXT: st1h { z1.d }, p1, [x0, z3.d, lsl #1] -; CHECK-NEXT: st1h { z0.d }, p0, [x0, z4.d, lsl #1] +; CHECK-NEXT: st1h { z3.d }, p1, [x0, z1.d, lsl #1] +; CHECK-NEXT: st1h { z0.d }, p0, [x0, z2.d, lsl #1] ; CHECK-NEXT: ret %t0 = insertelement undef, i64 %offset, i32 0 %t1 = shufflevector %t0, undef, zeroinitializer Index: llvm/test/CodeGen/AArch64/sve-int-arith.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-int-arith.ll +++ llvm/test/CodeGen/AArch64/sve-int-arith.ll @@ -337,12 +337,11 @@ ret %res } -define @mla_i8( %a, %b, %c) { -; CHECK-LABEL: mla_i8: +define @mad_i8( %a, %b, %c) { +; CHECK-LABEL: mad_i8: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.b -; CHECK-NEXT: mla z2.b, p0/m, z0.b, z1.b -; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: mad z0.b, p0/m, z1.b, z2.b ; CHECK-NEXT: ret %prod = mul %a, %b %res = add %c, %prod Index: llvm/test/CodeGen/AArch64/sve-multiply-add-accumulate.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-multiply-add-accumulate.ll @@ -0,0 +1,234 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-unknown-linux-gnu -mattr=+sve < %s | FileCheck %s + +define @muladd_i64_positiveAddend( %a, %b) +; CHECK-LABEL: muladd_i64_positiveAddend: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: mov z2.d, #0xffffffff +; CHECK-NEXT: mad z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i64 4294967295, i64 0), poison, zeroinitializer) + ret %2 +} + +define @muladd_i64_negativeAddend( %a, %b) +; CHECK-LABEL: muladd_i64_negativeAddend: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: mov z2.d, #0xffffffff00000001 +; CHECK-NEXT: mad z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i64 -4294967295, i64 0), poison, zeroinitializer) + ret %2 +} + +define @muladd_i32_positiveAddend( %a, %b) +; CHECK-LABEL: muladd_i32_positiveAddend: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: mov z2.s, #0x10000 +; CHECK-NEXT: mad z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i32 65536, i32 0), poison, zeroinitializer) + ret %2 +} + +define @muladd_i32_negativeAddend( %a, %b) +; CHECK-LABEL: muladd_i32_negativeAddend: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: mov z2.s, #0xffff0000 +; CHECK-NEXT: mad z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i32 -65536, i32 0), poison, zeroinitializer) + ret %2 +} + +define @muladd_i16_positiveAddend( %a, %b) +; CHECK-LABEL: muladd_i16_positiveAddend: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: mov z2.h, #255 // =0xff +; CHECK-NEXT: mad z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i16 255, i16 0), poison, zeroinitializer) + ret %2 +} + +define @muladd_i16_negativeAddend( %a, %b) +; CHECK-LABEL: muladd_i16_negativeAddend: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: mov z2.h, #-255 // =0xffffffffffffff01 +; CHECK-NEXT: mad z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i16 -255, i16 0), poison, zeroinitializer) + ret %2 +} + +define @muladd_i8_positiveAddend( %a, %b) +; CHECK-LABEL: muladd_i8_positiveAddend: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.b +; CHECK-NEXT: mov z2.b, #15 // =0xf +; CHECK-NEXT: mad z0.b, p0/m, z1.b, z2.b +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i8 15, i8 0), poison, zeroinitializer) + ret %2 +} + +define @muladd_i8_negativeAddend( %a, %b) +; CHECK-LABEL: muladd_i8_negativeAddend: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.b +; CHECK-NEXT: mov z2.b, #-15 // =0xfffffffffffffff1 +; CHECK-NEXT: mad z0.b, p0/m, z1.b, z2.b +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i8 -15, i8 0), poison, zeroinitializer) + ret %2 +} + +; both mul operands have a use +define @muladd_generic_test1( %a, %b) +; CHECK-LABEL: muladd_generic_test1: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #200 +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: mov z2.h, w8 +; CHECK-NEXT: mla z2.h, p0/m, z0.h, z1.h +; CHECK-NEXT: mul z0.h, p0/m, z0.h, z2.h +; CHECK-NEXT: sub z0.h, z0.h, z1.h +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i16 200, i16 0), poison, zeroinitializer) + %3 = mul %2, %a + %4 = sub %3, %b + ret %4 +} + +; only the first mul operand has a use +define @muladd_generic_test2( %a, %b) +; CHECK-LABEL: muladd_generic_test2: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #200 +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: mov z2.h, w8 +; CHECK-NEXT: mad z1.h, p0/m, z0.h, z2.h +; CHECK-NEXT: mul z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i16 200, i16 0), poison, zeroinitializer) + %3 = mul %2, %a + ret %3 +} + +; only the second mul operand has a use +define @muladd_generic_test3( %a, %b) +; CHECK-LABEL: muladd_generic_test3: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #200 +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: mov z2.h, w8 +; CHECK-NEXT: mad z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: mul z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i16 200, i16 0), poison, zeroinitializer) + %3 = mul %2, %b + ret %3 +} + +; negative integer splat as one of the addend +define @muladd_generic_test4( %a, %b) +; CHECK-LABEL: muladd_generic_test4: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #-200 +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: mov z2.h, w8 +; CHECK-NEXT: mad z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i16 -200, i16 0), poison, zeroinitializer) + ret %2 +} + +define void @fused_mul_add_in_loop(ptr noalias %a, ptr noalias %b, ptr noalias %c, i32 %n) +; CHECK-LABEL: fused_mul_add_in_loop: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: cmp w3, #1 +; CHECK-NEXT: b.lt .LBB12_3 +; CHECK-NEXT: // %bb.1: // %for.body.preheader +; CHECK-NEXT: mov w9, w3 +; CHECK-NEXT: mov x8, xzr +; CHECK-NEXT: cntw x10 +; CHECK-NEXT: mov z0.s, #1 // =0x1 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: whilelo p1.s, xzr, x9 +; CHECK-NEXT: .LBB12_2: // %vector.body +; CHECK-NEXT: // =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: ld1w { z1.s }, p1/z, [x0, x8, lsl #2] +; CHECK-NEXT: ld1w { z2.s }, p1/z, [x1, x8, lsl #2] +; CHECK-NEXT: mad z1.s, p0/m, z2.s, z0.s +; CHECK-NEXT: st1w { z1.s }, p1, [x2, x8, lsl #2] +; CHECK-NEXT: add x8, x8, x10 +; CHECK-NEXT: whilelo p1.s, x8, x9 +; CHECK-NEXT: b.mi .LBB12_2 +; CHECK-NEXT: .LBB12_3: // %for.cond.cleanup +; CHECK-NEXT: ret +{ +entry: + %cmp9 = icmp sgt i32 %n, 0 + br i1 %cmp9, label %for.body.preheader, label %for.cond.cleanup + +for.body.preheader: ; preds = %entry + %wide.trip.count = zext i32 %n to i64 + %active.lane.mask.entry = tail call @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 %wide.trip.count) + %0 = tail call i64 @llvm.vscale.i64() + %1 = shl nuw nsw i64 %0, 2 + br label %vector.body + +vector.body: ; preds = %vector.body, %for.body.preheader + %index = phi i64 [ 0, %for.body.preheader ], [ %index.next, %vector.body ] + %active.lane.mask = phi [ %active.lane.mask.entry, %for.body.preheader ], [ %active.lane.mask.next, %vector.body ] + %2 = getelementptr inbounds i32, ptr %a, i64 %index + %wide.masked.load = tail call @llvm.masked.load.nxv4i32.p0(ptr %2, i32 4, %active.lane.mask, poison) + %3 = getelementptr inbounds i32, ptr %b, i64 %index + %wide.masked.load12 = tail call @llvm.masked.load.nxv4i32.p0(ptr %3, i32 4, %active.lane.mask, poison) + %4 = mul nsw %wide.masked.load12, %wide.masked.load + %5 = add nsw %4, shufflevector ( insertelement ( poison, i32 1, i64 0), poison, zeroinitializer) + %6 = getelementptr inbounds i32, ptr %c, i64 %index + tail call void @llvm.masked.store.nxv4i32.p0( %5, ptr %6, i32 4, %active.lane.mask) + %index.next = add i64 %index, %1 + %active.lane.mask.next = tail call @llvm.get.active.lane.mask.nxv4i1.i64(i64 %index.next, i64 %wide.trip.count) + %7 = extractelement %active.lane.mask.next, i64 0 + br i1 %7, label %vector.body, label %for.cond.cleanup + +for.cond.cleanup: ; preds = %vector.body, %entry + ret void +} + +declare i64 @llvm.vscale.i64() +declare @llvm.get.active.lane.mask.nxv4i1.i64(i64, i64) +declare @llvm.masked.load.nxv4i32.p0(ptr nocapture, i32 immarg, , ) +declare void @llvm.masked.store.nxv4i32.p0(, ptr nocapture, i32 immarg, )