diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -594,10 +594,10 @@ /// STEP_VECTOR(IMM) - Returns a scalable vector whose lanes are comprised /// of a linear sequence of unsigned values starting from 0 with a step of - /// IMM, where IMM must be a vector index constant positive integer value - /// which must fit in the vector element type. + /// IMM, where IMM must be a vector index constant integer value which must + /// fit in the vector element type. /// Note that IMM may be a smaller type than the vector element type, in - /// which case the step is implicitly zero-extended to the vector element + /// which case the step is implicitly sign-extended to the vector element /// type. IMM may also be a larger type than the vector element type, in /// which case the step is implicitly truncated to the vector element type. /// The operation does not support returning fixed-width vectors or diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -3544,6 +3544,14 @@ return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -IntVal)); } + // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C)) + if (N1.getOpcode() == ISD::STEP_VECTOR) { + SDValue NewStep = DAG.getConstant(-N1.getConstantOperandAPInt(0), DL, + N1.getOperand(0).getValueType()); + return DAG.getNode(ISD::ADD, DL, VT, N0, + DAG.getStepVector(SDLoc(N), VT, NewStep)); + } + // Prefer an add for more folding potential and possibly better codegen: // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1) if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) { diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -4791,7 +4791,8 @@ EVT NOutElemVT = TLI.getTypeToTransformTo(*DAG.getContext(), NOutVT.getVectorElementType()); APInt StepVal = cast(N->getOperand(0))->getAPIntValue(); - SDValue Step = DAG.getConstant(StepVal.getZExtValue(), dl, NOutElemVT); + SDValue Step = DAG.getConstant( + StepVal.sextOrTrunc(NOutElemVT.getSizeInBits()), dl, NOutElemVT); return DAG.getStepVector(dl, NOutVT, Step); } diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1655,11 +1655,13 @@ // Hi = Lo + (EltCnt * Step) EVT EltVT = Step.getValueType(); + APInt StepVal = cast(Step)->getAPIntValue(); SDValue StartOfHi = - DAG.getVScale(dl, EltVT, - cast(Step)->getAPIntValue() * - LoVT.getVectorMinNumElements()); - StartOfHi = DAG.getZExtOrTrunc(StartOfHi, dl, HiVT.getVectorElementType()); + DAG.getVScale(dl, EltVT, StepVal * LoVT.getVectorMinNumElements()); + StartOfHi = + StepVal.isNonNegative() + ? DAG.getZExtOrTrunc(StartOfHi, dl, HiVT.getVectorElementType()) + : DAG.getSExtOrTrunc(StartOfHi, dl, HiVT.getVectorElementType()); StartOfHi = DAG.getNode(ISD::SPLAT_VECTOR, dl, HiVT, StartOfHi); Hi = DAG.getNode(ISD::STEP_VECTOR, dl, HiVT, Step); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -4717,10 +4717,9 @@ "STEP_VECTOR can only be used with vectors of integers that are at " "least 8 bits wide"); assert(isa(Operand) && - cast(Operand)->getAPIntValue().isNonNegative() && cast(Operand)->getAPIntValue().isSignedIntN( VT.getScalarSizeInBits()) && - "Expected STEP_VECTOR integer constant to be positive and fit in " + "Expected STEP_VECTOR integer constant to be fit in " "the vector element type"); break; case ISD::FREEZE: diff --git a/llvm/test/CodeGen/AArch64/sve-stepvector.ll b/llvm/test/CodeGen/AArch64/sve-stepvector.ll --- a/llvm/test/CodeGen/AArch64/sve-stepvector.ll +++ b/llvm/test/CodeGen/AArch64/sve-stepvector.ll @@ -259,6 +259,51 @@ ret %3 } +define @sub_stepvector_nxv8i16() { +; CHECK-LABEL: sub_stepvector_nxv8i16: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: index z0.h, #2, #-1 +; CHECK-NEXT: ret +entry: + %0 = insertelement poison, i16 2, i32 0 + %1 = shufflevector %0, poison, zeroinitializer + %2 = call @llvm.experimental.stepvector.nxv8i16() + %3 = sub %1, %2 + ret %3 +} + +define @promote_sub_stepvector_nxv8i8() { +; CHECK-LABEL: promote_sub_stepvector_nxv8i8: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: index z0.h, #2, #-1 +; CHECK-NEXT: ret +entry: + %0 = insertelement poison, i8 2, i32 0 + %1 = shufflevector %0, poison, zeroinitializer + %2 = call @llvm.experimental.stepvector.nxv8i8() + %3 = sub %1, %2 + ret %3 +} + +define @split_sub_stepvector_nxv16i32() { +; CHECK-LABEL: split_sub_stepvector_nxv16i32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: cntw x9 +; CHECK-NEXT: cnth x8 +; CHECK-NEXT: neg x9, x9 +; CHECK-NEXT: index z0.s, #0, #-1 +; CHECK-NEXT: neg x8, x8 +; CHECK-NEXT: mov z1.s, w9 +; CHECK-NEXT: mov z3.s, w8 +; CHECK-NEXT: add z1.s, z0.s, z1.s +; CHECK-NEXT: add z2.s, z0.s, z3.s +; CHECK-NEXT: add z3.s, z1.s, z3.s +; CHECK-NEXT: ret +entry: + %0 = call @llvm.experimental.stepvector.nxv16i32() + %1 = sub zeroinitializer, %0 + ret %1 +} declare @llvm.experimental.stepvector.nxv2i64() declare @llvm.experimental.stepvector.nxv4i32()