diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -75,8 +75,10 @@ break; case ISD::MGATHER: Res = PromoteIntRes_MGATHER(cast(N)); break; - case ISD::SELECT: Res = PromoteIntRes_SELECT(N); break; - case ISD::VSELECT: Res = PromoteIntRes_VSELECT(N); break; + case ISD::SELECT: + case ISD::VSELECT: + Res = PromoteIntRes_SELECT(N, /*IsVP*/ false); + break; case ISD::SELECT_CC: Res = PromoteIntRes_SELECT_CC(N); break; case ISD::STRICT_FSETCC: case ISD::STRICT_FSETCCS: @@ -235,6 +237,10 @@ Res = PromoteIntRes_VECREDUCE(N); break; + case ISD::VP_SELECT: + Res = PromoteIntRes_SELECT(N, /*IsVP*/ true); + break; + case ISD::VP_REDUCE_ADD: case ISD::VP_REDUCE_MUL: case ISD::VP_REDUCE_AND: @@ -1127,20 +1133,16 @@ return Res; } -SDValue DAGTypeLegalizer::PromoteIntRes_SELECT(SDNode *N) { - SDValue LHS = GetPromotedInteger(N->getOperand(1)); - SDValue RHS = GetPromotedInteger(N->getOperand(2)); - return DAG.getSelect(SDLoc(N), - LHS.getValueType(), N->getOperand(0), LHS, RHS); -} - -SDValue DAGTypeLegalizer::PromoteIntRes_VSELECT(SDNode *N) { +SDValue DAGTypeLegalizer::PromoteIntRes_SELECT(SDNode *N, bool IsVP) { SDValue Mask = N->getOperand(0); SDValue LHS = GetPromotedInteger(N->getOperand(1)); SDValue RHS = GetPromotedInteger(N->getOperand(2)); - return DAG.getNode(ISD::VSELECT, SDLoc(N), - LHS.getValueType(), Mask, LHS, RHS); + + if (!IsVP) + return DAG.getSelect(SDLoc(N), LHS.getValueType(), Mask, LHS, RHS); + return DAG.getNode(N->getOpcode(), SDLoc(N), + LHS.getValueType(), Mask, LHS, RHS, N->getOperand(3)); } SDValue DAGTypeLegalizer::PromoteIntRes_SELECT_CC(SDNode *N) { diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -334,8 +334,7 @@ SDValue PromoteIntRes_MGATHER(MaskedGatherSDNode *N); SDValue PromoteIntRes_Overflow(SDNode *N); SDValue PromoteIntRes_SADDSUBO(SDNode *N, unsigned ResNo); - SDValue PromoteIntRes_SELECT(SDNode *N); - SDValue PromoteIntRes_VSELECT(SDNode *N); + SDValue PromoteIntRes_SELECT(SDNode *N, bool IsVP); SDValue PromoteIntRes_SELECT_CC(SDNode *N); SDValue PromoteIntRes_SETCC(SDNode *N); SDValue PromoteIntRes_SHL(SDNode *N, bool IsVP); diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vselect-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vselect-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vselect-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vselect-vp.ll @@ -4,6 +4,18 @@ ; RUN: llc -mtriple=riscv64 -mattr=+d,+experimental-zfh,+experimental-v -target-abi=lp64d -riscv-v-vector-bits-min=128 \ ; RUN: -verify-machineinstrs < %s | FileCheck %s +declare <8 x i7> @llvm.vp.select.v8i7(<8 x i1>, <8 x i7>, <8 x i7>, i32) + +define <8 x i7> @select_v8i7(<8 x i1> %a, <8 x i7> %b, <8 x i7> %c, i32 zeroext %evl) { +; CHECK-LABEL: select_v8i7: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, mf2, ta, mu +; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0 +; CHECK-NEXT: ret + %v = call <8 x i7> @llvm.vp.select.v8i7(<8 x i1> %a, <8 x i7> %b, <8 x i7> %c, i32 %evl) + ret <8 x i7> %v +} + declare <2 x i8> @llvm.vp.select.v2i8(<2 x i1>, <2 x i8>, <2 x i8>, i32) define <2 x i8> @select_v2i8(<2 x i1> %a, <2 x i8> %b, <2 x i8> %c, i32 zeroext %evl) { diff --git a/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll @@ -4,6 +4,18 @@ ; RUN: llc -mtriple=riscv64 -mattr=+d,+experimental-zfh,+experimental-v -target-abi=lp64d \ ; RUN: -verify-machineinstrs < %s | FileCheck %s +declare @llvm.vp.select.nxv8i7(, , , i32) + +define @select_nxv8i7( %a, %b, %c, i32 zeroext %evl) { +; CHECK-LABEL: select_nxv8i7: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, m1, ta, mu +; CHECK-NEXT: vmerge.vvm v8, v9, v8, v0 +; CHECK-NEXT: ret + %v = call @llvm.vp.select.nxv8i7( %a, %b, %c, i32 %evl) + ret %v +} + declare @llvm.vp.select.nxv1i8(, , , i32) define @select_nxv1i8( %a, %b, %c, i32 zeroext %evl) {