diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -2729,6 +2729,82 @@ return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL, DAG); } +// If a vector_shuffle source is from a same and contiguous (and started with 0 +// index) vector, return the vector. Otherwise, return SDValue(). +static SDValue getSingleShuffleSource(SDValue N0, SDValue N1, + ArrayRef Mask) { + if (N1.isUndef()) { + if (N0.getOpcode() != ISD::EXTRACT_SUBVECTOR) + return N0; + return N0.getOperand(0); + } + + // Both input must be extracts. + if (N0.getOpcode() != ISD::EXTRACT_SUBVECTOR || + N1.getOpcode() != ISD::EXTRACT_SUBVECTOR) + return SDValue(); + + // Extracting from the same source. + SDValue Src = N0.getOperand(0); + if (Src != N1.getOperand(0)) + return SDValue(); + + // Make sure N0 and N1 are continuous. + if (N0.getConstantOperandVal(1) != 0 || + N1.getConstantOperandVal(1) != Mask.size()) + return SDValue(); + + return Src; +} + +// Mask can only be this form. +// X X X X ... -1 -1 -1 ..., while X is not -1. +// X X X X ... must be in ascending order. +// for example, +// [0, 2, 4, 6] -> vnsrl src, 0 +// [0, 4, 8, 12] -> vnsrl (vnsrl src, 0), 0 +// [3, 7, 11, 15] -> vnsrl (vnsrl src, EltSize * 2), EltSize +// [2, 10, 18, 26] -> vnsrl (vnsrl (vnsrl src, 0), EltSize * 2), 0 +// In addition, N0 and N1 must from a same vector (or N1 is undef). +static SDValue isVnsrlShuffle(SDValue N0, SDValue N1, ArrayRef Mask, + EVT VT, const RISCVSubtarget &Subtarget) { + SDValue Src = getSingleShuffleSource(N0, N1, Mask); + if (!Src) + return SDValue(); + if (Mask.size() < 2) + return SDValue(); + // Find first -1 and check whether the mask behind is -1. + auto FirstUndef = find(Mask, -1); + if (std::any_of(FirstUndef, Mask.end(), + [](int MaskIdx) { return MaskIdx != -1; })) + return SDValue(); + ptrdiff_t ValidMaskEnd = std::distance(Mask.begin(), FirstUndef); + // Do not convert it to vnsrl. The pattern has only 0 or 1 non undef mask. + if (ValidMaskEnd < 2) + return SDValue(); + int Difference = Mask[1] - Mask[0]; + if (Difference <= Mask[0]) + return SDValue(); + // Use vslidedown if Difference is 1. + if (Difference == 1 || !isPowerOf2_32(Difference)) + return SDValue(); + // Make sure it is a narrowing shuffle. The Difference decides the scaling. + if (Src.getSimpleValueType().getVectorElementCount().divideCoefficientBy( + Difference) != VT.getVectorElementCount()) + return SDValue(); + unsigned EltSize = VT.getScalarSizeInBits(); + // The smallest type for vnsrl is i8. + if (EltSize < 8) + return SDValue(); + // Because vnsrl will be used, we need to make sure it will not exceed ELEN. + if (Subtarget.getELEN() < Difference * EltSize) + return SDValue(); + for (ptrdiff_t i = 2; i != ValidMaskEnd; ++i) + if (Mask[i - 1] + Difference != Mask[i]) + return SDValue(); + return Src; +} + static bool isInterleaveShuffle(ArrayRef Mask, MVT VT, bool &SwapSources, const RISCVSubtarget &Subtarget) { // We need to be able to widen elements to the next larger integer type. @@ -2844,84 +2920,55 @@ return Rotation; } -// Lower the following shuffles to vnsrl. -// t34: v8i8 = extract_subvector t11, Constant:i64<0> -// t33: v8i8 = extract_subvector t11, Constant:i64<8> -// a) t35: v8i8 = vector_shuffle<0,2,4,6,8,10,12,14> t34, t33 -// b) t35: v8i8 = vector_shuffle<1,3,5,7,9,11,13,15> t34, t33 -static SDValue lowerVECTOR_SHUFFLEAsVNSRL(const SDLoc &DL, MVT VT, - MVT ContainerVT, SDValue V1, - SDValue V2, SDValue TrueMask, - SDValue VL, ArrayRef Mask, +// Lower any pattern that matches isVnsrlShuffle. +static SDValue lowerVECTOR_SHUFFLEAsVNSRL(const SDLoc &DL, MVT VT, SDValue V1, + SDValue V2, ArrayRef Mask, const RISCVSubtarget &Subtarget, SelectionDAG &DAG) { - // Need to be able to widen the vector. - if (VT.getScalarSizeInBits() >= Subtarget.getELEN()) + SDValue Src = isVnsrlShuffle(V1, V2, Mask, VT, Subtarget); + if (!Src) return SDValue(); - - // Both input must be extracts. - if (V1.getOpcode() != ISD::EXTRACT_SUBVECTOR || - V2.getOpcode() != ISD::EXTRACT_SUBVECTOR) - return SDValue(); - - // Extracting from the same source. - SDValue Src = V1.getOperand(0); - if (Src != V2.getOperand(0)) - return SDValue(); - - // Src needs to have twice the number of elements. - if (Src.getValueType().getVectorNumElements() != (Mask.size() * 2)) - return SDValue(); - - // The extracts must extract the two halves of the source. - if (V1.getConstantOperandVal(1) != 0 || - V2.getConstantOperandVal(1) != Mask.size()) - return SDValue(); - - // First index must be the first even or odd element from V1. - if (Mask[0] != 0 && Mask[0] != 1) - return SDValue(); - - // The others must increase by 2 each time. - // TODO: Support undef elements? - for (unsigned i = 1; i != Mask.size(); ++i) - if (Mask[i] != Mask[i - 1] + 2) - return SDValue(); - - // Convert the source using a container type with twice the elements. Since - // source VT is legal and twice this VT, we know VT isn't LMUL=8 so it is - // safe to double. - MVT DoubleContainerVT = - MVT::getVectorVT(ContainerVT.getVectorElementType(), - ContainerVT.getVectorElementCount() * 2); - Src = convertToScalableVector(DoubleContainerVT, Src, DAG, Subtarget); - - // Convert the vector to a wider integer type with the original element - // count. This also converts FP to int. - unsigned EltBits = ContainerVT.getScalarSizeInBits(); - MVT WideIntEltVT = MVT::getIntegerVT(EltBits * 2); - MVT WideIntContainerVT = - MVT::getVectorVT(WideIntEltVT, ContainerVT.getVectorElementCount()); - Src = DAG.getBitcast(WideIntContainerVT, Src); - - // Convert to the integer version of the container type. - MVT IntEltVT = MVT::getIntegerVT(EltBits); - MVT IntContainerVT = - MVT::getVectorVT(IntEltVT, ContainerVT.getVectorElementCount()); - - // If we want even elements, then the shift amount is 0. Otherwise, shift by - // the original element size. - unsigned Shift = Mask[0] == 0 ? 0 : EltBits; - SDValue SplatShift = DAG.getNode( - RISCVISD::VMV_V_X_VL, DL, IntContainerVT, DAG.getUNDEF(ContainerVT), - DAG.getConstant(Shift, DL, Subtarget.getXLenVT()), VL); - SDValue Res = - DAG.getNode(RISCVISD::VNSRL_VL, DL, IntContainerVT, Src, SplatShift, - DAG.getUNDEF(IntContainerVT), TrueMask, VL); + MVT SrcVT = Src.getSimpleValueType(); + unsigned EltSizeInBits = VT.getScalarSizeInBits(); + unsigned Difference = Mask[1] - Mask[0]; + + // We use SrcVT.getVectorNumElements() instead of VT.getVectorNumElements() + // because Src may be from a extract_subvector (which is twice longer than + // VT). + MVT WidenVT = MVT::getVectorVT(MVT::getIntegerVT(EltSizeInBits * Difference), + SrcVT.getVectorNumElements() / Difference); + MVT ContainerWidenVT = + getContainerForFixedLengthVector(DAG, WidenVT, Subtarget); + // Do bitcast first, then convert it to scalable vector. + SDValue WidenSrc = DAG.getBitcast(WidenVT, Src); + WidenSrc = + convertToScalableVector(ContainerWidenVT, WidenSrc, DAG, Subtarget); + + // TODO: Some pattern has undef. We can shrink VL to get higher performance. + auto [TrueMask, VL] = + getDefaultVLOps(WidenVT, ContainerWidenVT, DL, DAG, Subtarget); + + for (unsigned i = Difference; i != 1; i >>= 1) { + MVT HalfEltSizeContainerWidenVT = + MVT::getVectorVT(MVT::getIntegerVT(EltSizeInBits * i / 2), + WidenSrc.getSimpleValueType().getVectorElementCount()); + // If the Difference is 4, and Mask[0] is 2 (0b10). It means we have to get + // the upper part first, then the lower part. + unsigned Shift = (Mask[0] & (i >> 1)) ? EltSizeInBits * i / 2 : 0; + SDValue SplatShift = + DAG.getNode(RISCVISD::VMV_V_X_VL, DL, HalfEltSizeContainerWidenVT, + DAG.getUNDEF(HalfEltSizeContainerWidenVT), + DAG.getConstant(Shift, DL, Subtarget.getXLenVT()), VL); + WidenSrc = DAG.getNode( + RISCVISD::VNSRL_VL, DL, HalfEltSizeContainerWidenVT, WidenSrc, + SplatShift, DAG.getUNDEF(HalfEltSizeContainerWidenVT), TrueMask, VL); + } + + // In a reverse order. Convert it to scalable vector first, then do bitcast. + SDValue Res = convertFromScalableVector(VT.changeVectorElementTypeToInteger(), + WidenSrc, DAG, Subtarget); // Cast back to FP if needed. - Res = DAG.getBitcast(ContainerVT, Res); - - return convertFromScalableVector(VT, Res, DAG, Subtarget); + return DAG.getBitcast(VT, Res); } // Lower the following shuffle to vslidedown. @@ -3107,8 +3154,8 @@ return convertFromScalableVector(VT, Res, DAG, Subtarget); } - if (SDValue V = lowerVECTOR_SHUFFLEAsVNSRL( - DL, VT, ContainerVT, V1, V2, TrueMask, VL, Mask, Subtarget, DAG)) + if (SDValue V = + lowerVECTOR_SHUFFLEAsVNSRL(DL, VT, V1, V2, Mask, Subtarget, DAG)) return V; // Detect an interleave shuffle and lower to diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shufflevector-vnsrl.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shufflevector-vnsrl.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shufflevector-vnsrl.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shufflevector-vnsrl.ll @@ -372,18 +372,8 @@ ; CHECK-NEXT: vsetivli zero, 16, e8, mf2, ta, ma ; CHECK-NEXT: vle8.v v8, (a0) ; CHECK-NEXT: vsetivli zero, 8, e8, mf4, ta, ma -; CHECK-NEXT: vid.v v9 -; CHECK-NEXT: vadd.vv v9, v9, v9 -; CHECK-NEXT: vadd.vi v10, v9, 1 -; CHECK-NEXT: vrgather.vv v11, v8, v10 -; CHECK-NEXT: li a0, 112 -; CHECK-NEXT: vmv.s.x v0, a0 -; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma -; CHECK-NEXT: vslidedown.vi v8, v8, 8 -; CHECK-NEXT: vsetivli zero, 8, e8, mf4, ta, mu -; CHECK-NEXT: vadd.vi v9, v9, -7 -; CHECK-NEXT: vrgather.vv v11, v8, v9, v0.t -; CHECK-NEXT: vse8.v v11, (a1) +; CHECK-NEXT: vnsrl.wi v8, v8, 8 +; CHECK-NEXT: vse8.v v8, (a1) ; CHECK-NEXT: ret entry: %0 = load <16 x i8>, ptr %in, align 1 @@ -437,16 +427,13 @@ ; V-NEXT: li a2, 32 ; V-NEXT: vsetvli zero, a2, e8, m1, ta, ma ; V-NEXT: vle8.v v8, (a0) -; V-NEXT: li a0, 2 -; V-NEXT: vsetivli zero, 1, e8, mf8, ta, ma -; V-NEXT: vmv.s.x v0, a0 -; V-NEXT: vsetivli zero, 8, e8, m1, ta, ma -; V-NEXT: vslidedown.vi v9, v8, 8 -; V-NEXT: vsetivli zero, 8, e8, mf4, ta, mu -; V-NEXT: vrgather.vi v10, v8, 5 -; V-NEXT: vrgather.vi v10, v9, 5, v0.t -; V-NEXT: vsetivli zero, 4, e8, mf8, ta, ma -; V-NEXT: vse8.v v10, (a1) +; V-NEXT: vsetivli zero, 4, e32, mf2, ta, ma +; V-NEXT: vnsrl.wx v8, v8, a2 +; V-NEXT: vsetvli zero, zero, e16, mf4, ta, ma +; V-NEXT: vnsrl.wi v8, v8, 0 +; V-NEXT: vsetvli zero, zero, e8, mf8, ta, ma +; V-NEXT: vnsrl.wi v8, v8, 8 +; V-NEXT: vse8.v v8, (a1) ; V-NEXT: ret ; ; ZVE32F-LABEL: vnsrl_8_undef_i8: @@ -473,17 +460,25 @@ } define void @vnsrl_2_undef_i16(ptr %in, ptr %out) { -; CHECK-LABEL: vnsrl_2_undef_i16: -; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetivli zero, 16, e16, m1, ta, ma -; CHECK-NEXT: vle16.v v8, (a0) -; CHECK-NEXT: vsetivli zero, 8, e16, mf2, ta, ma -; CHECK-NEXT: vid.v v9 -; CHECK-NEXT: vadd.vv v9, v9, v9 -; CHECK-NEXT: vadd.vi v9, v9, 1 -; CHECK-NEXT: vrgather.vv v10, v8, v9 -; CHECK-NEXT: vse16.v v10, (a1) -; CHECK-NEXT: ret +; V-LABEL: vnsrl_2_undef_i16: +; V: # %bb.0: # %entry +; V-NEXT: vsetivli zero, 16, e16, m1, ta, ma +; V-NEXT: vle16.v v8, (a0) +; V-NEXT: vsetivli zero, 4, e16, mf4, ta, ma +; V-NEXT: vnsrl.wi v8, v8, 16 +; V-NEXT: vsetivli zero, 8, e16, mf2, ta, ma +; V-NEXT: vse16.v v8, (a1) +; V-NEXT: ret +; +; ZVE32F-LABEL: vnsrl_2_undef_i16: +; ZVE32F: # %bb.0: # %entry +; ZVE32F-NEXT: vsetivli zero, 16, e16, m1, ta, ma +; ZVE32F-NEXT: vle16.v v8, (a0) +; ZVE32F-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; ZVE32F-NEXT: vnsrl.wi v8, v8, 16 +; ZVE32F-NEXT: vsetivli zero, 8, e16, mf2, ta, ma +; ZVE32F-NEXT: vse16.v v8, (a1) +; ZVE32F-NEXT: ret entry: %0 = load <16 x i16>, ptr %in, align 2 %1 = shufflevector <16 x i16> %0, <16 x i16> poison, <8 x i32> @@ -496,19 +491,11 @@ ; V: # %bb.0: # %entry ; V-NEXT: vsetivli zero, 16, e16, m1, ta, ma ; V-NEXT: vle16.v v8, (a0) -; V-NEXT: vsetivli zero, 8, e16, mf2, ta, ma -; V-NEXT: vid.v v9 -; V-NEXT: vsll.vi v9, v9, 2 -; V-NEXT: vadd.vi v9, v9, 1 -; V-NEXT: vrgather.vv v10, v8, v9 -; V-NEXT: li a0, 4 -; V-NEXT: vmv.s.x v0, a0 -; V-NEXT: vsetivli zero, 8, e16, m1, ta, ma -; V-NEXT: vslidedown.vi v8, v8, 8 -; V-NEXT: vsetivli zero, 8, e16, mf2, ta, mu -; V-NEXT: vrgather.vi v10, v8, 1, v0.t -; V-NEXT: vsetivli zero, 4, e16, mf4, ta, ma -; V-NEXT: vse16.v v10, (a1) +; V-NEXT: vsetivli zero, 4, e32, mf2, ta, ma +; V-NEXT: vnsrl.wi v8, v8, 0 +; V-NEXT: vsetvli zero, zero, e16, mf4, ta, ma +; V-NEXT: vnsrl.wi v8, v8, 16 +; V-NEXT: vse16.v v8, (a1) ; V-NEXT: ret ; ; ZVE32F-LABEL: vnsrl_4_undef_i16: @@ -537,17 +524,28 @@ } define void @vnsrl_2_undef_i32(ptr %in, ptr %out) { -; CHECK-LABEL: vnsrl_2_undef_i32: -; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetivli zero, 16, e32, m2, ta, ma -; CHECK-NEXT: vle32.v v8, (a0) -; CHECK-NEXT: vsetivli zero, 8, e32, m1, ta, ma -; CHECK-NEXT: vid.v v10 -; CHECK-NEXT: vadd.vv v10, v10, v10 -; CHECK-NEXT: vadd.vi v10, v10, 1 -; CHECK-NEXT: vrgather.vv v11, v8, v10 -; CHECK-NEXT: vse32.v v11, (a1) -; CHECK-NEXT: ret +; V-LABEL: vnsrl_2_undef_i32: +; V: # %bb.0: # %entry +; V-NEXT: vsetivli zero, 16, e32, m2, ta, ma +; V-NEXT: vle32.v v8, (a0) +; V-NEXT: li a0, 32 +; V-NEXT: vsetivli zero, 4, e32, mf2, ta, ma +; V-NEXT: vnsrl.wx v8, v8, a0 +; V-NEXT: vsetivli zero, 8, e32, m1, ta, ma +; V-NEXT: vse32.v v8, (a1) +; V-NEXT: ret +; +; ZVE32F-LABEL: vnsrl_2_undef_i32: +; ZVE32F: # %bb.0: # %entry +; ZVE32F-NEXT: vsetivli zero, 16, e32, m2, ta, ma +; ZVE32F-NEXT: vle32.v v8, (a0) +; ZVE32F-NEXT: vsetivli zero, 8, e32, m1, ta, ma +; ZVE32F-NEXT: vid.v v10 +; ZVE32F-NEXT: vadd.vv v10, v10, v10 +; ZVE32F-NEXT: vadd.vi v10, v10, 1 +; ZVE32F-NEXT: vrgather.vv v11, v8, v10 +; ZVE32F-NEXT: vse32.v v11, (a1) +; ZVE32F-NEXT: ret entry: %0 = load <16 x i32>, ptr %in, align 4 %1 = shufflevector <16 x i32> %0, <16 x i32> poison, <8 x i32> @@ -561,18 +559,8 @@ ; CHECK-NEXT: vsetivli zero, 16, e16, m1, ta, ma ; CHECK-NEXT: vle16.v v8, (a0) ; CHECK-NEXT: vsetivli zero, 8, e16, mf2, ta, ma -; CHECK-NEXT: vid.v v9 -; CHECK-NEXT: vadd.vv v9, v9, v9 -; CHECK-NEXT: vadd.vi v10, v9, 1 -; CHECK-NEXT: vrgather.vv v11, v8, v10 -; CHECK-NEXT: li a0, 112 -; CHECK-NEXT: vmv.s.x v0, a0 -; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma -; CHECK-NEXT: vslidedown.vi v8, v8, 8 -; CHECK-NEXT: vsetivli zero, 8, e16, mf2, ta, mu -; CHECK-NEXT: vadd.vi v9, v9, -7 -; CHECK-NEXT: vrgather.vv v11, v8, v9, v0.t -; CHECK-NEXT: vse16.v v11, (a1) +; CHECK-NEXT: vnsrl.wi v8, v8, 16 +; CHECK-NEXT: vse16.v v8, (a1) ; CHECK-NEXT: ret entry: %0 = load <16 x half>, ptr %in, align 2 @@ -621,17 +609,28 @@ } define void @vnsrl_2_undef_float(ptr %in, ptr %out) { -; CHECK-LABEL: vnsrl_2_undef_float: -; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetivli zero, 16, e32, m2, ta, ma -; CHECK-NEXT: vle32.v v8, (a0) -; CHECK-NEXT: vsetivli zero, 8, e32, m1, ta, ma -; CHECK-NEXT: vid.v v10 -; CHECK-NEXT: vadd.vv v10, v10, v10 -; CHECK-NEXT: vadd.vi v10, v10, 1 -; CHECK-NEXT: vrgather.vv v11, v8, v10 -; CHECK-NEXT: vse32.v v11, (a1) -; CHECK-NEXT: ret +; V-LABEL: vnsrl_2_undef_float: +; V: # %bb.0: # %entry +; V-NEXT: vsetivli zero, 16, e32, m2, ta, ma +; V-NEXT: vle32.v v8, (a0) +; V-NEXT: li a0, 32 +; V-NEXT: vsetivli zero, 4, e32, mf2, ta, ma +; V-NEXT: vnsrl.wx v8, v8, a0 +; V-NEXT: vsetivli zero, 8, e32, m1, ta, ma +; V-NEXT: vse32.v v8, (a1) +; V-NEXT: ret +; +; ZVE32F-LABEL: vnsrl_2_undef_float: +; ZVE32F: # %bb.0: # %entry +; ZVE32F-NEXT: vsetivli zero, 16, e32, m2, ta, ma +; ZVE32F-NEXT: vle32.v v8, (a0) +; ZVE32F-NEXT: vsetivli zero, 8, e32, m1, ta, ma +; ZVE32F-NEXT: vid.v v10 +; ZVE32F-NEXT: vadd.vv v10, v10, v10 +; ZVE32F-NEXT: vadd.vi v10, v10, 1 +; ZVE32F-NEXT: vrgather.vv v11, v8, v10 +; ZVE32F-NEXT: vse32.v v11, (a1) +; ZVE32F-NEXT: ret entry: %0 = load <16 x float>, ptr %in, align 4 %1 = shufflevector <16 x float> %0, <16 x float> poison, <8 x i32>