diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -12952,7 +12952,7 @@ if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT)) return false; - return (Index == 0 || Index == ResVT.getVectorNumElements()); + return (Index == 0 || Index == ResVT.getVectorMinNumElements()); } /// Turn vector tests of the signbit in the form of: @@ -14312,12 +14312,43 @@ static SDValue performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + SDLoc DL(N); SDValue Vec = N->getOperand(0); SDValue SubVec = N->getOperand(1); uint64_t IdxVal = N->getConstantOperandVal(2); EVT VecVT = Vec.getValueType(); EVT SubVT = SubVec.getValueType(); + if (VecVT.isScalableVector() && + DAG.getTargetLoweringInfo().isTypeLegal(VecVT) && + DAG.getTargetLoweringInfo().isTypeLegal(SubVT) && + VecVT.getVectorElementType() == MVT::i1) { + // When inserting e.g. nxv4i1 into nxv16i1, we already know the other + // lanes are zeroed. + if (Vec.isUndef() || isNullOrNullSplat(Vec)) + return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VecVT, SubVec); + + // Break down insert_subvector into simpler parts. + unsigned NumElts = VecVT.getVectorMinNumElements(); + unsigned NumSubElts = SubVT.getVectorMinNumElements(); + EVT HalfVT = VecVT.getHalfNumVectorElementsVT(*DAG.getContext()); + + SDValue Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec, + DAG.getVectorIdxConstant(0, DL)); + SDValue Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec, + DAG.getVectorIdxConstant(NumElts / 2, DL)); + if (IdxVal < (NumElts / 2)) { + SDValue NewLo = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Lo, SubVec, + DAG.getVectorIdxConstant(IdxVal, DL)); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, NewLo, Hi); + } else { + SDValue NewHi = + DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Hi, SubVec, + DAG.getVectorIdxConstant(IdxVal - (NumElts / 2), DL)); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, Lo, NewHi); + } + } + // Only do this for legal fixed vector types. if (!VecVT.isFixedLengthVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VecVT) || @@ -14337,7 +14368,6 @@ // Fold insert_subvector -> concat_vectors // insert_subvector(Vec,Sub,lo) -> concat_vectors(Sub,extract(Vec,hi)) // insert_subvector(Vec,Sub,hi) -> concat_vectors(extract(Vec,lo),Sub) - SDLoc DL(N); SDValue Lo, Hi; if (IdxVal == 0) { Lo = SubVec; diff --git a/llvm/test/CodeGen/AArch64/sve-insert-vector.ll b/llvm/test/CodeGen/AArch64/sve-insert-vector.ll --- a/llvm/test/CodeGen/AArch64/sve-insert-vector.ll +++ b/llvm/test/CodeGen/AArch64/sve-insert-vector.ll @@ -501,6 +501,72 @@ ret %v0 } +; Test predicate inserts of half size. +define @insert_nxv16i1_nxv8i1_0( %vec, %sv) { +; CHECK-LABEL: insert_nxv16i1_nxv8i1_0: +; CHECK: // %bb.0: +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: uzp1 p0.b, p1.b, p0.b +; CHECK-NEXT: ret + %v0 = call @llvm.experimental.vector.insert.nx16i1.nxv8i1( %vec, %sv, i64 0) + ret %v0 +} + +define @insert_nxv16i1_nxv8i1_8( %vec, %sv) { +; CHECK-LABEL: insert_nxv16i1_nxv8i1_8: +; CHECK: // %bb.0: +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: uzp1 p0.b, p0.b, p1.b +; CHECK-NEXT: ret + %v0 = call @llvm.experimental.vector.insert.nx16i1.nxv8i1( %vec, %sv, i64 8) + ret %v0 +} + +; Test predicate inserts of less than half the size. +define @insert_nxv16i1_nxv4i1_0( %vec, %sv) { +; CHECK-LABEL: insert_nxv16i1_nxv4i1_0: +; CHECK: // %bb.0: +; CHECK-NEXT: punpklo p2.h, p0.b +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: punpkhi p2.h, p2.b +; CHECK-NEXT: uzp1 p1.h, p1.h, p2.h +; CHECK-NEXT: uzp1 p0.b, p1.b, p0.b +; CHECK-NEXT: ret + %v0 = call @llvm.experimental.vector.insert.nx16i1.nxv4i1( %vec, %sv, i64 0) + ret %v0 +} + +define @insert_nxv16i1_nxv4i1_12( %vec, %sv) { +; CHECK-LABEL: insert_nxv16i1_nxv4i1_12: +; CHECK: // %bb.0: +; CHECK-NEXT: punpkhi p2.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: punpklo p2.h, p2.b +; CHECK-NEXT: uzp1 p1.h, p2.h, p1.h +; CHECK-NEXT: uzp1 p0.b, p0.b, p1.b +; CHECK-NEXT: ret + %v0 = call @llvm.experimental.vector.insert.nx16i1.nxv4i1( %vec, %sv, i64 12) + ret %v0 +} + +; Test predicate insert into undef/zero +define @insert_nxv16i1_nxv4i1_into_zero( %sv) { +; CHECK-LABEL: insert_nxv16i1_nxv4i1_into_zero: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %v0 = call @llvm.experimental.vector.insert.nx16i1.nxv4i1( zeroinitializer, %sv, i64 0) + ret %v0 +} + +define @insert_nxv16i1_nxv4i1_into_poison( %sv) { +; CHECK-LABEL: insert_nxv16i1_nxv4i1_into_poison: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %v0 = call @llvm.experimental.vector.insert.nx16i1.nxv4i1( poison, %sv, i64 0) + ret %v0 +} + + declare @llvm.experimental.vector.insert.nxv3i32.nxv2i32(, , i64) declare @llvm.experimental.vector.insert.nxv3f32.nxv2f32(, , i64) declare @llvm.experimental.vector.insert.nxv6i32.nxv2i32(, , i64) @@ -511,3 +577,6 @@ declare @llvm.experimental.vector.insert.nxv4bf16.nxv4bf16(, , i64) declare @llvm.experimental.vector.insert.nxv4bf16.v4bf16(, <4 x bfloat>, i64) declare @llvm.experimental.vector.insert.nxv2bf16.nxv2bf16(, , i64) + +declare @llvm.experimental.vector.insert.nx16i1.nxv4i1(, , i64) +declare @llvm.experimental.vector.insert.nx16i1.nxv8i1(, , i64)