diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1817,8 +1817,10 @@ SDValue getSplatSourceVector(SDValue V, int &SplatIndex); /// If V is a splat vector, return its scalar source operand by extracting - /// that element from the source vector. - SDValue getSplatValue(SDValue V); + /// that element from the source vector. If LegalTypes is true, this method + /// may only return a legally-typed splat value. If it cannot legalize the + /// splatted value it will return SDValue(). + SDValue getSplatValue(SDValue V, bool LegalTypes = false); /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount /// that is less than the element bit-width of the shift node, return it. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2636,12 +2636,21 @@ return SDValue(); } -SDValue SelectionDAG::getSplatValue(SDValue V) { +SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) { int SplatIdx; - if (SDValue SrcVector = getSplatSourceVector(V, SplatIdx)) - return getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), - SrcVector.getValueType().getScalarType(), SrcVector, + if (SDValue SrcVector = getSplatSourceVector(V, SplatIdx)) { + EVT SVT = SrcVector.getValueType().getScalarType(); + EVT LegalSVT = SVT; + if (LegalTypes && !TLI->isTypeLegal(SVT)) { + if (!SVT.isInteger()) + return SDValue(); + LegalSVT = TLI->getTypeToTransformTo(*getContext(), LegalSVT); + if (LegalSVT.bitsLT(SVT)) + return SDValue(); + } + return getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), LegalSVT, SrcVector, getVectorIdxConstant(SplatIdx, SDLoc(V))); + } return SDValue(); } diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1849,7 +1849,7 @@ SDValue Gather; // TODO: This doesn't trigger for i64 vectors on RV32, since there we // encounter a bitcasted BUILD_VECTOR with low/high i32 values. - if (SDValue SplatValue = DAG.getSplatValue(V1)) { + if (SDValue SplatValue = DAG.getSplatValue(V1, /*LegalTypes*/ true)) { Gather = lowerScalarSplat(SplatValue, VL, ContainerVT, DL, DAG, Subtarget); } else { SDValue LHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesLHS); diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-shuffles.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-shuffles.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-shuffles.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-shuffles.ll @@ -347,3 +347,29 @@ %s = shufflevector <8 x i64> %x, <8 x i64> , <8 x i32> ret <8 x i64> %s } + +define <4 x i8> @interleave_shuffles(<4 x i8> %x) { +; CHECK-LABEL: interleave_shuffles: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, zero, e8,mf4,ta,mu +; CHECK-NEXT: vmv.x.s a0, v8 +; CHECK-NEXT: vsetivli a1, 4, e8,mf4,ta,mu +; CHECK-NEXT: vrgather.vi v25, v8, 1 +; CHECK-NEXT: addi a1, zero, 1 +; CHECK-NEXT: vmv.s.x v26, a1 +; CHECK-NEXT: vmv.v.i v27, 0 +; CHECK-NEXT: vsetivli a1, 4, e8,mf4,tu,mu +; CHECK-NEXT: vslideup.vi v27, v26, 3 +; CHECK-NEXT: addi a1, zero, 10 +; CHECK-NEXT: vsetivli a2, 1, e8,mf8,ta,mu +; CHECK-NEXT: vmv.s.x v0, a1 +; CHECK-NEXT: vsetivli a1, 4, e8,mf4,ta,mu +; CHECK-NEXT: vmv.v.x v8, a0 +; CHECK-NEXT: vsetivli a0, 4, e8,mf4,tu,mu +; CHECK-NEXT: vrgather.vv v8, v25, v27, v0.t +; CHECK-NEXT: ret + %y = shufflevector <4 x i8> %x, <4 x i8> undef, <4 x i32> + %z = shufflevector <4 x i8> %x, <4 x i8> undef, <4 x i32> + %w = shufflevector <4 x i8> %y, <4 x i8> %z, <4 x i32> + ret <4 x i8> %w +}