diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -13433,6 +13433,36 @@ return DAG.getNode(N->getOpcode(), SDLoc(N), VT, LHS, RHS); } +// add(index_vector(zero, step), dup(X)) -> index_vector(X, step) +static SDValue performAddIndexVectorCombine(SDNode *N, SelectionDAG &DAG) { + if (N->getOpcode() != ISD::ADD) + return SDValue(); + + SDValue IV = N->getOperand(0); + SDValue A = N->getOperand(1); + // Handle commutivity + if (IV.getOpcode() != AArch64ISD::INDEX_VECTOR) + std::swap(IV, A); + if (IV.getOpcode() != AArch64ISD::INDEX_VECTOR) + return SDValue(); + + // We do not want to duplicate multiple index_vector. + if (!IV->hasOneUse()) + return SDValue(); + + EVT VT = N->getValueType(0); + assert(VT.isScalableVector() && + "Only expect scalable vectors for STEP_VECTOR"); + + auto *BaseVal = dyn_cast(IV->getOperand(0)); + if (BaseVal && BaseVal->isNullValue() && + (A.getOpcode() == AArch64ISD::DUP || A.getOpcode() == ISD::SPLAT_VECTOR)) + return DAG.getNode(AArch64ISD::INDEX_VECTOR, SDLoc(N), VT, A->getOperand(0), + IV->getOperand(1)); + + return SDValue(); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { @@ -13442,6 +13472,10 @@ if (SDValue Val = performAddDotCombine(N, DAG)) return Val; + // Try to combine add with index_vector + if (SDValue Val = performAddIndexVectorCombine(N, DAG)) + return Val; + return performAddSubLongCombine(N, DCI, DAG); } diff --git a/llvm/test/CodeGen/AArch64/sve-stepvector.ll b/llvm/test/CodeGen/AArch64/sve-stepvector.ll --- a/llvm/test/CodeGen/AArch64/sve-stepvector.ll +++ b/llvm/test/CodeGen/AArch64/sve-stepvector.ll @@ -131,6 +131,32 @@ ret %3 } +define @add_stepvector_nxv8i8_2() { +; CHECK-LABEL: add_stepvector_nxv8i8_2: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: index z0.h, #2, #1 +; CHECK-NEXT: ret +entry: + %0 = insertelement poison, i8 2, i32 0 + %1 = shufflevector %0, poison, zeroinitializer + %2 = call @llvm.experimental.stepvector.nxv8i8() + %3 = add %2, %1 + ret %3 +} + +define @add_stepvector_nxv8i8_3() { +; CHECK-LABEL: add_stepvector_nxv8i8_3: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: index z0.h, #2, #1 +; CHECK-NEXT: ret +entry: + %0 = insertelement poison, i8 2, i32 0 + %1 = shufflevector %0, poison, zeroinitializer + %2 = call @llvm.experimental.stepvector.nxv8i8() + %3 = add %1, %2 + ret %3 +} + define @mul_stepvector_nxv8i8() { ; CHECK-LABEL: mul_stepvector_nxv8i8: ; CHECK: // %bb.0: // %entry