diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -75,8 +75,11 @@ break; case ISD::MGATHER: Res = PromoteIntRes_MGATHER(cast(N)); break; - case ISD::SELECT: Res = PromoteIntRes_SELECT(N); break; - case ISD::VSELECT: Res = PromoteIntRes_VSELECT(N); break; + case ISD::SELECT: + case ISD::VSELECT: + case ISD::VP_SELECT: + Res = PromoteIntRes_Select(N); + break; case ISD::SELECT_CC: Res = PromoteIntRes_SELECT_CC(N); break; case ISD::STRICT_FSETCC: case ISD::STRICT_FSETCCS: @@ -1127,20 +1130,18 @@ return Res; } -SDValue DAGTypeLegalizer::PromoteIntRes_SELECT(SDNode *N) { - SDValue LHS = GetPromotedInteger(N->getOperand(1)); - SDValue RHS = GetPromotedInteger(N->getOperand(2)); - return DAG.getSelect(SDLoc(N), - LHS.getValueType(), N->getOperand(0), LHS, RHS); -} - -SDValue DAGTypeLegalizer::PromoteIntRes_VSELECT(SDNode *N) { +SDValue DAGTypeLegalizer::PromoteIntRes_Select(SDNode *N) { SDValue Mask = N->getOperand(0); SDValue LHS = GetPromotedInteger(N->getOperand(1)); SDValue RHS = GetPromotedInteger(N->getOperand(2)); - return DAG.getNode(ISD::VSELECT, SDLoc(N), - LHS.getValueType(), Mask, LHS, RHS); + + unsigned Opcode = N->getOpcode(); + return Opcode == ISD::VP_SELECT + ? DAG.getNode(Opcode, SDLoc(N), LHS.getValueType(), Mask, LHS, RHS, + N->getOperand(3)) + : DAG.getNode(Opcode, SDLoc(N), LHS.getValueType(), Mask, LHS, + RHS); } SDValue DAGTypeLegalizer::PromoteIntRes_SELECT_CC(SDNode *N) { diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -334,8 +334,7 @@ SDValue PromoteIntRes_MGATHER(MaskedGatherSDNode *N); SDValue PromoteIntRes_Overflow(SDNode *N); SDValue PromoteIntRes_SADDSUBO(SDNode *N, unsigned ResNo); - SDValue PromoteIntRes_SELECT(SDNode *N); - SDValue PromoteIntRes_VSELECT(SDNode *N); + SDValue PromoteIntRes_Select(SDNode *N); SDValue PromoteIntRes_SELECT_CC(SDNode *N); SDValue PromoteIntRes_SETCC(SDNode *N); SDValue PromoteIntRes_SHL(SDNode *N, bool IsVP); diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vselect-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vselect-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vselect-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vselect-vp.ll @@ -74,6 +74,18 @@ ret <16 x i1> %v } +declare <8 x i7> @llvm.vp.select.v8i7(<8 x i1>, <8 x i7>, <8 x i7>, i32) + +define <8 x i7> @select_v8i7(<8 x i1> %a, <8 x i7> %b, <8 x i7> %c, i32 zeroext %evl) { +; CHECK-LABEL: select_v8i7: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, mf2, ta, mu +; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0 +; CHECK-NEXT: ret + %v = call <8 x i7> @llvm.vp.select.v8i7(<8 x i1> %a, <8 x i7> %b, <8 x i7> %c, i32 %evl) + ret <8 x i7> %v +} + declare <2 x i8> @llvm.vp.select.v2i8(<2 x i1>, <2 x i8>, <2 x i8>, i32) define <2 x i8> @select_v2i8(<2 x i1> %a, <2 x i8> %b, <2 x i8> %c, i32 zeroext %evl) { diff --git a/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll @@ -102,6 +102,18 @@ ret %v } +declare @llvm.vp.select.nxv8i7(, , , i32) + +define @select_nxv8i7( %a, %b, %c, i32 zeroext %evl) { +; CHECK-LABEL: select_nxv8i7: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, m1, ta, mu +; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0 +; CHECK-NEXT: ret + %v = call @llvm.vp.select.nxv8i7( %a, %b, %c, i32 %evl) + ret %v +} + declare @llvm.vp.select.nxv1i8(, , , i32) define @select_nxv1i8( %a, %b, %c, i32 zeroext %evl) {