diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17812,9 +17812,38 @@ return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2); } +static SDValue performAddMulCombine(SDNode *N, SelectionDAG &DAG) { + if (N->getOpcode() == ISD::ADD) { + SDValue MulValue, ConstValue; + bool LowerToScalable = false; + if (N->getOperand(0)->getOpcode() == ISD::EXTRACT_SUBVECTOR) { + LowerToScalable = true; + MulValue = N->getOperand(0).getOperand(0); + ConstValue = N->getOperand(1); + } + if (N->getOperand(1)->getOpcode() == ISD::EXTRACT_SUBVECTOR) { + LowerToScalable = true; + MulValue = N->getOperand(1).getOperand(0); + ConstValue = N->getOperand(0); + } + if (LowerToScalable) { + EVT ContainerVT = + getContainerForFixedLengthVector(DAG, N->getValueType(0)); + SDValue ScaledConstValue = + convertToScalableVector(DAG, ContainerVT, ConstValue); + SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N), ContainerVT, + {ScaledConstValue, MulValue}); + return convertFromScalableVector(DAG, N->getValueType(0), Add); + } + } + return SDValue(); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + if (SDValue Val = performAddMulCombine(N, DAG)) + return Val; // Try to change sum of two reductions. if (SDValue Val = performAddUADDVCombine(N, DAG)) return Val; diff --git a/llvm/test/CodeGen/AArch64/aarch64-combine-mul-add.ll b/llvm/test/CodeGen/AArch64/aarch64-combine-mul-add.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/aarch64-combine-mul-add.ll @@ -0,0 +1,32 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mattr=+sve | FileCheck %s + +define <2 x i64> @test_add_mull_2x64(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c) { +; CHECK-LABEL: test_add_mull_2x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <2 x i64> %b, %c + %add = add <2 x i64> %a, %mul + ret <2 x i64> %add +} + +define <1 x i64> @test_add_mull_1x64(<1 x i64> %a, <1 x i64> %b, <1 x i64> %c) { +; CHECK-LABEL: test_add_mull_1x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl1 +; CHECK-NEXT: // kill: def $d2 killed $d2 def $z2 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <1 x i64> %b, %c + %add = add <1 x i64> %a, %mul + ret <1 x i64> %add +}