diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2503,6 +2503,31 @@ return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS); } + // Fold (add step_vector(c1), step_vector(c2) to step_vector(c1+c2)) + if (N0.getOpcode() == ISD::STEP_VECTOR && + N1.getOpcode() == ISD::STEP_VECTOR) { + const APInt &C0 = N0->getConstantOperandAPInt(0); + const APInt &C1 = N1->getConstantOperandAPInt(0); + EVT SVT = N0.getOperand(0).getValueType(); + SDValue NewStep = DAG.getConstant(C0 + C1, DL, SVT); + return DAG.getStepVector(DL, VT, NewStep); + } + + // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2) + if ((N0.getOpcode() == ISD::ADD) && + (N0.getOperand(1).getOpcode() == ISD::STEP_VECTOR) && + (N1.getOpcode() == ISD::STEP_VECTOR)) { + const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0); + const APInt &SV1 = N1->getConstantOperandAPInt(0); + EVT SVT = N1.getOperand(0).getValueType(); + assert(N1.getOperand(0).getValueType() == + N0.getOperand(1)->getOperand(0).getValueType() && + "Different operand types of STEP_VECTOR."); + SDValue NewStep = DAG.getConstant(SV0 + SV1, DL, SVT); + SDValue SV = DAG.getStepVector(DL, VT, NewStep); + return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV); + } + return SDValue(); } @@ -3893,6 +3918,17 @@ return DAG.getVScale(SDLoc(N), VT, C0 * C1); } + // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)). + APInt MulVal; + if (N0.getOpcode() == ISD::STEP_VECTOR) + if (ISD::isConstantSplatVector(N1.getNode(), MulVal)) { + const APInt &C0 = N0.getConstantOperandAPInt(0); + EVT SVT = N0.getOperand(0).getValueType(); + SDValue NewStep = DAG.getConstant( + C0 * MulVal.sextOrTrunc(SVT.getSizeInBits()), SDLoc(N), SVT); + return DAG.getStepVector(SDLoc(N), VT, NewStep); + } + // Fold ((mul x, 0/undef) -> 0, // (mul x, 1) -> x) -> x) // -> and(x, mask) @@ -8381,6 +8417,17 @@ return DAG.getVScale(SDLoc(N), VT, C0 << C1); } + // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)). + APInt ShlVal; + if (N0.getOpcode() == ISD::STEP_VECTOR) + if (ISD::isConstantSplatVector(N1.getNode(), ShlVal)) { + const APInt &C0 = N0.getConstantOperandAPInt(0); + EVT SVT = N0.getOperand(0).getValueType(); + SDValue NewStep = DAG.getConstant( + C0 << ShlVal.sextOrTrunc(SVT.getSizeInBits()), SDLoc(N), SVT); + return DAG.getStepVector(SDLoc(N), VT, NewStep); + } + return SDValue(); } diff --git a/llvm/test/CodeGen/AArch64/sve-stepvector.ll b/llvm/test/CodeGen/AArch64/sve-stepvector.ll --- a/llvm/test/CodeGen/AArch64/sve-stepvector.ll +++ b/llvm/test/CodeGen/AArch64/sve-stepvector.ll @@ -105,6 +105,59 @@ ret %0 } +define @add_stepvector_nxv8i8() { +; CHECK-LABEL: add_stepvector_nxv8i8: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: index z0.h, #0, #2 +; CHECK-NEXT: ret +entry: + %0 = call @llvm.experimental.stepvector.nxv8i8() + %1 = call @llvm.experimental.stepvector.nxv8i8() + %2 = add %0, %1 + ret %2 +} + +define @add_stepvector_nxv8i8_1( %p) { +; CHECK-LABEL: add_stepvector_nxv8i8_1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: index z1.h, #0, #2 +; CHECK-NEXT: add z0.h, z0.h, z1.h +; CHECK-NEXT: ret +entry: + %0 = call @llvm.experimental.stepvector.nxv8i8() + %1 = add %p, %0 + %2 = call @llvm.experimental.stepvector.nxv8i8() + %3 = add %1, %2 + ret %3 +} + +define @mul_stepvector_nxv8i8() { +; CHECK-LABEL: mul_stepvector_nxv8i8: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: index z0.h, #0, #2 +; CHECK-NEXT: ret +entry: + %0 = insertelement poison, i8 2, i32 0 + %1 = shufflevector %0, poison, zeroinitializer + %2 = call @llvm.experimental.stepvector.nxv8i8() + %3 = mul %2, %1 + ret %3 +} + +define @shl_stepvector_nxv8i8() { +; CHECK-LABEL: shl_stepvector_nxv8i8: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: index z0.h, #0, #4 +; CHECK-NEXT: ret +entry: + %0 = insertelement poison, i8 2, i32 0 + %1 = shufflevector %0, poison, zeroinitializer + %2 = call @llvm.experimental.stepvector.nxv8i8() + %3 = shl %2, %1 + ret %3 +} + + declare @llvm.experimental.stepvector.nxv2i64() declare @llvm.experimental.stepvector.nxv4i32() declare @llvm.experimental.stepvector.nxv8i16()