Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -340,6 +340,24 @@ return N->hasOneUse(); }]>; +def AArch64mul_p_firstOpndWithSingleUse : PatFrag<(ops node:$pred, node:$src1, node:$src2), + (AArch64mul_p node:$pred, node:$src1, node:$src2), [{ + bool rootHasSingleUse = N->hasOneUse(); + + SDNode* Op1 = N->getOperand(1).getNode(); + + unsigned nUses = 0; + for(SDNode* use: Op1->uses()) + { + // TOFIX: How do you check for other opcodes which dont actually + // translate into real instructions post-isel? + if(use->getOpcode() == ISD::TokenFactor) + continue; + nUses++; + } + return (rootHasSingleUse && nUses<=1) ; +}]>; + def AArch64fmul_p_oneuse : PatFrag<(ops node:$pred, node:$src1, node:$src2), (AArch64fmul_p node:$pred, node:$src1, node:$src2), [{ return N->hasOneUse(); @@ -388,6 +406,9 @@ (add node:$op1, (AArch64mul_p_oneuse node:$pred, node:$op2, node:$op3)), // add(a, select(mask, mul(b, c), splat(0))) -> mla(a, mask, b, c) (add node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))]>; +def AArch64mad_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3), + [(int_aarch64_sve_mad node:$pred, node:$op1, node:$op2, node:$op3), + (add node:$op3, (AArch64mul_p_firstOpndWithSingleUse node:$pred, node:$op1, node:$op2))]>; def AArch64mls_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3), [(int_aarch64_sve_mls node:$pred, node:$op1, node:$op2, node:$op3), (sub node:$op1, (AArch64mul_p_oneuse node:$pred, node:$op2, node:$op3)), @@ -467,7 +488,7 @@ defm SQSUB_ZI : sve_int_arith_imm0<0b110, "sqsub", ssubsat>; defm UQSUB_ZI : sve_int_arith_imm0<0b111, "uqsub", usubsat>; - defm MAD_ZPmZZ : sve_int_mladdsub_vvv_pred<0b0, "mad", int_aarch64_sve_mad>; + defm MAD_ZPmZZ : sve_int_mladdsub_vvv_pred<0b0, "mad", AArch64mad_m1>; defm MSB_ZPmZZ : sve_int_mladdsub_vvv_pred<0b1, "msb", int_aarch64_sve_msb>; defm MLA_ZPmZZ : sve_int_mlas_vvv_pred<0b0, "mla", AArch64mla_m1>; defm MLS_ZPmZZ : sve_int_mlas_vvv_pred<0b1, "mls", AArch64mls_m1>; Index: llvm/lib/Target/AArch64/SVEInstrFormats.td =================================================================== --- llvm/lib/Target/AArch64/SVEInstrFormats.td +++ llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -480,6 +480,15 @@ : Pat<(vtd (op vt1:$Op1, vt2:$Op2, vt3:$Op3, vt4:$Op4)), (inst $Op1, $Op2, $Op3, $Op4)>; +// Not sure why this pattern doesn't kick in with value<9 +let AddedComplexity = 9 in { +class SVE_4_Mad_Op_Pat +: Pat<(vtd (op vt1:$Op1, vt2:$Op2, vt3:$Op3, vt4:$Op4)), + (inst $Op1, $Op2, $Op3, $Op4)>; +} + class SVE_2_Op_Imm_Pat : Pat<(vtd (op vt1:$Op1, (vt2 ImmTy:$Op2))), @@ -3103,10 +3112,10 @@ def _S : sve_int_mladdsub_vvv_pred<0b10, opc, asm, ZPR32>; def _D : sve_int_mladdsub_vvv_pred<0b11, opc, asm, ZPR64>; - def : SVE_4_Op_Pat(NAME # _B)>; - def : SVE_4_Op_Pat(NAME # _H)>; - def : SVE_4_Op_Pat(NAME # _S)>; - def : SVE_4_Op_Pat(NAME # _D)>; + def : SVE_4_Mad_Op_Pat(NAME # _B)>; + def : SVE_4_Mad_Op_Pat(NAME # _H)>; + def : SVE_4_Mad_Op_Pat(NAME # _S)>; + def : SVE_4_Mad_Op_Pat(NAME # _D)>; } class sve_int_mlas_vvv_pred sz8_64, bits<1> opc, string asm, Index: llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll +++ llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll @@ -72,19 +72,18 @@ ; CHECK-LABEL: scatter_f16_index_offset_var: ; CHECK: // %bb.0: ; CHECK-NEXT: index z1.d, #0, #1 -; CHECK-NEXT: mov z3.d, x1 -; CHECK-NEXT: mov z2.d, z1.d -; CHECK-NEXT: mov z4.d, z3.d ; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: mov z2.d, z1.d +; CHECK-NEXT: mov z3.d, x1 ; CHECK-NEXT: incd z2.d +; CHECK-NEXT: mad z2.d, p1/m, z3.d, z3.d ; CHECK-NEXT: mla z3.d, p1/m, z1.d, z3.d -; CHECK-NEXT: mla z4.d, p1/m, z2.d, z4.d ; CHECK-NEXT: punpklo p1.h, p0.b ; CHECK-NEXT: uunpklo z1.d, z0.s ; CHECK-NEXT: punpkhi p0.h, p0.b ; CHECK-NEXT: uunpkhi z0.d, z0.s ; CHECK-NEXT: st1h { z1.d }, p1, [x0, z3.d, lsl #1] -; CHECK-NEXT: st1h { z0.d }, p0, [x0, z4.d, lsl #1] +; CHECK-NEXT: st1h { z0.d }, p0, [x0, z2.d, lsl #1] ; CHECK-NEXT: ret %t0 = insertelement undef, i64 %offset, i32 0 %t1 = shufflevector %t0, undef, zeroinitializer Index: llvm/test/CodeGen/AArch64/sve-gep.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-gep.ll +++ llvm/test/CodeGen/AArch64/sve-gep.ll @@ -227,7 +227,8 @@ ; CHECK-NEXT: rdvl x8, #1 ; CHECK-NEXT: sxtw z1.d, p0/m, z1.d ; CHECK-NEXT: mov z2.d, x8 -; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: mad z1.d, p0/m, z2.d, z0.d +; CHECK-NEXT: mov z0.d, z1.d ; CHECK-NEXT: ret %d = getelementptr , *> %base, %idx ret *> %d Index: llvm/test/CodeGen/AArch64/sve-int-arith.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-int-arith.ll +++ llvm/test/CodeGen/AArch64/sve-int-arith.ll @@ -337,12 +337,11 @@ ret %res } -define @mla_i8( %a, %b, %c) { -; CHECK-LABEL: mla_i8: +define @mad_i8( %a, %b, %c) { +; CHECK-LABEL: mad_i8: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.b -; CHECK-NEXT: mla z2.b, p0/m, z0.b, z1.b -; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: mad z0.b, p0/m, z1.b, z2.b ; CHECK-NEXT: ret %prod = mul %a, %b %res = add %c, %prod Index: llvm/test/CodeGen/AArch64/sve-multiply-add-accumulate.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-multiply-add-accumulate.ll @@ -0,0 +1,107 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-unknown-linux-gnu -mattr=+sve < %s | FileCheck %s + +define @muladd_i64( %a, %b) +; CHECK-LABEL: muladd_i64: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: mov z2.d, #1 // =0x1 +; CHECK-NEXT: mad z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i64 1, i64 0), poison, zeroinitializer) + ret %2 +} + +define @muladd_i32( %a, %b) +; CHECK-LABEL: muladd_i32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: mov z2.s, #1 // =0x1 +; CHECK-NEXT: mad z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer) + ret %2 +} + +define @muladd_i16_test1( %a, %b) +; CHECK-LABEL: muladd_i16_test1: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: mov z2.h, #128 // =0x80 +; CHECK-NEXT: mad z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i16 128, i16 0), poison, zeroinitializer) + ret %2 +} + +; 'a' and 'b' both have more than single use. Avoid using MAD +; TOFIX: MLA can be invoked for the first mul-add +define @muladd_i16_test2( %a, %b) +; CHECK-LABEL: muladd_i16_test2: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: movprfx z2, z0 +; CHECK-NEXT: mul z2.h, p0/m, z2.h, z1.h +; CHECK-NEXT: add z2.h, z2.h, #200 // =0xc8 +; CHECK-NEXT: mul z0.h, p0/m, z0.h, z2.h +; CHECK-NEXT: sub z0.h, z0.h, z1.h +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i16 200, i16 0), poison, zeroinitializer) + %3 = mul %2, %a + %4 = sub %3, %b + ret %4 +} + +; TOFIX: 'a' has more than one use. But, can use register for 'b' for MAD result computation +define @muladd_i16_test3( %a, %b) +; CHECK-LABEL: muladd_i16_test3: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: mul z1.h, p0/m, z1.h, z0.h +; CHECK-NEXT: add z1.h, z1.h, #200 // =0xc8 +; CHECK-NEXT: mul z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i16 200, i16 0), poison, zeroinitializer) + %3 = mul %2, %a + ret %3 +} + +; MAD result can be computed in register for 'a' +define @muladd_i16_test4( %a, %b) +; CHECK-LABEL: muladd_i16_test4: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #200 +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: mov z2.h, w8 +; CHECK-NEXT: mad z0.h, p0/m, z1.h, z2.h +; CHECK-NEXT: mul z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i16 200, i16 0), poison, zeroinitializer) + %3 = mul %2, %b + ret %3 +} + +define @muladd_i8( %a, %b) +; CHECK-LABEL: muladd_i8: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.b +; CHECK-NEXT: mov z2.b, #-15 // =0xfffffffffffffff1 +; CHECK-NEXT: mad z0.b, p0/m, z1.b, z2.b +; CHECK-NEXT: ret +{ + %1 = mul %a, %b + %2 = add %1, shufflevector ( insertelement ( poison, i8 -15, i8 0), poison, zeroinitializer) + ret %2 +}