diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -6653,7 +6653,7 @@ EVT InVT = InOp.getValueType(); assert(InVT.getVectorElementType() == NVT.getVectorElementType() && "input and widen element type must match"); - assert(!InVT.isScalableVector() && !NVT.isScalableVector() && + assert(InVT.isScalableVector() == NVT.isScalableVector() && "cannot modify scalable vectors in this way"); SDLoc dl(InOp); @@ -6661,8 +6661,8 @@ if (InVT == NVT) return InOp; - unsigned InNumElts = InVT.getVectorNumElements(); - unsigned WidenNumElts = NVT.getVectorNumElements(); + unsigned InNumElts = InVT.getVectorMinNumElements(); + unsigned WidenNumElts = NVT.getVectorMinNumElements(); if (WidenNumElts > InNumElts && WidenNumElts % InNumElts == 0) { unsigned NumConcat = WidenNumElts / InNumElts; SmallVector Ops(NumConcat); @@ -6679,6 +6679,9 @@ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NVT, InOp, DAG.getVectorIdxConstant(0, dl)); + assert(!InVT.isScalableVector() && !NVT.isScalableVector() && + "Scalable vectors should have been handled already."); + // Fall back to extract and build. SmallVector Ops(WidenNumElts); EVT EltVT = NVT.getVectorElementType(); diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td @@ -82,9 +82,9 @@ nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64], CCPassIndirect>, - CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], + CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1], CCAssignToReg<[P0, P1, P2, P3]>>, - CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], + CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1], CCPassIndirect>, // Handle i1, i8, i16, i32, i64, f32, f64 and v2f64 by passing in registers, @@ -149,7 +149,7 @@ nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64], CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>, - CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], + CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1], CCAssignToReg<[P0, P1, P2, P3]>> ]>; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -291,6 +291,7 @@ if (Subtarget->hasSVE() || Subtarget->hasSME()) { // Add legal sve predicate types + addRegisterClass(MVT::nxv1i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv2i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv4i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv8i1, &AArch64::PPRRegClass); @@ -1155,7 +1156,8 @@ MVT::nxv4i16, MVT::nxv4i32, MVT::nxv8i8, MVT::nxv8i16 }) setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Legal); - for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) { + for (auto VT : + {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) { setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::SETCC, VT, Custom); diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td @@ -871,7 +871,7 @@ // SVE predicate register classes. class PPRClass : RegisterClass< "AArch64", - [ nxv16i1, nxv8i1, nxv4i1, nxv2i1 ], 16, + [ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ], 16, (sequence "P%u", 0, lastreg)> { let Size = 16; } diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -744,8 +744,10 @@ defm UUNPKLO_ZZ : sve_int_perm_unpk<0b10, "uunpklo", AArch64uunpklo>; defm UUNPKHI_ZZ : sve_int_perm_unpk<0b11, "uunpkhi", AArch64uunpkhi>; - defm PUNPKLO_PP : sve_int_perm_punpk<0b0, "punpklo", int_aarch64_sve_punpklo>; - defm PUNPKHI_PP : sve_int_perm_punpk<0b1, "punpkhi", int_aarch64_sve_punpkhi>; + // Instructions PUNPKLO/HI_PP are defined in SVEInstrFormats.td + // Here we only define the patterns. + defm : sve_int_perm_punpk_pat; + defm : sve_int_perm_punpk_pat; defm MOVPRFX_ZPzZ : sve_int_movprfx_pred_zero<0b000, "movprfx">; defm MOVPRFX_ZPmZ : sve_int_movprfx_pred_merge<0b001, "movprfx">; @@ -1598,6 +1600,8 @@ (UUNPKHI_ZZ_D (UUNPKHI_ZZ_S ZPR:$Zs))>; // Concatenate two predicates. + def : Pat<(nxv2i1 (concat_vectors nxv1i1:$p1, nxv1i1:$p2)), + (UZP1_PPP_D $p1, $p2)>; def : Pat<(nxv4i1 (concat_vectors nxv2i1:$p1, nxv2i1:$p2)), (UZP1_PPP_S $p1, $p2)>; def : Pat<(nxv8i1 (concat_vectors nxv4i1:$p1, nxv4i1:$p2)), diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -288,9 +288,29 @@ //===----------------------------------------------------------------------===// -// SVE PTrue - These are used extensively throughout the pattern matching so -// it's important we define them first. +// SVE PUNPKLO/HI and PTrue: +// These are used extensively throughout the pattern matching so +// it's important we define them first. //===----------------------------------------------------------------------===// +class sve_int_perm_punpk +: I<(outs PPR16:$Pd), (ins PPR8:$Pn), + asm, "\t$Pd, $Pn", + "", + []>, Sched<[]> { + bits<4> Pd; + bits<4> Pn; + let Inst{31-17} = 0b000001010011000; + let Inst{16} = opc; + let Inst{15-9} = 0b0100000; + let Inst{8-5} = Pn; + let Inst{4} = 0b0; + let Inst{3-0} = Pd; +} + +let Predicates = [HasSVEorSME] in { + def PUNPKLO_PP : sve_int_perm_punpk<0b0, "punpklo">; + def PUNPKHI_PP : sve_int_perm_punpk<0b1, "punpkhi">; +} class sve_int_ptrue sz8_64, bits<3> opc, string asm, PPRRegOp pprty, ValueType vt, SDPatternOperator op> @@ -342,6 +362,9 @@ def : Pat<(nxv8i1 immAllOnesV), (PTRUE_H 31)>; def : Pat<(nxv4i1 immAllOnesV), (PTRUE_S 31)>; def : Pat<(nxv2i1 immAllOnesV), (PTRUE_D 31)>; + def : Pat<(nxv1i1 (AArch64ptrue (sve_pred_enum:$pattern))), + (PUNPKLO_PP (PTRUE_D $pattern))>; + def : Pat<(nxv1i1 immAllOnesV), (PUNPKLO_PP (PTRUE_D 31))>; } //===----------------------------------------------------------------------===// @@ -641,6 +664,7 @@ def : Pat<(nxv8i1 immAllZerosV), (!cast(NAME))>; def : Pat<(nxv4i1 immAllZerosV), (!cast(NAME))>; def : Pat<(nxv2i1 immAllZerosV), (!cast(NAME))>; + def : Pat<(nxv1i1 immAllZerosV), (!cast(NAME))>; } class sve_int_ptest opc, string asm> @@ -1658,6 +1682,7 @@ def : SVE_3_Op_Pat(NAME)>; def : SVE_3_Op_Pat(NAME)>; def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; def : SVE_2_Op_AllActive_Pat(NAME), PTRUE_B>; def : SVE_2_Op_AllActive_Pat(NAME # _H)>; def : SVE_2_Op_Pat(NAME # _S)>; def : SVE_2_Op_Pat(NAME # _D)>; + def : Pat<(nxv1i1 (op i32:$Op1, i32:$Op2)), + (PUNPKLO_PP (!cast(NAME # _D) $Op1, $Op2))>; } multiclass sve_int_while8_rr opc, string asm, SDPatternOperator op> { @@ -4943,6 +4970,8 @@ def : SVE_2_Op_Pat(NAME # _H)>; def : SVE_2_Op_Pat(NAME # _S)>; def : SVE_2_Op_Pat(NAME # _D)>; + def : Pat<(nxv1i1 (op i64:$Op1, i64:$Op2)), + (PUNPKLO_PP (!cast(NAME # _D) $Op1, $Op2))>; } class sve2_int_while_rr sz8_64, bits<1> rw, string asm, @@ -6177,29 +6206,14 @@ def : SVE_2_Op_Pat(NAME # _H)>; def : SVE_2_Op_Pat(NAME # _S)>; def : SVE_2_Op_Pat(NAME # _D)>; -} -class sve_int_perm_punpk -: I<(outs PPR16:$Pd), (ins PPR8:$Pn), - asm, "\t$Pd, $Pn", - "", - []>, Sched<[]> { - bits<4> Pd; - bits<4> Pn; - let Inst{31-17} = 0b000001010011000; - let Inst{16} = opc; - let Inst{15-9} = 0b0100000; - let Inst{8-5} = Pn; - let Inst{4} = 0b0; - let Inst{3-0} = Pd; + // FIXME: ADD PATTERNS FOR _Q IF SME IS NOT AVAILABLE } -multiclass sve_int_perm_punpk { - def NAME : sve_int_perm_punpk; - - def : SVE_1_Op_Pat(NAME)>; - def : SVE_1_Op_Pat(NAME)>; - def : SVE_1_Op_Pat(NAME)>; +multiclass sve_int_perm_punpk_pat { + def : SVE_1_Op_Pat; + def : SVE_1_Op_Pat; + def : SVE_1_Op_Pat; } class sve_int_rdffr_pred diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-pred-creation.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-pred-creation.ll --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-pred-creation.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-pred-creation.ll @@ -42,7 +42,17 @@ ret %out } +define @ptrue_b128() { +; CHECK-LABEL: ptrue_b128: +; CHECK: ptrue p0.d +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: ret + %out = call @llvm.aarch64.sve.ptrue.nxv1i1(i32 31) + ret %out +} + declare @llvm.aarch64.sve.ptrue.nxv16i1(i32 %pattern) declare @llvm.aarch64.sve.ptrue.nxv8i1(i32 %pattern) declare @llvm.aarch64.sve.ptrue.nxv4i1(i32 %pattern) declare @llvm.aarch64.sve.ptrue.nxv2i1(i32 %pattern) +declare @llvm.aarch64.sve.ptrue.nxv1i1(i32 %pattern) diff --git a/llvm/test/CodeGen/AArch64/sve-select.ll b/llvm/test/CodeGen/AArch64/sve-select.ll --- a/llvm/test/CodeGen/AArch64/sve-select.ll +++ b/llvm/test/CodeGen/AArch64/sve-select.ll @@ -187,6 +187,7 @@ ; CHECK-NEXT: // kill: def $w0 killed $w0 def $x0 ; CHECK-NEXT: sbfx x8, x0, #0, #1 ; CHECK-NEXT: whilelo p2.d, xzr, x8 +; CHECK-NEXT: punpklo p2.h, p2.b ; CHECK-NEXT: sel p0.b, p2, p0.b, p1.b ; CHECK-NEXT: ret %res = select i1 %cond, %a, %b @@ -225,6 +226,7 @@ define @sel_nxv1i64( %p, %dst, %a) { ; CHECK-LABEL: sel_nxv1i64: ; CHECK: // %bb.0: +; CHECK-NEXT: uzp1 p0.d, p0.d, p0.d ; CHECK-NEXT: mov z0.d, p0/m, z1.d ; CHECK-NEXT: ret %sel = select %p, %a, %dst @@ -483,6 +485,7 @@ ; CHECK-NEXT: cset w8, eq ; CHECK-NEXT: sbfx x8, x8, #0, #1 ; CHECK-NEXT: whilelo p2.d, xzr, x8 +; CHECK-NEXT: punpklo p2.h, p2.b ; CHECK-NEXT: sel p0.b, p2, p0.b, p1.b ; CHECK-NEXT: ret %mask = icmp eq i64 %x0, 0 diff --git a/llvm/test/CodeGen/AArch64/sve-zeroinit.ll b/llvm/test/CodeGen/AArch64/sve-zeroinit.ll --- a/llvm/test/CodeGen/AArch64/sve-zeroinit.ll +++ b/llvm/test/CodeGen/AArch64/sve-zeroinit.ll @@ -52,6 +52,13 @@ ret zeroinitializer } +define @test_zeroinit_1xi1() { +; CHECK-LABEL: test_zeroinit_1xi1 +; CHECK: pfalse p0.b +; CHECK-NEXT: ret + ret zeroinitializer +} + define @test_zeroinit_2xi1() { ; CHECK-LABEL: test_zeroinit_2xi1 ; CHECK: pfalse p0.b