diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -2075,6 +2075,8 @@ bool isConstant() const; + bool isConstantSequence(APInt &Stride) const; + /// Recast bit data \p SrcBitElements to \p DstEltSizeInBits wide elements. /// Undef elements are treated as zero, and entirely undefined elements are /// flagged in \p DstUndefElements. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -11517,6 +11517,37 @@ return true; } +bool BuildVectorSDNode::isConstantSequence(APInt &Stride) const { + unsigned NumOps = getNumOperands(); + if (NumOps < 2) + return false; + + if (!isa(getOperand(0)) || + !isa(getOperand(1))) + return false; + + unsigned EltSize = getValueType(0).getScalarSizeInBits(); + APInt ExpectedValue = getConstantOperandAPInt(0).truncOrSelf(EltSize); + APInt PossibleStride = getConstantOperandAPInt(1).truncOrSelf(EltSize); + + if (PossibleStride.isZero()) + return 0; + + for (const SDValue &Op : op_values()) { + if (!isa(Op)) + return false; + + APInt Val = cast(Op)->getAPIntValue().truncOrSelf(EltSize); + if (Val != ExpectedValue) + return false; + + ExpectedValue += PossibleStride; + } + + Stride = PossibleStride; + return true; +} + bool ShuffleVectorSDNode::isSplatMask(const int *Mask, EVT VT) { // Find the first non-undef value in the shuffle mask. unsigned i, e; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1571,6 +1571,7 @@ setOperationAction(ISD::ANY_EXTEND, VT, Custom); setOperationAction(ISD::BITCAST, VT, Custom); setOperationAction(ISD::BITREVERSE, VT, Custom); + setOperationAction(ISD::BUILD_VECTOR, VT, Custom); setOperationAction(ISD::BSWAP, VT, Custom); setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::CTLZ, VT, Custom); @@ -10930,6 +10931,18 @@ SelectionDAG &DAG) const { EVT VT = Op.getValueType(); + if (useSVEForFixedLengthVectorVT(VT)) { + APInt Stride; + if (cast(Op)->isConstantSequence(Stride)) { + EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); + SDValue Seq = DAG.getStepVector(SDLoc(Op), ContainerVT, Stride); + return convertFromScalableVector(DAG, Op.getValueType(), Seq); + } + + // Revert to common legalisation for all other variant. + return SDValue(); + } + // Try to build a simple constant vector. Op = NormalizeBuildVector(Op, DAG); if (VT.isInteger()) { @@ -17165,9 +17178,9 @@ // Match: // Index = step(const) int64_t Stride = 0; - if (Index.getOpcode() == ISD::STEP_VECTOR) + if (Index.getOpcode() == ISD::STEP_VECTOR) { Stride = cast(Index.getOperand(0))->getSExtValue(); - + } // Match: // Index = step(const) << shift(const) else if (Index.getOpcode() == ISD::SHL && @@ -17179,6 +17192,13 @@ Stride = Step << Shift->getZExtValue(); } } + // Match: + // Index = build_vector(0, n, 2n, 3n,...) + else if (auto BV = dyn_cast(Index)) { + APInt PossibleStride; + if (BV->isConstantSequence(PossibleStride)) + Stride = PossibleStride.getSExtValue(); + } // Return early because no supported pattern is found. if (Stride == 0) @@ -17202,8 +17222,7 @@ EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32); // Stride does not scale explicitly by 'Scale', because it happens in // the gather/scatter addressing mode. - Index = DAG.getNode(ISD::STEP_VECTOR, SDLoc(N), NewIndexVT, - DAG.getTargetConstant(Stride, SDLoc(N), MVT::i32)); + Index = DAG.getStepVector(SDLoc(N), NewIndexVT, APInt(32, Stride)); return true; }