diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -3113,27 +3113,36 @@ } // Lower a deinterleave shuffle to vnsrl. -static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT, - MVT ContainerVT, - SDValue Src, bool EvenElts, - SDValue TrueMask, SDValue VL, +// [a, p, b, q, c, r, d, s] -> [a, b, c, d] (EvenElts == true) +// -> [p, q, r, s] (EvenElts == false) +// VT is the type of the vector to return, <[vscale x ]n x ty> +// Src is the vector to deinterleave of type <[vscale x ]n*2 x ty> +static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT, SDValue Src, + bool EvenElts, const RISCVSubtarget &Subtarget, SelectionDAG &DAG) { - // Convert the source using a container type with twice the elements. Since - // source VT is legal and twice this VT, we know VT isn't LMUL=8 so it is - // safe to double. - MVT DoubleContainerVT = - MVT::getVectorVT(ContainerVT.getVectorElementType(), - ContainerVT.getVectorElementCount() * 2); - Src = convertToScalableVector(DoubleContainerVT, Src, DAG, Subtarget); - - // Convert the vector to a wider integer type with the original element - // count. This also converts FP to int. + // The result is a vector of type + MVT ContainerVT = VT; + // Convert fixed vectors to scalable if needed + if (ContainerVT.isFixedLengthVector()) { + assert(Src.getSimpleValueType().isFixedLengthVector()); + ContainerVT = getContainerForFixedLengthVector(DAG, ContainerVT, Subtarget); + + // The source is a vector of type + MVT SrcContainerVT = + MVT::getVectorVT(ContainerVT.getVectorElementType(), + ContainerVT.getVectorElementCount() * 2); + Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget); + } + + auto [TrueMask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); + + // Bitcast the source vector from -> + // This also converts FP to int. unsigned EltBits = ContainerVT.getScalarSizeInBits(); - MVT WideIntContainerVT = - MVT::getVectorVT(MVT::getIntegerVT(EltBits * 2), - ContainerVT.getVectorElementCount()); - Src = DAG.getBitcast(WideIntContainerVT, Src); + MVT WideSrcContainerVT = MVT::getVectorVT( + MVT::getIntegerVT(EltBits * 2), ContainerVT.getVectorElementCount()); + Src = DAG.getBitcast(WideSrcContainerVT, Src); // The integer version of the container type. MVT IntContainerVT = ContainerVT.changeVectorElementTypeToInteger(); @@ -3150,7 +3159,9 @@ // Cast back to FP if needed. Res = DAG.getBitcast(ContainerVT, Res); - return convertFromScalableVector(VT, Res, DAG, Subtarget); + if (VT.isFixedLengthVector()) + Res = convertFromScalableVector(VT, Res, DAG, Subtarget); + return Res; } static SDValue @@ -3461,9 +3472,12 @@ return convertFromScalableVector(VT, Res, DAG, Subtarget); } - if (isDeinterleaveShuffle(VT, ContainerVT, V1, V2, Mask, Subtarget)) - return getDeinterleaveViaVNSRL(DL, VT, ContainerVT, V1.getOperand(0), - Mask[0] == 0, TrueMask, VL, Subtarget, DAG); + // If this is a deinterleave and we can widen the vector, then we can use + // vnsrl to deinterleave. + if (isDeinterleaveShuffle(VT, ContainerVT, V1, V2, Mask, Subtarget)) { + return getDeinterleaveViaVNSRL(DL, VT, V1.getOperand(0), Mask[0] == 0, + Subtarget, DAG); + } // Detect an interleave shuffle and lower to // (vmaccu.vx (vwaddu.vx lohalf(V1), lohalf(V2)), lohalf(V2), (2^eltbits - 1)) @@ -6619,33 +6633,14 @@ auto [Mask, VL] = getDefaultScalableVLOps(ConcatVT, DL, DAG, Subtarget); SDValue Passthru = DAG.getUNDEF(ConcatVT); - // If the element type is smaller than ELEN, then we can deinterleave - // through vnsrl.wi + // We can deinterleave through vnsrl.wi if the element type is smaller than + // ELEN if (VecVT.getScalarSizeInBits() < Subtarget.getELEN()) { - // Bitcast the concatenated vector from -> - // This is also casts FPs to ints - MVT WideVT = MVT::getVectorVT( - MVT::getIntegerVT(ConcatVT.getScalarSizeInBits() * 2), - ConcatVT.getVectorElementCount().divideCoefficientBy(2)); - SDValue Wide = DAG.getBitcast(WideVT, Concat); - - MVT NarrowVT = VecVT.changeVectorElementTypeToInteger(); - SDValue Passthru = DAG.getUNDEF(VecVT); - - SDValue Even = DAG.getNode( - RISCVISD::VNSRL_VL, DL, NarrowVT, Wide, - DAG.getSplatVector(NarrowVT, DL, DAG.getConstant(0, DL, XLenVT)), - Passthru, Mask, VL); - SDValue Odd = DAG.getNode( - RISCVISD::VNSRL_VL, DL, NarrowVT, Wide, - DAG.getSplatVector( - NarrowVT, DL, - DAG.getConstant(VecVT.getScalarSizeInBits(), DL, XLenVT)), - Passthru, Mask, VL); - - // Bitcast the results back in case it was casted from an FP vector - return DAG.getMergeValues( - {DAG.getBitcast(VecVT, Even), DAG.getBitcast(VecVT, Odd)}, DL); + SDValue Even = + getDeinterleaveViaVNSRL(DL, VecVT, Concat, true, Subtarget, DAG); + SDValue Odd = + getDeinterleaveViaVNSRL(DL, VecVT, Concat, false, Subtarget, DAG); + return DAG.getMergeValues({Even, Odd}, DL); } // For the indices, use the same SEW to avoid an extra vsetvli