diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -2950,10 +2950,6 @@ static bool isDeinterleaveShuffle(MVT VT, MVT ContainerVT, SDValue V1, SDValue V2, ArrayRef Mask, const RISCVSubtarget &Subtarget) { - // Need to be able to widen the vector. - if (VT.getScalarSizeInBits() >= Subtarget.getELEN()) - return false; - // Both input must be extracts. if (V1.getOpcode() != ISD::EXTRACT_SUBVECTOR || V2.getOpcode() != ISD::EXTRACT_SUBVECTOR) @@ -3116,27 +3112,41 @@ } // Lower a deinterleave shuffle to vnsrl. -static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT, - MVT ContainerVT, - SDValue Src, bool EvenElts, - SDValue TrueMask, SDValue VL, +// [a, p, b, q, c, r, d, s] -> [a, b, c, d] (EvenElts == true) +// -> [p, q, r, s] (EvenElts == false) +// VT is the type of the vector to return, <[vscale x ]n x ty> +// Src is the vector to deinterleave of type <[vscale x ]n*2 x ty> +static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT, SDValue Src, + bool EvenElts, const RISCVSubtarget &Subtarget, SelectionDAG &DAG) { - // Convert the source using a container type with twice the elements. Since - // source VT is legal and twice this VT, we know VT isn't LMUL=8 so it is - // safe to double. - MVT DoubleContainerVT = - MVT::getVectorVT(ContainerVT.getVectorElementType(), - ContainerVT.getVectorElementCount() * 2); - Src = convertToScalableVector(DoubleContainerVT, Src, DAG, Subtarget); - - // Convert the vector to a wider integer type with the original element - // count. This also converts FP to int. - unsigned EltBits = ContainerVT.getScalarSizeInBits(); - MVT WideIntContainerVT = - MVT::getVectorVT(MVT::getIntegerVT(EltBits * 2), - ContainerVT.getVectorElementCount()); - Src = DAG.getBitcast(WideIntContainerVT, Src); + // Need to be able to widen the vector. + if (VT.getScalarSizeInBits() >= Subtarget.getELEN()) + return SDValue(); + + // The result is a vector of type + MVT ContainerVT = VT; + // Convert fixed vectors to scalable if needed + if (ContainerVT.isFixedLengthVector()) + ContainerVT = getContainerForFixedLengthVector(DAG, ContainerVT, Subtarget); + + auto [TrueMask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); + + // The (minimum) number of elements in the result vector + ElementCount N = ContainerVT.getVectorElementCount(); + + // The source is a vector of type + MVT SrcContainerVT = + MVT::getVectorVT(ContainerVT.getVectorElementType(), N * 2); + // Make the source scalable if needed + if (Src.getSimpleValueType().isFixedLengthVector()) + Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget); + + // Bitcast the source vector from -> + // This also converts FP to int. + unsigned EltBits = SrcContainerVT.getScalarSizeInBits(); + MVT WideSrcContainerVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits * 2), N); + Src = DAG.getBitcast(WideSrcContainerVT, Src); // The integer version of the container type. MVT IntContainerVT = ContainerVT.changeVectorElementTypeToInteger(); @@ -3153,7 +3163,9 @@ // Cast back to FP if needed. Res = DAG.getBitcast(ContainerVT, Res); - return convertFromScalableVector(VT, Res, DAG, Subtarget); + if (VT.isFixedLengthVector()) + Res = convertFromScalableVector(VT, Res, DAG, Subtarget); + return Res; } static SDValue @@ -3465,8 +3477,9 @@ } if (isDeinterleaveShuffle(VT, ContainerVT, V1, V2, Mask, Subtarget)) - return getDeinterleaveViaVNSRL(DL, VT, ContainerVT, V1.getOperand(0), - Mask[0] == 0, TrueMask, VL, Subtarget, DAG); + if (SDValue Deinterleave = getDeinterleaveViaVNSRL( + DL, VT, V1.getOperand(0), Mask[0] == 0, Subtarget, DAG)) + return Deinterleave; // Detect an interleave shuffle and lower to // (vmaccu.vx (vwaddu.vx lohalf(V1), lohalf(V2)), lohalf(V2), (2^eltbits - 1)) @@ -6622,33 +6635,12 @@ auto [Mask, VL] = getDefaultScalableVLOps(ConcatVT, DL, DAG, Subtarget); SDValue Passthru = DAG.getUNDEF(ConcatVT); - // If the element type is smaller than ELEN, then we can deinterleave - // through vnsrl.wi - if (VecVT.getScalarSizeInBits() < Subtarget.getELEN()) { - // Bitcast the concatenated vector from -> - // This is also casts FPs to ints - MVT WideVT = MVT::getVectorVT( - MVT::getIntegerVT(ConcatVT.getScalarSizeInBits() * 2), - ConcatVT.getVectorElementCount().divideCoefficientBy(2)); - SDValue Wide = DAG.getBitcast(WideVT, Concat); - - MVT NarrowVT = VecVT.changeVectorElementTypeToInteger(); - SDValue Passthru = DAG.getUNDEF(VecVT); - - SDValue Even = DAG.getNode( - RISCVISD::VNSRL_VL, DL, NarrowVT, Wide, - DAG.getSplatVector(NarrowVT, DL, DAG.getConstant(0, DL, XLenVT)), - Passthru, Mask, VL); - SDValue Odd = DAG.getNode( - RISCVISD::VNSRL_VL, DL, NarrowVT, Wide, - DAG.getSplatVector( - NarrowVT, DL, - DAG.getConstant(VecVT.getScalarSizeInBits(), DL, XLenVT)), - Passthru, Mask, VL); - - // Bitcast the results back in case it was casted from an FP vector - return DAG.getMergeValues( - {DAG.getBitcast(VecVT, Even), DAG.getBitcast(VecVT, Odd)}, DL); + // We might be able to deinterleave through vnsrl.wi if the element type is + // smaller than ELEN + if (SDValue Even, Odd; + Even = getDeinterleaveViaVNSRL(DL, VecVT, Concat, true, Subtarget, DAG), + Odd = getDeinterleaveViaVNSRL(DL, VecVT, Concat, false, Subtarget, DAG)) { + return DAG.getMergeValues({Even, Odd}, DL); } // For the indices, use the same SEW to avoid an extra vsetvli