diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17804,9 +17804,56 @@ return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2); } +// This works on the pattern of: add v1 , ( mul v2, v3 ), +// It will transform the add to scalable version, so that +// we can make use of SVE's MLA that will be generated for that pattern. +// Given that the mul is already scalable, as NEON doesn't support i64 mul. +static SDValue performMulAddSubCombine(SDNode *N, SelectionDAG &DAG) { + // Before using SVE's features, check first if it's available. + if(!DAG.getSubtarget().hasSVE()) + return SDValue(); + + if (N->getOpcode() != ISD::ADD && N->getOpcode() != ISD::SUB) + return SDValue(); + + SDValue MulValue, AddConstValue, ExtractIndexValue; + if (N->getOperand(0)->getOpcode() == ISD::EXTRACT_SUBVECTOR) { + MulValue = N->getOperand(0).getOperand(0); + ExtractIndexValue = N->getOperand(0).getOperand(1); + AddConstValue = N->getOperand(1); + } else if (N->getOperand(1)->getOpcode() == ISD::EXTRACT_SUBVECTOR) { + MulValue = N->getOperand(1).getOperand(0); + ExtractIndexValue = N->getOperand(1).getOperand(1); + AddConstValue = N->getOperand(0); + } else + return SDValue(); + + // If the Opcode is NOT MUL, then that is NOT the expected pattern: + if(MulValue.getOpcode() != AArch64ISD::MUL_PRED) + return SDValue(); + + // If the Mul value type is NOT scalable vector, then that is NOT the expected pattern: + EVT VT = MulValue.getValueType(); + if(!VT.isScalableVector()) + return SDValue(); + + // If the ConstValue is NOT 0, then that is NOT the expected pattern: + ConstantSDNode *CV = dyn_cast(ExtractIndexValue); + if(CV && CV->getSExtValue() != 0) + return SDValue(); + + SDValue ScaledConstValue = + convertToScalableVector(DAG, VT, AddConstValue); + SDValue Add = DAG.getNode(N->getOpcode(), SDLoc(N), VT, + {ScaledConstValue, MulValue}); + return convertFromScalableVector(DAG, N->getValueType(0), Add); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + if (SDValue Val = performMulAddSubCombine(N, DAG)) + return Val; // Try to change sum of two reductions. if (SDValue Val = performAddUADDVCombine(N, DAG)) return Val; diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -3165,7 +3165,7 @@ std::pair AdditionalBypass) { Value *VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader); assert(VectorTripCount && "Expected valid arguments"); - + LLVM_DEBUG(dbgs() << "VectorTripCount: "; VectorTripCount->dump();); Instruction *OldInduction = Legal->getPrimaryInduction(); Value *&EndValue = IVEndValues[OrigPhi]; Value *EndValueFromAdditionalBypass = AdditionalBypass.second; diff --git a/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll b/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll @@ -0,0 +1,62 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mattr=+sve | FileCheck %s + +define <2 x i64> @test_mul_add_2x64(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c) { +; CHECK-LABEL: test_mul_add_2x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <2 x i64> %b, %c + %add = add <2 x i64> %a, %mul + ret <2 x i64> %add +} + +define <1 x i64> @test_mul_add_1x64(<1 x i64> %a, <1 x i64> %b, <1 x i64> %c) { +; CHECK-LABEL: test_mul_add_1x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl1 +; CHECK-NEXT: // kill: def $d2 killed $d2 def $z2 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <1 x i64> %b, %c + %add = add <1 x i64> %mul, %a + ret <1 x i64> %add +} + +define <2 x i64> @test_mul_sub_2x64(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c) { +; CHECK-LABEL: test_mul_sub_2x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <2 x i64> %b, %c + %sub = sub <2 x i64> %a, %mul + ret <2 x i64> %sub +} + +define <1 x i64> @test_mul_sub_1x64(<1 x i64> %a, <1 x i64> %b, <1 x i64> %c) { +; CHECK-LABEL: test_mul_sub_1x64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl1 +; CHECK-NEXT: // kill: def $d2 killed $d2 def $z2 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: ret + %mul = mul <1 x i64> %b, %c + %sub = sub <1 x i64> %mul, %a + ret <1 x i64> %sub +}