diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1246,6 +1246,7 @@ setOperationAction(ISD::SELECT_CC, VT, Expand); setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); // There are no legal MVT::nxv16f## based types. if (VT != MVT::nxv16i1) { @@ -11029,6 +11030,28 @@ if (!isTypeLegal(VT)) return SDValue(); + // Break down insert_subvector into simpler parts. + if (VT.getVectorElementType() == MVT::i1) { + unsigned NumElts = VT.getVectorMinNumElements(); + EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext()); + + SDValue Lo, Hi; + Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec0, + DAG.getVectorIdxConstant(0, DL)); + Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec0, + DAG.getVectorIdxConstant(NumElts / 2, DL)); + if (Idx < (NumElts / 2)) { + SDValue NewLo = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Lo, Vec1, + DAG.getVectorIdxConstant(Idx, DL)); + return DAG.getNode(AArch64ISD::UZP1, DL, VT, NewLo, Hi); + } else { + SDValue NewHi = + DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Hi, Vec1, + DAG.getVectorIdxConstant(Idx - (NumElts / 2), DL)); + return DAG.getNode(AArch64ISD::UZP1, DL, VT, Lo, NewHi); + } + } + // Ensure the subvector is half the size of the main vector. if (VT.getVectorElementCount() != (InVT.getVectorElementCount() * 2)) return SDValue(); @@ -12952,7 +12975,7 @@ if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT)) return false; - return (Index == 0 || Index == ResVT.getVectorNumElements()); + return (Index == 0 || Index == ResVT.getVectorMinNumElements()); } /// Turn vector tests of the signbit in the form of: @@ -14312,6 +14335,7 @@ static SDValue performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + SDLoc DL(N); SDValue Vec = N->getOperand(0); SDValue SubVec = N->getOperand(1); uint64_t IdxVal = N->getConstantOperandVal(2); @@ -14337,7 +14361,6 @@ // Fold insert_subvector -> concat_vectors // insert_subvector(Vec,Sub,lo) -> concat_vectors(Sub,extract(Vec,hi)) // insert_subvector(Vec,Sub,hi) -> concat_vectors(extract(Vec,lo),Sub) - SDLoc DL(N); SDValue Lo, Hi; if (IdxVal == 0) { Lo = SubVec; diff --git a/llvm/test/CodeGen/AArch64/sve-insert-vector.ll b/llvm/test/CodeGen/AArch64/sve-insert-vector.ll --- a/llvm/test/CodeGen/AArch64/sve-insert-vector.ll +++ b/llvm/test/CodeGen/AArch64/sve-insert-vector.ll @@ -501,6 +501,80 @@ ret %v0 } +; Test predicate inserts of half size. +define @insert_nxv16i1_nxv8i1_0( %vec, %sv) { +; CHECK-LABEL: insert_nxv16i1_nxv8i1_0: +; CHECK: // %bb.0: +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: uzp1 p0.b, p1.b, p0.b +; CHECK-NEXT: ret + %v0 = call @llvm.experimental.vector.insert.nx16i1.nxv8i1( %vec, %sv, i64 0) + ret %v0 +} + +define @insert_nxv16i1_nxv8i1_8( %vec, %sv) { +; CHECK-LABEL: insert_nxv16i1_nxv8i1_8: +; CHECK: // %bb.0: +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: uzp1 p0.b, p0.b, p1.b +; CHECK-NEXT: ret + %v0 = call @llvm.experimental.vector.insert.nx16i1.nxv8i1( %vec, %sv, i64 8) + ret %v0 +} + +; Test predicate inserts of less than half the size. +define @insert_nxv16i1_nxv4i1_0( %vec, %sv) { +; CHECK-LABEL: insert_nxv16i1_nxv4i1_0: +; CHECK: // %bb.0: +; CHECK-NEXT: punpklo p2.h, p0.b +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: punpkhi p2.h, p2.b +; CHECK-NEXT: uzp1 p1.h, p1.h, p2.h +; CHECK-NEXT: uzp1 p0.b, p1.b, p0.b +; CHECK-NEXT: ret + %v0 = call @llvm.experimental.vector.insert.nx16i1.nxv4i1( %vec, %sv, i64 0) + ret %v0 +} + +define @insert_nxv16i1_nxv4i1_12( %vec, %sv) { +; CHECK-LABEL: insert_nxv16i1_nxv4i1_12: +; CHECK: // %bb.0: +; CHECK-NEXT: punpkhi p2.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: punpklo p2.h, p2.b +; CHECK-NEXT: uzp1 p1.h, p2.h, p1.h +; CHECK-NEXT: uzp1 p0.b, p0.b, p1.b +; CHECK-NEXT: ret + %v0 = call @llvm.experimental.vector.insert.nx16i1.nxv4i1( %vec, %sv, i64 12) + ret %v0 +} + +; Test predicate insert into undef/zero +define @insert_nxv16i1_nxv4i1_into_zero( %sv) { +; CHECK-LABEL: insert_nxv16i1_nxv4i1_into_zero: +; CHECK: // %bb.0: +; CHECK-NEXT: pfalse p1.b +; CHECK-NEXT: punpklo p2.h, p1.b +; CHECK-NEXT: punpkhi p1.h, p1.b +; CHECK-NEXT: punpkhi p2.h, p2.b +; CHECK-NEXT: uzp1 p0.h, p0.h, p2.h +; CHECK-NEXT: uzp1 p0.b, p0.b, p1.b +; CHECK-NEXT: ret + %v0 = call @llvm.experimental.vector.insert.nx16i1.nxv4i1( zeroinitializer, %sv, i64 0) + ret %v0 +} + +define @insert_nxv16i1_nxv4i1_into_poison( %sv) { +; CHECK-LABEL: insert_nxv16i1_nxv4i1_into_poison: +; CHECK: // %bb.0: +; CHECK-NEXT: uzp1 p0.h, p0.h, p0.h +; CHECK-NEXT: uzp1 p0.b, p0.b, p0.b +; CHECK-NEXT: ret + %v0 = call @llvm.experimental.vector.insert.nx16i1.nxv4i1( poison, %sv, i64 0) + ret %v0 +} + + declare @llvm.experimental.vector.insert.nxv3i32.nxv2i32(, , i64) declare @llvm.experimental.vector.insert.nxv3f32.nxv2f32(, , i64) declare @llvm.experimental.vector.insert.nxv6i32.nxv2i32(, , i64) @@ -511,3 +585,6 @@ declare @llvm.experimental.vector.insert.nxv4bf16.nxv4bf16(, , i64) declare @llvm.experimental.vector.insert.nxv4bf16.v4bf16(, <4 x bfloat>, i64) declare @llvm.experimental.vector.insert.nxv2bf16.nxv2bf16(, , i64) + +declare @llvm.experimental.vector.insert.nx16i1.nxv4i1(, , i64) +declare @llvm.experimental.vector.insert.nx16i1.nxv8i1(, , i64)