diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -4274,13 +4274,60 @@ SelectionDAG &DAG) const { SDLoc DL(Op); MVT VecVT = Op.getSimpleValueType(); + MVT XLenVT = Subtarget.getXLenVT(); SDValue Vec = Op.getOperand(0); SDValue Val = Op.getOperand(1); SDValue Idx = Op.getOperand(2); if (VecVT.getVectorElementType() == MVT::i1) { - // FIXME: For now we just promote to an i8 vector and insert into that, - // but this is probably not optimal. + if (VecVT.isFixedLengthVector()) { + unsigned NumElts = VecVT.getVectorNumElements(); + if (NumElts >= 8) { + MVT WideEltVT; + unsigned WidenVecLen; + SDValue ExtractElementIdx; + SDValue ExtractBitIdx; + unsigned MaxEEW = Subtarget.getMaxELENForFixedLengthVectors(); + MVT LargestEltVT = MVT::getIntegerVT( + std::min(MaxEEW, unsigned(XLenVT.getSizeInBits()))); + if (NumElts <= LargestEltVT.getSizeInBits()) { + assert(isPowerOf2_32(NumElts) && + "the number of elements should be power of 2"); + WideEltVT = MVT::getIntegerVT(NumElts); + WidenVecLen = 1; + ExtractElementIdx = DAG.getConstant(0, DL, XLenVT); + ExtractBitIdx = Idx; + } else { + WideEltVT = LargestEltVT; + WidenVecLen = NumElts / WideEltVT.getSizeInBits(); + // extract element index = index / element width + ExtractElementIdx = DAG.getNode( + ISD::SRL, DL, XLenVT, Idx, + DAG.getConstant(Log2_64(WideEltVT.getSizeInBits()), DL, XLenVT)); + // mask bit index = index % element width + ExtractBitIdx = DAG.getNode( + ISD::AND, DL, XLenVT, Idx, + DAG.getConstant(WideEltVT.getSizeInBits() - 1, DL, XLenVT)); + } + MVT WideVT = MVT::getVectorVT(WideEltVT, WidenVecLen); + Vec = DAG.getNode(ISD::BITCAST, DL, WideVT, Vec); + SDValue ExtractElt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, XLenVT, + Vec, ExtractElementIdx); + // Set the bit and insert back to the widen vector. + SDValue ExtVal = DAG.getNode(ISD::ZERO_EXTEND, DL, XLenVT, Val); + SDValue Xor = DAG.getNode(ISD::XOR, DL, XLenVT, ExtVal, + DAG.getConstant(1, DL, XLenVT)); + SDValue ShiftLeft = + DAG.getNode(ISD::SHL, DL, XLenVT, Xor, ExtractBitIdx); + SDValue Not = DAG.getNode(ISD::XOR, DL, XLenVT, ShiftLeft, + DAG.getConstant(-1, DL, XLenVT)); + SDValue NewElt = DAG.getNode(ISD::OR, DL, XLenVT, ExtractElt, Not); + SDValue NewWidenVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, WideVT, + Vec, NewElt, ExtractElementIdx); + return DAG.getNode(ISD::BITCAST, DL, VecVT, NewWidenVec); + } + } + // Otherwise, promote to an i8 vector and insert to it. MVT WideVT = MVT::getVectorVT(MVT::i8, VecVT.getVectorElementCount()); Vec = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT, Vec); Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, WideVT, Vec, Val, Idx); @@ -4294,8 +4341,6 @@ Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); } - MVT XLenVT = Subtarget.getXLenVT(); - SDValue Zero = DAG.getConstant(0, DL, XLenVT); bool IsLegalInsert = Subtarget.is64Bit() || Val.getValueType() != MVT::i64; // Even i64-element vectors on RV32 can be lowered without scalar diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert-i1.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert-i1.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert-i1.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert-i1.ll @@ -106,15 +106,14 @@ define <8 x i1> @insertelt_v8i1(<8 x i1> %x, i1 %elt) nounwind { ; CHECK-LABEL: insertelt_v8i1: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu -; CHECK-NEXT: vmv.s.x v8, a0 -; CHECK-NEXT: vmv.v.i v9, 0 -; CHECK-NEXT: vmerge.vim v9, v9, 1, v0 -; CHECK-NEXT: vsetivli zero, 2, e8, mf2, tu, mu -; CHECK-NEXT: vslideup.vi v9, v8, 1 -; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu -; CHECK-NEXT: vand.vi v8, v9, 1 -; CHECK-NEXT: vmsne.vi v0, v8, 0 +; CHECK-NEXT: xori a0, a0, 1 +; CHECK-NEXT: slli a0, a0, 1 +; CHECK-NEXT: not a0, a0 +; CHECK-NEXT: vsetivli zero, 0, e8, mf8, ta, mu +; CHECK-NEXT: vmv.x.s a1, v0 +; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: vsetivli zero, 1, e8, mf8, tu, mu +; CHECK-NEXT: vmv.s.x v0, a0 ; CHECK-NEXT: ret %y = insertelement <8 x i1> %x, i1 %elt, i64 1 ret <8 x i1> %y @@ -123,50 +122,56 @@ define <8 x i1> @insertelt_idx_v8i1(<8 x i1> %x, i1 %elt, i32 zeroext %idx) nounwind { ; RV32-LABEL: insertelt_idx_v8i1: ; RV32: # %bb.0: -; RV32-NEXT: vsetivli zero, 8, e8, mf2, ta, mu -; RV32-NEXT: vmv.s.x v8, a0 -; RV32-NEXT: vmv.v.i v9, 0 -; RV32-NEXT: vmerge.vim v9, v9, 1, v0 -; RV32-NEXT: addi a0, a1, 1 -; RV32-NEXT: vsetvli zero, a0, e8, mf2, tu, mu -; RV32-NEXT: vslideup.vx v9, v8, a1 -; RV32-NEXT: vsetivli zero, 8, e8, mf2, ta, mu -; RV32-NEXT: vand.vi v8, v9, 1 -; RV32-NEXT: vmsne.vi v0, v8, 0 +; RV32-NEXT: xori a0, a0, 1 +; RV32-NEXT: sll a0, a0, a1 +; RV32-NEXT: not a0, a0 +; RV32-NEXT: vsetivli zero, 0, e8, mf8, ta, mu +; RV32-NEXT: vmv.x.s a1, v0 +; RV32-NEXT: or a0, a1, a0 +; RV32-NEXT: vsetivli zero, 1, e8, mf8, tu, mu +; RV32-NEXT: vmv.s.x v0, a0 ; RV32-NEXT: ret ; ; RV64-LABEL: insertelt_idx_v8i1: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 8, e8, mf2, ta, mu -; RV64-NEXT: vmv.s.x v8, a0 -; RV64-NEXT: vmv.v.i v9, 0 -; RV64-NEXT: vmerge.vim v9, v9, 1, v0 -; RV64-NEXT: sext.w a0, a1 -; RV64-NEXT: addi a1, a0, 1 -; RV64-NEXT: vsetvli zero, a1, e8, mf2, tu, mu -; RV64-NEXT: vslideup.vx v9, v8, a0 -; RV64-NEXT: vsetivli zero, 8, e8, mf2, ta, mu -; RV64-NEXT: vand.vi v8, v9, 1 -; RV64-NEXT: vmsne.vi v0, v8, 0 +; RV64-NEXT: sext.w a1, a1 +; RV64-NEXT: xori a0, a0, 1 +; RV64-NEXT: sll a0, a0, a1 +; RV64-NEXT: not a0, a0 +; RV64-NEXT: vsetivli zero, 0, e8, mf8, ta, mu +; RV64-NEXT: vmv.x.s a1, v0 +; RV64-NEXT: or a0, a1, a0 +; RV64-NEXT: vsetivli zero, 1, e8, mf8, tu, mu +; RV64-NEXT: vmv.s.x v0, a0 ; RV64-NEXT: ret %y = insertelement <8 x i1> %x, i1 %elt, i32 %idx ret <8 x i1> %y } define <64 x i1> @insertelt_v64i1(<64 x i1> %x, i1 %elt) nounwind { -; CHECK-LABEL: insertelt_v64i1: -; CHECK: # %bb.0: -; CHECK-NEXT: li a1, 64 -; CHECK-NEXT: vsetvli zero, a1, e8, m4, ta, mu -; CHECK-NEXT: vmv.s.x v8, a0 -; CHECK-NEXT: vmv.v.i v12, 0 -; CHECK-NEXT: vmerge.vim v12, v12, 1, v0 -; CHECK-NEXT: vsetivli zero, 2, e8, m4, tu, mu -; CHECK-NEXT: vslideup.vi v12, v8, 1 -; CHECK-NEXT: vsetvli zero, a1, e8, m4, ta, mu -; CHECK-NEXT: vand.vi v8, v12, 1 -; CHECK-NEXT: vmsne.vi v0, v8, 0 -; CHECK-NEXT: ret +; RV32-LABEL: insertelt_v64i1: +; RV32: # %bb.0: +; RV32-NEXT: xori a0, a0, 1 +; RV32-NEXT: slli a0, a0, 1 +; RV32-NEXT: not a0, a0 +; RV32-NEXT: vsetivli zero, 0, e32, mf2, ta, mu +; RV32-NEXT: vmv.x.s a1, v0 +; RV32-NEXT: or a0, a1, a0 +; RV32-NEXT: vsetivli zero, 2, e32, mf2, tu, mu +; RV32-NEXT: vmv.s.x v0, a0 +; RV32-NEXT: ret +; +; RV64-LABEL: insertelt_v64i1: +; RV64: # %bb.0: +; RV64-NEXT: xori a0, a0, 1 +; RV64-NEXT: slli a0, a0, 1 +; RV64-NEXT: not a0, a0 +; RV64-NEXT: vsetivli zero, 0, e64, m1, ta, mu +; RV64-NEXT: vmv.x.s a1, v0 +; RV64-NEXT: or a0, a1, a0 +; RV64-NEXT: vsetivli zero, 1, e64, m1, tu, mu +; RV64-NEXT: vmv.s.x v0, a0 +; RV64-NEXT: ret %y = insertelement <64 x i1> %x, i1 %elt, i64 1 ret <64 x i1> %y } @@ -174,33 +179,31 @@ define <64 x i1> @insertelt_idx_v64i1(<64 x i1> %x, i1 %elt, i32 zeroext %idx) nounwind { ; RV32-LABEL: insertelt_idx_v64i1: ; RV32: # %bb.0: -; RV32-NEXT: li a2, 64 -; RV32-NEXT: vsetvli zero, a2, e8, m4, ta, mu +; RV32-NEXT: srli a2, a1, 5 +; RV32-NEXT: vsetivli zero, 1, e32, mf2, ta, mu +; RV32-NEXT: vslidedown.vx v8, v0, a2 +; RV32-NEXT: vmv.x.s a3, v8 +; RV32-NEXT: xori a0, a0, 1 +; RV32-NEXT: sll a0, a0, a1 +; RV32-NEXT: not a0, a0 +; RV32-NEXT: or a0, a3, a0 ; RV32-NEXT: vmv.s.x v8, a0 -; RV32-NEXT: vmv.v.i v12, 0 -; RV32-NEXT: vmerge.vim v12, v12, 1, v0 -; RV32-NEXT: addi a0, a1, 1 -; RV32-NEXT: vsetvli zero, a0, e8, m4, tu, mu -; RV32-NEXT: vslideup.vx v12, v8, a1 -; RV32-NEXT: vsetvli zero, a2, e8, m4, ta, mu -; RV32-NEXT: vand.vi v8, v12, 1 -; RV32-NEXT: vmsne.vi v0, v8, 0 +; RV32-NEXT: addi a0, a2, 1 +; RV32-NEXT: vsetvli zero, a0, e32, mf2, tu, mu +; RV32-NEXT: vslideup.vx v0, v8, a2 ; RV32-NEXT: ret ; ; RV64-LABEL: insertelt_idx_v64i1: ; RV64: # %bb.0: -; RV64-NEXT: li a2, 64 -; RV64-NEXT: vsetvli zero, a2, e8, m4, ta, mu -; RV64-NEXT: vmv.s.x v8, a0 -; RV64-NEXT: vmv.v.i v12, 0 -; RV64-NEXT: vmerge.vim v12, v12, 1, v0 -; RV64-NEXT: sext.w a0, a1 -; RV64-NEXT: addi a1, a0, 1 -; RV64-NEXT: vsetvli zero, a1, e8, m4, tu, mu -; RV64-NEXT: vslideup.vx v12, v8, a0 -; RV64-NEXT: vsetvli zero, a2, e8, m4, ta, mu -; RV64-NEXT: vand.vi v8, v12, 1 -; RV64-NEXT: vmsne.vi v0, v8, 0 +; RV64-NEXT: sext.w a1, a1 +; RV64-NEXT: xori a0, a0, 1 +; RV64-NEXT: sll a0, a0, a1 +; RV64-NEXT: not a0, a0 +; RV64-NEXT: vsetivli zero, 0, e64, m1, ta, mu +; RV64-NEXT: vmv.x.s a1, v0 +; RV64-NEXT: or a0, a1, a0 +; RV64-NEXT: vsetivli zero, 1, e64, m1, tu, mu +; RV64-NEXT: vmv.s.x v0, a0 ; RV64-NEXT: ret %y = insertelement <64 x i1> %x, i1 %elt, i32 %idx ret <64 x i1> %y