diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17804,9 +17804,57 @@ return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2); } +// This works on the patterns of: +// add v1, (mul v2, v3) +// sub v1, (mul v2, v3) +// for vectors of type <1 x i64> and <2 x i64> when SVE is available. +// It will transform the add/sub to a scalable version, so that we can +// make use of SVE's MLA/MLS that will be generated for that pattern +static SDValue performMulAddSubCombine(SDNode *N, SelectionDAG &DAG) { + // Before using SVE's features, check first if it's available. + if (!DAG.getSubtarget().hasSVE()) + return SDValue(); + + if (N->getOpcode() != ISD::ADD && N->getOpcode() != ISD::SUB) + return SDValue(); + + SDValue MulValue, AddOp, ExtractIndexValue; + if (N->getOperand(0)->getOpcode() == ISD::EXTRACT_SUBVECTOR) { + MulValue = N->getOperand(0).getOperand(0); + ExtractIndexValue = N->getOperand(0).getOperand(1); + AddOp = N->getOperand(1); + } else if (N->getOperand(1)->getOpcode() == ISD::EXTRACT_SUBVECTOR) { + MulValue = N->getOperand(1).getOperand(0); + ExtractIndexValue = N->getOperand(1).getOperand(1); + AddOp = N->getOperand(0); + } else + return SDValue(); + + // If the Opcode is NOT MUL, then that is NOT the expected pattern: + if (MulValue.getOpcode() != AArch64ISD::MUL_PRED) + return SDValue(); + + // If the Mul value type is NOT scalable vector, then that is NOT the expected + // pattern: + EVT VT = MulValue.getValueType(); + if (!VT.isScalableVector()) + return SDValue(); + + // If the ConstValue is NOT 0, then that is NOT the expected pattern: + if (!cast(ExtractIndexValue)->isZero()) + return SDValue(); + + SDValue ScaledConstValue = convertToScalableVector(DAG, VT, AddOp); + SDValue Add = + DAG.getNode(N->getOpcode(), SDLoc(N), VT, {ScaledConstValue, MulValue}); + return convertFromScalableVector(DAG, N->getValueType(0), Add); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + if (SDValue Val = performMulAddSubCombine(N, DAG)) + return Val; // Try to change sum of two reductions. if (SDValue Val = performAddUADDVCombine(N, DAG)) return Val; diff --git a/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll b/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll @@ -0,0 +1,62 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mattr=+sve | FileCheck %s + +define <2 x i64> @test_mul_add_2x64(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c) { +; CHECK-LABEL: test_mul_add_2x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <2 x i64> %b, %c + %add = add <2 x i64> %a, %mul + ret <2 x i64> %add +} + +define <1 x i64> @test_mul_add_1x64(<1 x i64> %a, <1 x i64> %b, <1 x i64> %c) { +; CHECK-LABEL: test_mul_add_1x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl1 +; CHECK-NEXT: // kill: def $d2 killed $d2 def $z2 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <1 x i64> %b, %c + %add = add <1 x i64> %mul, %a + ret <1 x i64> %add +} + +define <2 x i64> @test_mul_sub_2x64(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c) { +; CHECK-LABEL: test_mul_sub_2x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <2 x i64> %b, %c + %sub = sub <2 x i64> %a, %mul + ret <2 x i64> %sub +} + +define <1 x i64> @test_mul_sub_1x64(<1 x i64> %a, <1 x i64> %b, <1 x i64> %c) { +; CHECK-LABEL: test_mul_sub_1x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl1 +; CHECK-NEXT: // kill: def $d2 killed $d2 def $z2 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <1 x i64> %b, %c + %sub = sub <1 x i64> %mul, %a + ret <1 x i64> %sub +}