diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -6653,7 +6653,7 @@ EVT InVT = InOp.getValueType(); assert(InVT.getVectorElementType() == NVT.getVectorElementType() && "input and widen element type must match"); - assert(!InVT.isScalableVector() && !NVT.isScalableVector() && + assert(InVT.isScalableVector() == NVT.isScalableVector() && "cannot modify scalable vectors in this way"); SDLoc dl(InOp); @@ -6661,8 +6661,8 @@ if (InVT == NVT) return InOp; - unsigned InNumElts = InVT.getVectorNumElements(); - unsigned WidenNumElts = NVT.getVectorNumElements(); + unsigned InNumElts = InVT.getVectorMinNumElements(); + unsigned WidenNumElts = NVT.getVectorMinNumElements(); if (WidenNumElts > InNumElts && WidenNumElts % InNumElts == 0) { unsigned NumConcat = WidenNumElts / InNumElts; SmallVector Ops(NumConcat); @@ -6679,6 +6679,9 @@ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NVT, InOp, DAG.getVectorIdxConstant(0, dl)); + assert(!InVT.isScalableVector() && !NVT.isScalableVector() && + "Scalable vectors should have been handled already."); + // Fall back to extract and build. SmallVector Ops(WidenNumElts); EVT EltVT = NVT.getVectorElementType(); diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td @@ -82,9 +82,9 @@ nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64], CCPassIndirect>, - CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], + CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1], CCAssignToReg<[P0, P1, P2, P3]>>, - CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], + CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1], CCPassIndirect>, // Handle i1, i8, i16, i32, i64, f32, f64 and v2f64 by passing in registers, @@ -149,7 +149,7 @@ nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64], CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>, - CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], + CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1], CCAssignToReg<[P0, P1, P2, P3]>> ]>; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -291,6 +291,7 @@ if (Subtarget->hasSVE() || Subtarget->hasSME()) { // Add legal sve predicate types + addRegisterClass(MVT::nxv1i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv2i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv4i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv8i1, &AArch64::PPRRegClass); @@ -1155,7 +1156,8 @@ MVT::nxv4i16, MVT::nxv4i32, MVT::nxv8i8, MVT::nxv8i16 }) setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Legal); - for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) { + for (auto VT : + {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) { setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::SETCC, VT, Custom); @@ -4257,6 +4259,13 @@ DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1)); } +static inline SDValue getPUNPKLO(SelectionDAG &DAG, SDLoc DL, SDValue Op) { + return DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, + Op.getValueType().getHalfNumVectorElementsVT(*DAG.getContext()), + DAG.getConstant(Intrinsic::aarch64_sve_punpklo, DL, MVT::i64), Op); +} + static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT, int Pattern) { return DAG.getNode(AArch64ISD::PTRUE, DL, VT, @@ -4605,7 +4614,6 @@ Op.getOperand(2), Op.getOperand(3), DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)), Op.getOperand(1)); - case Intrinsic::localaddress: { const auto &MF = DAG.getMachineFunction(); const auto *RegInfo = Subtarget->getRegisterInfo(); @@ -10471,8 +10479,12 @@ DAG.getValueType(MVT::i1)); SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64); - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, ID, - DAG.getConstant(0, DL, MVT::i64), SplatVal); + SDValue Zero = DAG.getConstant(0, DL, MVT::i64); + if (VT == MVT::nxv1i1) + return getPUNPKLO(DAG, DL, + DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv2i1, ID, + Zero, SplatVal)); + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, ID, Zero, SplatVal); } SDValue AArch64TargetLowering::LowerDUPQLane(SDValue Op, diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td @@ -871,7 +871,7 @@ // SVE predicate register classes. class PPRClass : RegisterClass< "AArch64", - [ nxv16i1, nxv8i1, nxv4i1, nxv2i1 ], 16, + [ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ], 16, (sequence "P%u", 0, lastreg)> { let Size = 16; } diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -747,6 +747,11 @@ defm PUNPKLO_PP : sve_int_perm_punpk<0b0, "punpklo", int_aarch64_sve_punpklo>; defm PUNPKHI_PP : sve_int_perm_punpk<0b1, "punpkhi", int_aarch64_sve_punpkhi>; + // Define pattern for `nxv1i1 splat_vector(1)`. + // We do this here instead of in ISelLowering such that PatFrag's can still + // recognize a splat. + def : Pat<(nxv1i1 immAllOnesV), (PUNPKLO_PP (PTRUE_D 31))>; + defm MOVPRFX_ZPzZ : sve_int_movprfx_pred_zero<0b000, "movprfx">; defm MOVPRFX_ZPmZ : sve_int_movprfx_pred_merge<0b001, "movprfx">; def MOVPRFX_ZZ : sve_int_bin_cons_misc_0_c<0b00000001, "movprfx", ZPRAny>; @@ -1598,6 +1603,8 @@ (UUNPKHI_ZZ_D (UUNPKHI_ZZ_S ZPR:$Zs))>; // Concatenate two predicates. + def : Pat<(nxv2i1 (concat_vectors nxv1i1:$p1, nxv1i1:$p2)), + (UZP1_PPP_D $p1, $p2)>; def : Pat<(nxv4i1 (concat_vectors nxv2i1:$p1, nxv2i1:$p2)), (UZP1_PPP_S $p1, $p2)>; def : Pat<(nxv8i1 (concat_vectors nxv4i1:$p1, nxv4i1:$p2)), @@ -2297,15 +2304,23 @@ def : Pat<(nxv16i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv16i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv16i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv16i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv8i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv8i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv8i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv8i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv4i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv4i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv4i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv4i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv2i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv2i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv2i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv2i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv1i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv1i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv1i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv1i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; // These allow casting from/to unpacked floating-point types. def : Pat<(nxv2f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -641,6 +641,7 @@ def : Pat<(nxv8i1 immAllZerosV), (!cast(NAME))>; def : Pat<(nxv4i1 immAllZerosV), (!cast(NAME))>; def : Pat<(nxv2i1 immAllZerosV), (!cast(NAME))>; + def : Pat<(nxv1i1 immAllZerosV), (!cast(NAME))>; } class sve_int_ptest opc, string asm> @@ -1658,6 +1659,7 @@ def : SVE_3_Op_Pat(NAME)>; def : SVE_3_Op_Pat(NAME)>; def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; def : SVE_2_Op_AllActive_Pat(NAME), PTRUE_B>; def : SVE_2_Op_AllActive_Pat(NAME)>; def : SVE_1_Op_Pat(NAME)>; def : SVE_1_Op_Pat(NAME)>; + def : SVE_1_Op_Pat(NAME)>; } class sve_int_rdffr_pred diff --git a/llvm/test/CodeGen/AArch64/sve-select.ll b/llvm/test/CodeGen/AArch64/sve-select.ll --- a/llvm/test/CodeGen/AArch64/sve-select.ll +++ b/llvm/test/CodeGen/AArch64/sve-select.ll @@ -187,6 +187,7 @@ ; CHECK-NEXT: // kill: def $w0 killed $w0 def $x0 ; CHECK-NEXT: sbfx x8, x0, #0, #1 ; CHECK-NEXT: whilelo p2.d, xzr, x8 +; CHECK-NEXT: punpklo p2.h, p2.b ; CHECK-NEXT: sel p0.b, p2, p0.b, p1.b ; CHECK-NEXT: ret %res = select i1 %cond, %a, %b @@ -225,6 +226,7 @@ define @sel_nxv1i64( %p, %dst, %a) { ; CHECK-LABEL: sel_nxv1i64: ; CHECK: // %bb.0: +; CHECK-NEXT: uzp1 p0.d, p0.d, p0.d ; CHECK-NEXT: mov z0.d, p0/m, z1.d ; CHECK-NEXT: ret %sel = select %p, %a, %dst @@ -483,6 +485,7 @@ ; CHECK-NEXT: cset w8, eq ; CHECK-NEXT: sbfx x8, x8, #0, #1 ; CHECK-NEXT: whilelo p2.d, xzr, x8 +; CHECK-NEXT: punpklo p2.h, p2.b ; CHECK-NEXT: sel p0.b, p2, p0.b, p1.b ; CHECK-NEXT: ret %mask = icmp eq i64 %x0, 0 diff --git a/llvm/test/CodeGen/AArch64/sve-zeroinit.ll b/llvm/test/CodeGen/AArch64/sve-zeroinit.ll --- a/llvm/test/CodeGen/AArch64/sve-zeroinit.ll +++ b/llvm/test/CodeGen/AArch64/sve-zeroinit.ll @@ -52,6 +52,13 @@ ret zeroinitializer } +define @test_zeroinit_1xi1() { +; CHECK-LABEL: test_zeroinit_1xi1 +; CHECK: pfalse p0.b +; CHECK-NEXT: ret + ret zeroinitializer +} + define @test_zeroinit_2xi1() { ; CHECK-LABEL: test_zeroinit_2xi1 ; CHECK: pfalse p0.b