diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -13433,6 +13433,32 @@ return DAG.getNode(N->getOpcode(), SDLoc(N), VT, LHS, RHS); } +// add(index_vector(zero, step), dup(X)) -> index_vector(X, step) +static SDValue performAddIndexVectorCombine(SDNode *N, SelectionDAG &DAG) { + if (N->getOpcode() != ISD::ADD) + return SDValue(); + + SDValue IV = N->getOperand(0); + SDValue A = N->getOperand(1); + // Handle commutivity + if (IV.getOpcode() != AArch64ISD::INDEX_VECTOR) + std::swap(IV, A); + if (IV.getOpcode() != AArch64ISD::INDEX_VECTOR) + return SDValue(); + + EVT VT = N->getValueType(0); + assert(VT.isScalableVector() && + "Only expect scalable vectors for STEP_VECTOR"); + + auto *BaseVal = dyn_cast(IV->getOperand(0)); + if (BaseVal && BaseVal->isNullValue() && + (A.getOpcode() == AArch64ISD::DUP || A.getOpcode() == ISD::SPLAT_VECTOR)) + return DAG.getNode(AArch64ISD::INDEX_VECTOR, SDLoc(N), VT, A->getOperand(0), + IV->getOperand(1)); + + return SDValue(); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { @@ -13442,6 +13468,10 @@ if (SDValue Val = performAddDotCombine(N, DAG)) return Val; + // Tryp to combine add with index_vector + if (SDValue Val = performAddIndexVectorCombine(N, DAG)) + return Val; + return performAddSubLongCombine(N, DCI, DAG); } diff --git a/llvm/test/CodeGen/AArch64/sve-stepvector.ll b/llvm/test/CodeGen/AArch64/sve-stepvector.ll --- a/llvm/test/CodeGen/AArch64/sve-stepvector.ll +++ b/llvm/test/CodeGen/AArch64/sve-stepvector.ll @@ -49,9 +49,8 @@ ; CHECK-LABEL: stepvector_nxv4i64: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: cntd x8 -; CHECK-NEXT: mov z1.d, x8 +; CHECK-NEXT: index z1.d, x8, #1 ; CHECK-NEXT: index z0.d, #0, #1 -; CHECK-NEXT: add z1.d, z0.d, z1.d ; CHECK-NEXT: ret entry: %0 = call @llvm.experimental.stepvector.nxv4i64() @@ -61,14 +60,13 @@ define @stepvector_nxv16i32() { ; CHECK-LABEL: stepvector_nxv16i32: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: cntw x9 ; CHECK-NEXT: cnth x8 +; CHECK-NEXT: cntw x9 +; CHECK-NEXT: mov z0.s, w8 +; CHECK-NEXT: index z1.s, w9, #1 +; CHECK-NEXT: index z2.s, w8, #1 +; CHECK-NEXT: add z3.s, z1.s, z0.s ; CHECK-NEXT: index z0.s, #0, #1 -; CHECK-NEXT: mov z1.s, w9 -; CHECK-NEXT: mov z3.s, w8 -; CHECK-NEXT: add z1.s, z0.s, z1.s -; CHECK-NEXT: add z2.s, z0.s, z3.s -; CHECK-NEXT: add z3.s, z1.s, z3.s ; CHECK-NEXT: ret entry: %0 = call @llvm.experimental.stepvector.nxv16i32()