diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -914,7 +914,7 @@ SDValue WidenVecRes_MLOAD(MaskedLoadSDNode* N); SDValue WidenVecRes_MGATHER(MaskedGatherSDNode* N); SDValue WidenVecRes_ScalarOp(SDNode* N); - SDValue WidenVecRes_SELECT(SDNode* N); + SDValue WidenVecRes_Select(SDNode *N); SDValue WidenVSELECTMask(SDNode *N); SDValue WidenVecRes_SELECT_CC(SDNode* N); SDValue WidenVecRes_SETCC(SDNode* N); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -3047,7 +3047,10 @@ break; case ISD::SIGN_EXTEND_INREG: Res = WidenVecRes_InregOp(N); break; case ISD::VSELECT: - case ISD::SELECT: Res = WidenVecRes_SELECT(N); break; + case ISD::SELECT: + case ISD::VP_SELECT: + Res = WidenVecRes_Select(N); + break; case ISD::SELECT_CC: Res = WidenVecRes_SELECT_CC(N); break; case ISD::SETCC: Res = WidenVecRes_SETCC(N); break; case ISD::UNDEF: Res = WidenVecRes_UNDEF(N); break; @@ -4522,19 +4525,19 @@ return Mask; } -SDValue DAGTypeLegalizer::WidenVecRes_SELECT(SDNode *N) { +SDValue DAGTypeLegalizer::WidenVecRes_Select(SDNode *N) { EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); ElementCount WidenEC = WidenVT.getVectorElementCount(); SDValue Cond1 = N->getOperand(0); EVT CondVT = Cond1.getValueType(); + unsigned Opcode = N->getOpcode(); if (CondVT.isVector()) { if (SDValue WideCond = WidenVSELECTMask(N)) { SDValue InOp1 = GetWidenedVector(N->getOperand(1)); SDValue InOp2 = GetWidenedVector(N->getOperand(2)); assert(InOp1.getValueType() == WidenVT && InOp2.getValueType() == WidenVT); - return DAG.getNode(N->getOpcode(), SDLoc(N), - WidenVT, WideCond, InOp1, InOp2); + return DAG.getNode(Opcode, SDLoc(N), WidenVT, WideCond, InOp1, InOp2); } EVT CondEltVT = CondVT.getVectorElementType(); @@ -4560,8 +4563,10 @@ SDValue InOp1 = GetWidenedVector(N->getOperand(1)); SDValue InOp2 = GetWidenedVector(N->getOperand(2)); assert(InOp1.getValueType() == WidenVT && InOp2.getValueType() == WidenVT); - return DAG.getNode(N->getOpcode(), SDLoc(N), - WidenVT, Cond1, InOp1, InOp2); + return Opcode == ISD::VP_SELECT + ? DAG.getNode(Opcode, SDLoc(N), WidenVT, Cond1, InOp1, InOp2, + N->getOperand(3)) + : DAG.getNode(Opcode, SDLoc(N), WidenVT, Cond1, InOp1, InOp2); } SDValue DAGTypeLegalizer::WidenVecRes_SELECT_CC(SDNode *N) { diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vselect-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vselect-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vselect-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vselect-vp.ll @@ -98,6 +98,18 @@ ret <4 x i8> %v } +declare <5 x i8> @llvm.vp.select.v5i8(<5 x i1>, <5 x i8>, <5 x i8>, i32) + +define <5 x i8> @select_v5i8(<5 x i1> %a, <5 x i8> %b, <5 x i8> %c, i32 zeroext %evl) { +; CHECK-LABEL: select_v5i8: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, mf2, ta, mu +; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0 +; CHECK-NEXT: ret + %v = call <5 x i8> @llvm.vp.select.v5i8(<5 x i1> %a, <5 x i8> %b, <5 x i8> %c, i32 %evl) + ret <5 x i8> %v +} + declare <8 x i8> @llvm.vp.select.v8i8(<8 x i1>, <8 x i8>, <8 x i8>, i32) define <8 x i8> @select_v8i8(<8 x i1> %a, <8 x i8> %b, <8 x i8> %c, i32 zeroext %evl) { diff --git a/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll @@ -150,6 +150,18 @@ ret %v } +declare @llvm.vp.select.nxv14i8(, , , i32) + +define @select_nxv14i8( %a, %b, %c, i32 zeroext %evl) { +; CHECK-LABEL: select_nxv14i8: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, m2, ta, mu +; CHECK-NEXT: vmerge.vvm v8, v10, v8, v0 +; CHECK-NEXT: ret + %v = call @llvm.vp.select.nxv14i8( %a, %b, %c, i32 %evl) + ret %v +} + declare @llvm.vp.select.nxv16i8(, , , i32) define @select_nxv16i8( %a, %b, %c, i32 zeroext %evl) {