diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17804,9 +17804,64 @@ return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2); } +// This works on the patterns of: +// add v1, (mul v2, v3) +// sub v1, (mul v2, v3) +// for vectors of type <1 x i64> and <2 x i64> when SVE is available. +// It will transform the add/sub to a scalable version, so that we can +// make use of SVE's MLA/MLS that will be generated for that pattern +static SDValue performMulAddSubCombine(SDNode *N, SelectionDAG &DAG) { + // Before using SVE's features, check first if it's available. + if (!DAG.getSubtarget().hasSVE()) + return SDValue(); + + if (N->getOpcode() != ISD::ADD && N->getOpcode() != ISD::SUB) + return SDValue(); + + if (!N->getValueType(0).isFixedLengthVector()) + return SDValue(); + + SDValue MulValue, Op, ExtractIndexValue, ExtractOp; + + if (N->getOperand(0)->getOpcode() == ISD::EXTRACT_SUBVECTOR) { + ExtractOp = N->getOperand(0); + Op = N->getOperand(1); + } else if (N->getOperand(1)->getOpcode() == ISD::EXTRACT_SUBVECTOR) { + ExtractOp = N->getOperand(1); + Op = N->getOperand(0); + } else + return SDValue(); + + MulValue = ExtractOp.getOperand(0); + ExtractIndexValue = ExtractOp.getOperand(1); + + if (!ExtractOp.hasOneUse() && !MulValue.hasOneUse()) + return SDValue(); + + // If the Opcode is NOT MUL, then that is NOT the expected pattern: + if (MulValue.getOpcode() != AArch64ISD::MUL_PRED) + return SDValue(); + + // If the Mul value type is NOT scalable vector, then that is NOT the expected + // pattern: + EVT VT = MulValue.getValueType(); + if (!VT.isScalableVector()) + return SDValue(); + + // If the ConstValue is NOT 0, then that is NOT the expected pattern: + if (!cast(ExtractIndexValue)->isZero()) + return SDValue(); + + SDValue ScaledOp = convertToScalableVector(DAG, VT, Op); + SDValue NewValue = DAG.getNode(N->getOpcode(), SDLoc(N), VT, {ScaledOp, MulValue}); + return convertFromScalableVector(DAG, N->getValueType(0), NewValue); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + if (SDValue Val = performMulAddSubCombine(N, DAG)) + return Val; // Try to change sum of two reductions. if (SDValue Val = performAddUADDVCombine(N, DAG)) return Val; diff --git a/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll b/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll @@ -0,0 +1,62 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mattr=+sve | FileCheck %s + +define <2 x i64> @test_mul_add_2x64(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c) { +; CHECK-LABEL: test_mul_add_2x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <2 x i64> %b, %c + %add = add <2 x i64> %a, %mul + ret <2 x i64> %add +} + +define <1 x i64> @test_mul_add_1x64(<1 x i64> %a, <1 x i64> %b, <1 x i64> %c) { +; CHECK-LABEL: test_mul_add_1x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl1 +; CHECK-NEXT: // kill: def $d2 killed $d2 def $z2 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <1 x i64> %b, %c + %add = add <1 x i64> %mul, %a + ret <1 x i64> %add +} + +define <2 x i64> @test_mul_sub_2x64(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c) { +; CHECK-LABEL: test_mul_sub_2x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <2 x i64> %b, %c + %sub = sub <2 x i64> %a, %mul + ret <2 x i64> %sub +} + +define <1 x i64> @test_mul_sub_1x64(<1 x i64> %a, <1 x i64> %b, <1 x i64> %c) { +; CHECK-LABEL: test_mul_sub_1x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl1 +; CHECK-NEXT: // kill: def $d2 killed $d2 def $z2 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <1 x i64> %b, %c + %sub = sub <1 x i64> %mul, %a + ret <1 x i64> %sub +} diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll @@ -606,7 +606,7 @@ ; ; VBITS_GE_256-LABEL: srem_v16i32: ; VBITS_GE_256: // %bb.0: -; VBITS_GE_256-NEXT: mov x8, #8 +; VBITS_GE_256-NEXT: mov x8, #8 // =0x8 ; VBITS_GE_256-NEXT: ptrue p0.s, vl8 ; VBITS_GE_256-NEXT: ld1w { z0.s }, p0/z, [x0, x8, lsl #2] ; VBITS_GE_256-NEXT: ld1w { z1.s }, p0/z, [x0] @@ -680,13 +680,13 @@ define <1 x i64> @srem_v1i64(<1 x i64> %op1, <1 x i64> %op2) vscale_range(1,0) #0 { ; CHECK-LABEL: srem_v1i64: ; CHECK: // %bb.0: -; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 -; CHECK-NEXT: ptrue p0.d, vl1 ; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl1 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 ; CHECK-NEXT: movprfx z2, z0 ; CHECK-NEXT: sdiv z2.d, p0/m, z2.d, z1.d -; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d -; CHECK-NEXT: sub d0, d0, d1 +; CHECK-NEXT: mls z0.d, p0/m, z2.d, z1.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 ; CHECK-NEXT: ret %res = srem <1 x i64> %op1, %op2 ret <1 x i64> %res @@ -697,13 +697,13 @@ define <2 x i64> @srem_v2i64(<2 x i64> %op1, <2 x i64> %op2) vscale_range(1,0) #0 { ; CHECK-LABEL: srem_v2i64: ; CHECK: // %bb.0: -; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 -; CHECK-NEXT: ptrue p0.d, vl2 ; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 ; CHECK-NEXT: movprfx z2, z0 ; CHECK-NEXT: sdiv z2.d, p0/m, z2.d, z1.d -; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d -; CHECK-NEXT: sub v0.2d, v0.2d, v1.2d +; CHECK-NEXT: mls z0.d, p0/m, z2.d, z1.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 ; CHECK-NEXT: ret %res = srem <2 x i64> %op1, %op2 ret <2 x i64> %res @@ -730,34 +730,32 @@ define void @srem_v8i64(ptr %a, ptr %b) #0 { ; VBITS_GE_128-LABEL: srem_v8i64: ; VBITS_GE_128: // %bb.0: -; VBITS_GE_128-NEXT: ldp q4, q5, [x1] -; VBITS_GE_128-NEXT: ptrue p0.d, vl2 -; VBITS_GE_128-NEXT: ldp q7, q6, [x1, #32] ; VBITS_GE_128-NEXT: ldp q0, q1, [x0, #32] -; VBITS_GE_128-NEXT: ldp q2, q3, [x0] -; VBITS_GE_128-NEXT: movprfx z16, z3 -; VBITS_GE_128-NEXT: sdiv z16.d, p0/m, z16.d, z5.d -; VBITS_GE_128-NEXT: movprfx z17, z2 -; VBITS_GE_128-NEXT: sdiv z17.d, p0/m, z17.d, z4.d -; VBITS_GE_128-NEXT: mul z5.d, p0/m, z5.d, z16.d +; VBITS_GE_128-NEXT: ptrue p0.d, vl2 +; VBITS_GE_128-NEXT: ldp q2, q3, [x1, #32] ; VBITS_GE_128-NEXT: movprfx z16, z1 +; VBITS_GE_128-NEXT: sdiv z16.d, p0/m, z16.d, z3.d +; VBITS_GE_128-NEXT: mls z1.d, p0/m, z16.d, z3.d +; VBITS_GE_128-NEXT: movprfx z3, z0 +; VBITS_GE_128-NEXT: sdiv z3.d, p0/m, z3.d, z2.d +; VBITS_GE_128-NEXT: mls z0.d, p0/m, z3.d, z2.d +; VBITS_GE_128-NEXT: ldp q4, q5, [x0] +; VBITS_GE_128-NEXT: ldp q7, q6, [x1] +; VBITS_GE_128-NEXT: movprfx z16, z5 ; VBITS_GE_128-NEXT: sdiv z16.d, p0/m, z16.d, z6.d -; VBITS_GE_128-NEXT: mul z4.d, p0/m, z4.d, z17.d -; VBITS_GE_128-NEXT: movprfx z17, z0 -; VBITS_GE_128-NEXT: sdiv z17.d, p0/m, z17.d, z7.d -; VBITS_GE_128-NEXT: mul z6.d, p0/m, z6.d, z16.d -; VBITS_GE_128-NEXT: mul z7.d, p0/m, z7.d, z17.d -; VBITS_GE_128-NEXT: sub v0.2d, v0.2d, v7.2d -; VBITS_GE_128-NEXT: sub v1.2d, v1.2d, v6.2d -; VBITS_GE_128-NEXT: sub v2.2d, v2.2d, v4.2d +; VBITS_GE_128-NEXT: movprfx z2, z4 +; VBITS_GE_128-NEXT: sdiv z2.d, p0/m, z2.d, z7.d ; VBITS_GE_128-NEXT: stp q0, q1, [x0, #32] -; VBITS_GE_128-NEXT: sub v0.2d, v3.2d, v5.2d -; VBITS_GE_128-NEXT: stp q2, q0, [x0] +; VBITS_GE_128-NEXT: movprfx z0, z4 +; VBITS_GE_128-NEXT: mls z0.d, p0/m, z2.d, z7.d +; VBITS_GE_128-NEXT: movprfx z1, z5 +; VBITS_GE_128-NEXT: mls z1.d, p0/m, z16.d, z6.d +; VBITS_GE_128-NEXT: stp q0, q1, [x0] ; VBITS_GE_128-NEXT: ret ; ; VBITS_GE_256-LABEL: srem_v8i64: ; VBITS_GE_256: // %bb.0: -; VBITS_GE_256-NEXT: mov x8, #4 +; VBITS_GE_256-NEXT: mov x8, #4 // =0x4 ; VBITS_GE_256-NEXT: ptrue p0.d, vl4 ; VBITS_GE_256-NEXT: ld1d { z0.d }, p0/z, [x0, x8, lsl #3] ; VBITS_GE_256-NEXT: ld1d { z1.d }, p0/z, [x0] @@ -1426,7 +1424,7 @@ ; ; VBITS_GE_256-LABEL: urem_v16i32: ; VBITS_GE_256: // %bb.0: -; VBITS_GE_256-NEXT: mov x8, #8 +; VBITS_GE_256-NEXT: mov x8, #8 // =0x8 ; VBITS_GE_256-NEXT: ptrue p0.s, vl8 ; VBITS_GE_256-NEXT: ld1w { z0.s }, p0/z, [x0, x8, lsl #2] ; VBITS_GE_256-NEXT: ld1w { z1.s }, p0/z, [x0] @@ -1500,13 +1498,13 @@ define <1 x i64> @urem_v1i64(<1 x i64> %op1, <1 x i64> %op2) vscale_range(1,0) #0 { ; CHECK-LABEL: urem_v1i64: ; CHECK: // %bb.0: -; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 -; CHECK-NEXT: ptrue p0.d, vl1 ; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl1 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 ; CHECK-NEXT: movprfx z2, z0 ; CHECK-NEXT: udiv z2.d, p0/m, z2.d, z1.d -; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d -; CHECK-NEXT: sub d0, d0, d1 +; CHECK-NEXT: mls z0.d, p0/m, z2.d, z1.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 ; CHECK-NEXT: ret %res = urem <1 x i64> %op1, %op2 ret <1 x i64> %res @@ -1517,13 +1515,13 @@ define <2 x i64> @urem_v2i64(<2 x i64> %op1, <2 x i64> %op2) vscale_range(1,0) #0 { ; CHECK-LABEL: urem_v2i64: ; CHECK: // %bb.0: -; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 -; CHECK-NEXT: ptrue p0.d, vl2 ; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 ; CHECK-NEXT: movprfx z2, z0 ; CHECK-NEXT: udiv z2.d, p0/m, z2.d, z1.d -; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d -; CHECK-NEXT: sub v0.2d, v0.2d, v1.2d +; CHECK-NEXT: mls z0.d, p0/m, z2.d, z1.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 ; CHECK-NEXT: ret %res = urem <2 x i64> %op1, %op2 ret <2 x i64> %res @@ -1550,34 +1548,32 @@ define void @urem_v8i64(ptr %a, ptr %b) #0 { ; VBITS_GE_128-LABEL: urem_v8i64: ; VBITS_GE_128: // %bb.0: -; VBITS_GE_128-NEXT: ldp q4, q5, [x1] -; VBITS_GE_128-NEXT: ptrue p0.d, vl2 -; VBITS_GE_128-NEXT: ldp q7, q6, [x1, #32] ; VBITS_GE_128-NEXT: ldp q0, q1, [x0, #32] -; VBITS_GE_128-NEXT: ldp q2, q3, [x0] -; VBITS_GE_128-NEXT: movprfx z16, z3 -; VBITS_GE_128-NEXT: udiv z16.d, p0/m, z16.d, z5.d -; VBITS_GE_128-NEXT: movprfx z17, z2 -; VBITS_GE_128-NEXT: udiv z17.d, p0/m, z17.d, z4.d -; VBITS_GE_128-NEXT: mul z5.d, p0/m, z5.d, z16.d +; VBITS_GE_128-NEXT: ptrue p0.d, vl2 +; VBITS_GE_128-NEXT: ldp q2, q3, [x1, #32] ; VBITS_GE_128-NEXT: movprfx z16, z1 +; VBITS_GE_128-NEXT: udiv z16.d, p0/m, z16.d, z3.d +; VBITS_GE_128-NEXT: mls z1.d, p0/m, z16.d, z3.d +; VBITS_GE_128-NEXT: movprfx z3, z0 +; VBITS_GE_128-NEXT: udiv z3.d, p0/m, z3.d, z2.d +; VBITS_GE_128-NEXT: mls z0.d, p0/m, z3.d, z2.d +; VBITS_GE_128-NEXT: ldp q4, q5, [x0] +; VBITS_GE_128-NEXT: ldp q7, q6, [x1] +; VBITS_GE_128-NEXT: movprfx z16, z5 ; VBITS_GE_128-NEXT: udiv z16.d, p0/m, z16.d, z6.d -; VBITS_GE_128-NEXT: mul z4.d, p0/m, z4.d, z17.d -; VBITS_GE_128-NEXT: movprfx z17, z0 -; VBITS_GE_128-NEXT: udiv z17.d, p0/m, z17.d, z7.d -; VBITS_GE_128-NEXT: mul z6.d, p0/m, z6.d, z16.d -; VBITS_GE_128-NEXT: mul z7.d, p0/m, z7.d, z17.d -; VBITS_GE_128-NEXT: sub v0.2d, v0.2d, v7.2d -; VBITS_GE_128-NEXT: sub v1.2d, v1.2d, v6.2d -; VBITS_GE_128-NEXT: sub v2.2d, v2.2d, v4.2d +; VBITS_GE_128-NEXT: movprfx z2, z4 +; VBITS_GE_128-NEXT: udiv z2.d, p0/m, z2.d, z7.d ; VBITS_GE_128-NEXT: stp q0, q1, [x0, #32] -; VBITS_GE_128-NEXT: sub v0.2d, v3.2d, v5.2d -; VBITS_GE_128-NEXT: stp q2, q0, [x0] +; VBITS_GE_128-NEXT: movprfx z0, z4 +; VBITS_GE_128-NEXT: mls z0.d, p0/m, z2.d, z7.d +; VBITS_GE_128-NEXT: movprfx z1, z5 +; VBITS_GE_128-NEXT: mls z1.d, p0/m, z16.d, z6.d +; VBITS_GE_128-NEXT: stp q0, q1, [x0] ; VBITS_GE_128-NEXT: ret ; ; VBITS_GE_256-LABEL: urem_v8i64: ; VBITS_GE_256: // %bb.0: -; VBITS_GE_256-NEXT: mov x8, #4 +; VBITS_GE_256-NEXT: mov x8, #4 // =0x4 ; VBITS_GE_256-NEXT: ptrue p0.d, vl4 ; VBITS_GE_256-NEXT: ld1d { z0.d }, p0/z, [x0, x8, lsl #3] ; VBITS_GE_256-NEXT: ld1d { z1.d }, p0/z, [x0]