diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -560,6 +560,10 @@ setOperationAction(ISD::UMAX, VT, Custom); setOperationAction(ISD::VSELECT, VT, Custom); + + setOperationAction(ISD::ANY_EXTEND, VT, Custom); + setOperationAction(ISD::SIGN_EXTEND, VT, Custom); + setOperationAction(ISD::ZERO_EXTEND, VT, Custom); } for (MVT VT : MVT::fp_fixedlen_vector_valuetypes()) { @@ -1741,32 +1745,53 @@ SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG, int64_t ExtTrueVal) const { SDLoc DL(Op); - EVT VecVT = Op.getValueType(); + MVT VecVT = Op.getSimpleValueType(); SDValue Src = Op.getOperand(0); // Only custom-lower extensions from mask types if (!Src.getValueType().isVector() || Src.getValueType().getVectorElementType() != MVT::i1) return Op; - // Be careful not to introduce illegal scalar types at this stage, and be - // careful also about splatting constants as on RV32, vXi64 SPLAT_VECTOR is - // illegal and must be expanded. Since we know that the constants are - // sign-extended 32-bit values, we use SPLAT_VECTOR_I64 directly. - bool IsRV32E64 = - !Subtarget.is64Bit() && VecVT.getVectorElementType() == MVT::i64; - SDValue SplatZero = DAG.getConstant(0, DL, Subtarget.getXLenVT()); - SDValue SplatTrueVal = DAG.getConstant(ExtTrueVal, DL, Subtarget.getXLenVT()); + MVT XLenVT = Subtarget.getXLenVT(); + SDValue SplatZero = DAG.getConstant(0, DL, XLenVT); + SDValue SplatTrueVal = DAG.getConstant(ExtTrueVal, DL, XLenVT); + + if (VecVT.isScalableVector()) { + // Be careful not to introduce illegal scalar types at this stage, and be + // careful also about splatting constants as on RV32, vXi64 SPLAT_VECTOR is + // illegal and must be expanded. Since we know that the constants are + // sign-extended 32-bit values, we use SPLAT_VECTOR_I64 directly. + bool IsRV32E64 = + !Subtarget.is64Bit() && VecVT.getVectorElementType() == MVT::i64; + + if (!IsRV32E64) { + SplatZero = DAG.getSplatVector(VecVT, DL, SplatZero); + SplatTrueVal = DAG.getSplatVector(VecVT, DL, SplatTrueVal); + } else { + SplatZero = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, SplatZero); + SplatTrueVal = + DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, SplatTrueVal); + } - if (!IsRV32E64) { - SplatZero = DAG.getSplatVector(VecVT, DL, SplatZero); - SplatTrueVal = DAG.getSplatVector(VecVT, DL, SplatTrueVal); - } else { - SplatZero = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, SplatZero); - SplatTrueVal = - DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, SplatTrueVal); + return DAG.getNode(ISD::VSELECT, DL, VecVT, Src, SplatTrueVal, SplatZero); } - return DAG.getNode(ISD::VSELECT, DL, VecVT, Src, SplatTrueVal, SplatZero); + MVT ContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget); + MVT I1ContainerVT = + MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); + + SDValue CC = convertToScalableVector(I1ContainerVT, Src, DAG, Subtarget); + + SDValue Mask, VL; + std::tie(Mask, VL) = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget); + + SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, SplatZero, VL); + SplatTrueVal = + DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, SplatTrueVal, VL); + SDValue Select = DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, CC, + SplatTrueVal, SplatZero, VL); + + return convertFromScalableVector(VecVT, Select, DAG, Subtarget); } // Custom-lower truncations from vectors to mask vectors by using a mask and a diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-setcc.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-setcc.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-setcc.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-setcc.ll @@ -6,37 +6,43 @@ ; stores is calculated assuming byte elements. We need to deal with mismatched ; subvector "casts" to make other elements work. -define void @seteq_vv_v16i8(<16 x i8>* %x, <16 x i8>* %y, <16 x i1>* %z) { +define void @seteq_vv_v16i8(<16 x i8>* %x, <16 x i8>* %y) { ; CHECK-LABEL: seteq_vv_v16i8: ; CHECK: # %bb.0: -; CHECK-NEXT: addi a3, zero, 16 -; CHECK-NEXT: vsetvli a3, a3, e8,m1,ta,mu +; CHECK-NEXT: addi a2, zero, 16 +; CHECK-NEXT: vsetvli a2, a2, e8,m1,ta,mu ; CHECK-NEXT: vle8.v v25, (a0) ; CHECK-NEXT: vle8.v v26, (a1) -; CHECK-NEXT: vmseq.vv v27, v25, v26 -; CHECK-NEXT: vse1.v v27, (a2) +; CHECK-NEXT: vmseq.vv v0, v25, v26 +; CHECK-NEXT: vmv.v.i v25, 0 +; CHECK-NEXT: vmerge.vim v25, v25, -1, v0 +; CHECK-NEXT: vse8.v v25, (a0) ; CHECK-NEXT: ret %a = load <16 x i8>, <16 x i8>* %x %b = load <16 x i8>, <16 x i8>* %y %c = icmp eq <16 x i8> %a, %b - store <16 x i1> %c, <16 x i1>* %z + %d = sext <16 x i1> %c to <16 x i8> + store <16 x i8> %d, <16 x i8>* %x ret void } -define void @setne_vv_v32i8(<32 x i8>* %x, <32 x i8>* %y, <32 x i1>* %z) { +define void @setne_vv_v32i8(<32 x i8>* %x, <32 x i8>* %y) { ; CHECK-LABEL: setne_vv_v32i8: ; CHECK: # %bb.0: -; CHECK-NEXT: addi a3, zero, 32 -; CHECK-NEXT: vsetvli a3, a3, e8,m2,ta,mu +; CHECK-NEXT: addi a2, zero, 32 +; CHECK-NEXT: vsetvli a2, a2, e8,m2,ta,mu ; CHECK-NEXT: vle8.v v26, (a0) ; CHECK-NEXT: vle8.v v28, (a1) -; CHECK-NEXT: vmsne.vv v25, v26, v28 -; CHECK-NEXT: vse1.v v25, (a2) +; CHECK-NEXT: vmsne.vv v0, v26, v28 +; CHECK-NEXT: vmv.v.i v26, 0 +; CHECK-NEXT: vmerge.vim v26, v26, 1, v0 +; CHECK-NEXT: vse8.v v26, (a0) ; CHECK-NEXT: ret %a = load <32 x i8>, <32 x i8>* %x %b = load <32 x i8>, <32 x i8>* %y %c = icmp ne <32 x i8> %a, %b - store <32 x i1> %c, <32 x i1>* %z + %d = zext <32 x i1> %c to <32 x i8> + store <32 x i8> %d, <32 x i8>* %x ret void } @@ -693,3 +699,123 @@ store <8 x i1> %d, <8 x i1>* %z ret void } + +define void @seteq_vv_v8i16(<8 x i16>* %x, <8 x i16>* %y) { +; CHECK-LABEL: seteq_vv_v8i16: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a2, zero, 8 +; CHECK-NEXT: vsetvli a2, a2, e16,m1,ta,mu +; CHECK-NEXT: vle16.v v25, (a0) +; CHECK-NEXT: vle16.v v26, (a1) +; CHECK-NEXT: vmseq.vv v0, v25, v26 +; CHECK-NEXT: vmv.v.i v25, 0 +; CHECK-NEXT: vmerge.vim v25, v25, -1, v0 +; CHECK-NEXT: vse16.v v25, (a0) +; CHECK-NEXT: ret + %a = load <8 x i16>, <8 x i16>* %x + %b = load <8 x i16>, <8 x i16>* %y + %c = icmp eq <8 x i16> %a, %b + %d = sext <8 x i1> %c to <8 x i16> + store <8 x i16> %d, <8 x i16>* %x + ret void +} + +define void @setne_vv_v4i32(<4 x i32>* %x, <4 x i32>* %y) { +; CHECK-LABEL: setne_vv_v4i32: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a2, zero, 4 +; CHECK-NEXT: vsetvli a2, a2, e32,m1,ta,mu +; CHECK-NEXT: vle32.v v25, (a0) +; CHECK-NEXT: vle32.v v26, (a1) +; CHECK-NEXT: vmsne.vv v0, v25, v26 +; CHECK-NEXT: vmv.v.i v25, 0 +; CHECK-NEXT: vmerge.vim v25, v25, -1, v0 +; CHECK-NEXT: vse32.v v25, (a0) +; CHECK-NEXT: ret + %a = load <4 x i32>, <4 x i32>* %x + %b = load <4 x i32>, <4 x i32>* %y + %c = icmp ne <4 x i32> %a, %b + %d = sext <4 x i1> %c to <4 x i32> + store <4 x i32> %d, <4 x i32>* %x + ret void +} + +define void @setgt_vv_v2i64(<2 x i64>* %x, <2 x i64>* %y) { +; CHECK-LABEL: setgt_vv_v2i64: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a2, zero, 2 +; CHECK-NEXT: vsetvli a2, a2, e64,m1,ta,mu +; CHECK-NEXT: vle64.v v25, (a0) +; CHECK-NEXT: vle64.v v26, (a1) +; CHECK-NEXT: vmslt.vv v0, v26, v25 +; CHECK-NEXT: vmv.v.i v25, 0 +; CHECK-NEXT: vmerge.vim v25, v25, -1, v0 +; CHECK-NEXT: vse64.v v25, (a0) +; CHECK-NEXT: ret + %a = load <2 x i64>, <2 x i64>* %x + %b = load <2 x i64>, <2 x i64>* %y + %c = icmp sgt <2 x i64> %a, %b + %d = sext <2 x i1> %c to <2 x i64> + store <2 x i64> %d, <2 x i64>* %x + ret void +} + +define void @setlt_vv_v16i16(<16 x i16>* %x, <16 x i16>* %y) { +; CHECK-LABEL: setlt_vv_v16i16: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a2, zero, 16 +; CHECK-NEXT: vsetvli a2, a2, e16,m2,ta,mu +; CHECK-NEXT: vle16.v v26, (a0) +; CHECK-NEXT: vle16.v v28, (a1) +; CHECK-NEXT: vmslt.vv v0, v26, v28 +; CHECK-NEXT: vmv.v.i v26, 0 +; CHECK-NEXT: vmerge.vim v26, v26, 1, v0 +; CHECK-NEXT: vse16.v v26, (a0) +; CHECK-NEXT: ret + %a = load <16 x i16>, <16 x i16>* %x + %b = load <16 x i16>, <16 x i16>* %y + %c = icmp slt <16 x i16> %a, %b + %d = zext <16 x i1> %c to <16 x i16> + store <16 x i16> %d, <16 x i16>* %x + ret void +} + +define void @setugt_vv_v8i32(<8 x i32>* %x, <8 x i32>* %y) { +; CHECK-LABEL: setugt_vv_v8i32: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a2, zero, 8 +; CHECK-NEXT: vsetvli a2, a2, e32,m2,ta,mu +; CHECK-NEXT: vle32.v v26, (a0) +; CHECK-NEXT: vle32.v v28, (a1) +; CHECK-NEXT: vmsltu.vv v0, v28, v26 +; CHECK-NEXT: vmv.v.i v26, 0 +; CHECK-NEXT: vmerge.vim v26, v26, 1, v0 +; CHECK-NEXT: vse32.v v26, (a0) +; CHECK-NEXT: ret + %a = load <8 x i32>, <8 x i32>* %x + %b = load <8 x i32>, <8 x i32>* %y + %c = icmp ugt <8 x i32> %a, %b + %d = zext <8 x i1> %c to <8 x i32> + store <8 x i32> %d, <8 x i32>* %x + ret void +} + +define void @setult_vv_v4i64(<4 x i64>* %x, <4 x i64>* %y) { +; CHECK-LABEL: setult_vv_v4i64: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a2, zero, 4 +; CHECK-NEXT: vsetvli a2, a2, e64,m2,ta,mu +; CHECK-NEXT: vle64.v v26, (a0) +; CHECK-NEXT: vle64.v v28, (a1) +; CHECK-NEXT: vmsltu.vv v0, v26, v28 +; CHECK-NEXT: vmv.v.i v26, 0 +; CHECK-NEXT: vmerge.vim v26, v26, 1, v0 +; CHECK-NEXT: vse64.v v26, (a0) +; CHECK-NEXT: ret + %a = load <4 x i64>, <4 x i64>* %x + %b = load <4 x i64>, <4 x i64>* %y + %c = icmp ult <4 x i64> %a, %b + %d = zext <4 x i1> %c to <4 x i64> + store <4 x i64> %d, <4 x i64>* %x + ret void +}