diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17804,9 +17804,63 @@ return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2); } +// This works on the patterns of: +// add v1, (mul v2, v3) +// sub v1, (mul v2, v3) +// for vectors of type <1 x i64> and <2 x i64> when SVE is available. +// It will transform the add/sub to a scalable version, so that we can +// make use of SVE's MLA/MLS that will be generated for that pattern +static SDValue performMulAddSubCombine(SDNode *N, SelectionDAG &DAG) { + // Before using SVE's features, check first if it's available. + if (!DAG.getSubtarget().hasSVE()) + return SDValue(); + + if (N->getOpcode() != ISD::ADD && N->getOpcode() != ISD::SUB) + return SDValue(); + + if(!N->getValueType(0).isFixedLengthVector()) + return SDValue(); + + SDValue MulValue, Op, ExtractIndexValue; + if (N->getOperand(0)->getOpcode() == ISD::EXTRACT_SUBVECTOR) { + MulValue = N->getOperand(0).getOperand(0); + ExtractIndexValue = N->getOperand(0).getOperand(1); + Op = N->getOperand(1); + } else if (N->getOperand(1)->getOpcode() == ISD::EXTRACT_SUBVECTOR) { + MulValue = N->getOperand(1).getOperand(0); + ExtractIndexValue = N->getOperand(1).getOperand(1); + Op = N->getOperand(0); + } else + return SDValue(); + + // If the Opcode is NOT MUL, then that is NOT the expected pattern: + if (MulValue.getOpcode() != AArch64ISD::MUL_PRED) + return SDValue(); + + if(!N->getOperand(0).hasOneUse()) + return SDValue(); + + // If the Mul value type is NOT scalable vector, then that is NOT the expected + // pattern: + EVT VT = MulValue.getValueType(); + if (!VT.isScalableVector()) + return SDValue(); + + // If the ConstValue is NOT 0, then that is NOT the expected pattern: + if (!cast(ExtractIndexValue)->isZero()) + return SDValue(); + + SDValue ScaledOp = convertToScalableVector(DAG, VT, Op); + SDValue Add = + DAG.getNode(N->getOpcode(), SDLoc(N), VT, {ScaledOp, MulValue}); + return convertFromScalableVector(DAG, N->getValueType(0), Add); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + if (SDValue Val = performMulAddSubCombine(N, DAG)) + return Val; // Try to change sum of two reductions. if (SDValue Val = performAddUADDVCombine(N, DAG)) return Val; diff --git a/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll b/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll @@ -0,0 +1,62 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mattr=+sve | FileCheck %s + +define <2 x i64> @test_mul_add_2x64(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c) { +; CHECK-LABEL: test_mul_add_2x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <2 x i64> %b, %c + %add = add <2 x i64> %a, %mul + ret <2 x i64> %add +} + +define <1 x i64> @test_mul_add_1x64(<1 x i64> %a, <1 x i64> %b, <1 x i64> %c) { +; CHECK-LABEL: test_mul_add_1x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl1 +; CHECK-NEXT: // kill: def $d2 killed $d2 def $z2 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <1 x i64> %b, %c + %add = add <1 x i64> %mul, %a + ret <1 x i64> %add +} + +define <2 x i64> @test_mul_sub_2x64(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c) { +; CHECK-LABEL: test_mul_sub_2x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <2 x i64> %b, %c + %sub = sub <2 x i64> %a, %mul + ret <2 x i64> %sub +} + +define <1 x i64> @test_mul_sub_1x64(<1 x i64> %a, <1 x i64> %b, <1 x i64> %c) { +; CHECK-LABEL: test_mul_sub_1x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl1 +; CHECK-NEXT: // kill: def $d2 killed $d2 def $z2 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <1 x i64> %b, %c + %sub = sub <1 x i64> %mul, %a + ret <1 x i64> %sub +} diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll @@ -606,7 +606,7 @@ ; ; VBITS_GE_256-LABEL: srem_v16i32: ; VBITS_GE_256: // %bb.0: -; VBITS_GE_256-NEXT: mov x8, #8 +; VBITS_GE_256-NEXT: mov x8, #8 // =0x8 ; VBITS_GE_256-NEXT: ptrue p0.s, vl8 ; VBITS_GE_256-NEXT: ld1w { z0.s }, p0/z, [x0, x8, lsl #2] ; VBITS_GE_256-NEXT: ld1w { z1.s }, p0/z, [x0] @@ -757,7 +757,7 @@ ; ; VBITS_GE_256-LABEL: srem_v8i64: ; VBITS_GE_256: // %bb.0: -; VBITS_GE_256-NEXT: mov x8, #4 +; VBITS_GE_256-NEXT: mov x8, #4 // =0x4 ; VBITS_GE_256-NEXT: ptrue p0.d, vl4 ; VBITS_GE_256-NEXT: ld1d { z0.d }, p0/z, [x0, x8, lsl #3] ; VBITS_GE_256-NEXT: ld1d { z1.d }, p0/z, [x0] @@ -1426,7 +1426,7 @@ ; ; VBITS_GE_256-LABEL: urem_v16i32: ; VBITS_GE_256: // %bb.0: -; VBITS_GE_256-NEXT: mov x8, #8 +; VBITS_GE_256-NEXT: mov x8, #8 // =0x8 ; VBITS_GE_256-NEXT: ptrue p0.s, vl8 ; VBITS_GE_256-NEXT: ld1w { z0.s }, p0/z, [x0, x8, lsl #2] ; VBITS_GE_256-NEXT: ld1w { z1.s }, p0/z, [x0] @@ -1577,7 +1577,7 @@ ; ; VBITS_GE_256-LABEL: urem_v8i64: ; VBITS_GE_256: // %bb.0: -; VBITS_GE_256-NEXT: mov x8, #4 +; VBITS_GE_256-NEXT: mov x8, #4 // =0x4 ; VBITS_GE_256-NEXT: ptrue p0.d, vl4 ; VBITS_GE_256-NEXT: ld1d { z0.d }, p0/z, [x0, x8, lsl #3] ; VBITS_GE_256-NEXT: ld1d { z1.d }, p0/z, [x0]