diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -6653,7 +6653,7 @@ EVT InVT = InOp.getValueType(); assert(InVT.getVectorElementType() == NVT.getVectorElementType() && "input and widen element type must match"); - assert(!InVT.isScalableVector() && !NVT.isScalableVector() && + assert(InVT.isScalableVector() == NVT.isScalableVector() && "cannot modify scalable vectors in this way"); SDLoc dl(InOp); @@ -6661,10 +6661,10 @@ if (InVT == NVT) return InOp; - unsigned InNumElts = InVT.getVectorNumElements(); - unsigned WidenNumElts = NVT.getVectorNumElements(); - if (WidenNumElts > InNumElts && WidenNumElts % InNumElts == 0) { - unsigned NumConcat = WidenNumElts / InNumElts; + ElementCount InEC = InVT.getVectorElementCount(); + ElementCount WidenEC = NVT.getVectorElementCount(); + if (WidenEC.hasKnownScalarFactor(InEC)) { + unsigned NumConcat = WidenEC.getKnownScalarFactor(InEC); SmallVector Ops(NumConcat); SDValue FillVal = FillWithZeroes ? DAG.getConstant(0, dl, InVT) : DAG.getUNDEF(InVT); @@ -6675,10 +6675,16 @@ return DAG.getNode(ISD::CONCAT_VECTORS, dl, NVT, Ops); } - if (WidenNumElts < InNumElts && InNumElts % WidenNumElts) + if (InEC.hasKnownScalarFactor(WidenEC)) return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NVT, InOp, DAG.getVectorIdxConstant(0, dl)); + assert(!InVT.isScalableVector() && !NVT.isScalableVector() && + "Scalable vectors should have been handled already."); + + unsigned InNumElts = InEC.getFixedValue(); + unsigned WidenNumElts = WidenEC.getFixedValue(); + // Fall back to extract and build. SmallVector Ops(WidenNumElts); EVT EltVT = NVT.getVectorElementType(); diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td @@ -82,9 +82,9 @@ nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64], CCPassIndirect>, - CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], + CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1], CCAssignToReg<[P0, P1, P2, P3]>>, - CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], + CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1], CCPassIndirect>, // Handle i1, i8, i16, i32, i64, f32, f64 and v2f64 by passing in registers, @@ -149,7 +149,7 @@ nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64], CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>, - CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], + CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1], CCAssignToReg<[P0, P1, P2, P3]>> ]>; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -292,6 +292,7 @@ if (Subtarget->hasSVE() || Subtarget->hasSME()) { // Add legal sve predicate types + addRegisterClass(MVT::nxv1i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv2i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv4i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv8i1, &AArch64::PPRRegClass); @@ -1156,7 +1157,8 @@ MVT::nxv4i16, MVT::nxv4i32, MVT::nxv8i8, MVT::nxv8i16 }) setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Legal); - for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) { + for (auto VT : + {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) { setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::SETCC, VT, Custom); @@ -4676,7 +4678,6 @@ Op.getOperand(2), Op.getOperand(3), DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)), Op.getOperand(1)); - case Intrinsic::localaddress: { const auto &MF = DAG.getMachineFunction(); const auto *RegInfo = Subtarget->getRegisterInfo(); @@ -10551,8 +10552,13 @@ DAG.getValueType(MVT::i1)); SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64); - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, ID, - DAG.getConstant(0, DL, MVT::i64), SplatVal); + SDValue Zero = DAG.getConstant(0, DL, MVT::i64); + if (VT == MVT::nxv1i1) + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv1i1, + DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv2i1, ID, + Zero, SplatVal), + Zero); + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, ID, Zero, SplatVal); } SDValue AArch64TargetLowering::LowerDUPQLane(SDValue Op, diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td @@ -871,7 +871,7 @@ // SVE predicate register classes. class PPRClass : RegisterClass< "AArch64", - [ nxv16i1, nxv8i1, nxv4i1, nxv2i1 ], 16, + [ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ], 16, (sequence "P%u", 0, lastreg)> { let Size = 16; } diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -748,6 +748,11 @@ defm PUNPKLO_PP : sve_int_perm_punpk<0b0, "punpklo", int_aarch64_sve_punpklo>; defm PUNPKHI_PP : sve_int_perm_punpk<0b1, "punpkhi", int_aarch64_sve_punpkhi>; + // Define pattern for `nxv1i1 splat_vector(1)`. + // We do this here instead of in ISelLowering such that PatFrag's can still + // recognize a splat. + def : Pat<(nxv1i1 immAllOnesV), (PUNPKLO_PP (PTRUE_D 31))>; + defm MOVPRFX_ZPzZ : sve_int_movprfx_pred_zero<0b000, "movprfx">; defm MOVPRFX_ZPmZ : sve_int_movprfx_pred_merge<0b001, "movprfx">; def MOVPRFX_ZZ : sve_int_bin_cons_misc_0_c<0b00000001, "movprfx", ZPRAny>; @@ -1509,6 +1514,10 @@ defm TRN2_PPP : sve_int_perm_bin_perm_pp<0b101, "trn2", AArch64trn2>; // Extract lo/hi halves of legal predicate types. + def : Pat<(nxv1i1 (extract_subvector (nxv2i1 PPR:$Ps), (i64 0))), + (PUNPKLO_PP PPR:$Ps)>; + def : Pat<(nxv1i1 (extract_subvector (nxv2i1 PPR:$Ps), (i64 1))), + (PUNPKHI_PP PPR:$Ps)>; def : Pat<(nxv2i1 (extract_subvector (nxv4i1 PPR:$Ps), (i64 0))), (PUNPKLO_PP PPR:$Ps)>; def : Pat<(nxv2i1 (extract_subvector (nxv4i1 PPR:$Ps), (i64 2))), @@ -1599,6 +1608,8 @@ (UUNPKHI_ZZ_D (UUNPKHI_ZZ_S ZPR:$Zs))>; // Concatenate two predicates. + def : Pat<(nxv2i1 (concat_vectors nxv1i1:$p1, nxv1i1:$p2)), + (UZP1_PPP_D $p1, $p2)>; def : Pat<(nxv4i1 (concat_vectors nxv2i1:$p1, nxv2i1:$p2)), (UZP1_PPP_S $p1, $p2)>; def : Pat<(nxv8i1 (concat_vectors nxv4i1:$p1, nxv4i1:$p2)), @@ -2298,15 +2309,23 @@ def : Pat<(nxv16i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv16i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv16i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv16i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv8i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv8i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv8i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv8i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv4i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv4i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv4i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv4i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv2i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv2i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv2i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv2i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv1i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv1i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv1i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv1i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; // These allow casting from/to unpacked floating-point types. def : Pat<(nxv2f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -647,6 +647,7 @@ def : Pat<(nxv8i1 immAllZerosV), (!cast(NAME))>; def : Pat<(nxv4i1 immAllZerosV), (!cast(NAME))>; def : Pat<(nxv2i1 immAllZerosV), (!cast(NAME))>; + def : Pat<(nxv1i1 immAllZerosV), (!cast(NAME))>; } class sve_int_ptest opc, string asm> @@ -1681,6 +1682,7 @@ def : SVE_3_Op_Pat(NAME)>; def : SVE_3_Op_Pat(NAME)>; def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; def : SVE_2_Op_AllActive_Pat(NAME), PTRUE_B>; def : SVE_2_Op_AllActive_Pat @llvm.vector.extract.nxv2f32.nxv4f32(, i64) declare @llvm.vector.extract.nxv4i32.nxv8i32(, i64) + +; +; Extract nxv1i1 type from: nxv2i1 +; + +define @extract_nxv1i1_nxv2i1_0( %in) { +; CHECK-LABEL: extract_nxv1i1_nxv2i1_0: +; CHECK: // %bb.0: +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: ret + %res = call @llvm.vector.extract.nxv1i1.nxv2i1( %in, i64 0) + ret %res +} + +define @extract_nxv1i1_nxv2i1_1( %in) { +; CHECK-LABEL: extract_nxv1i1_nxv2i1_1: +; CHECK: // %bb.0: +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: ret + %res = call @llvm.vector.extract.nxv1i1.nxv2i1( %in, i64 1) + ret %res +} + +declare @llvm.vector.extract.nxv1i1.nxv2i1(, i64) diff --git a/llvm/test/CodeGen/AArch64/sve-select.ll b/llvm/test/CodeGen/AArch64/sve-select.ll --- a/llvm/test/CodeGen/AArch64/sve-select.ll +++ b/llvm/test/CodeGen/AArch64/sve-select.ll @@ -187,6 +187,7 @@ ; CHECK-NEXT: // kill: def $w0 killed $w0 def $x0 ; CHECK-NEXT: sbfx x8, x0, #0, #1 ; CHECK-NEXT: whilelo p2.d, xzr, x8 +; CHECK-NEXT: punpklo p2.h, p2.b ; CHECK-NEXT: sel p0.b, p2, p0.b, p1.b ; CHECK-NEXT: ret %res = select i1 %cond, %a, %b @@ -225,6 +226,7 @@ define @sel_nxv1i64( %p, %dst, %a) { ; CHECK-LABEL: sel_nxv1i64: ; CHECK: // %bb.0: +; CHECK-NEXT: uzp1 p0.d, p0.d, p0.d ; CHECK-NEXT: mov z0.d, p0/m, z1.d ; CHECK-NEXT: ret %sel = select %p, %a, %dst @@ -483,6 +485,7 @@ ; CHECK-NEXT: cset w8, eq ; CHECK-NEXT: sbfx x8, x8, #0, #1 ; CHECK-NEXT: whilelo p2.d, xzr, x8 +; CHECK-NEXT: punpklo p2.h, p2.b ; CHECK-NEXT: sel p0.b, p2, p0.b, p1.b ; CHECK-NEXT: ret %mask = icmp eq i64 %x0, 0 diff --git a/llvm/test/CodeGen/AArch64/sve-zeroinit.ll b/llvm/test/CodeGen/AArch64/sve-zeroinit.ll --- a/llvm/test/CodeGen/AArch64/sve-zeroinit.ll +++ b/llvm/test/CodeGen/AArch64/sve-zeroinit.ll @@ -52,6 +52,13 @@ ret zeroinitializer } +define @test_zeroinit_1xi1() { +; CHECK-LABEL: test_zeroinit_1xi1 +; CHECK: pfalse p0.b +; CHECK-NEXT: ret + ret zeroinitializer +} + define @test_zeroinit_2xi1() { ; CHECK-LABEL: test_zeroinit_2xi1 ; CHECK: pfalse p0.b