Index: llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -66,6 +66,9 @@ template bool SelectRDVLImm(SDValue N, SDValue &Imm); + template + bool PTrueLoImm(SDValue N, SDValue &Imm); + bool tryMLAV64LaneV128(SDNode *N); bool tryMULLV64LaneV128(unsigned IntNo, SDNode *N); bool SelectArithExtendedRegister(SDValue N, SDValue &Reg, SDValue &Shift); @@ -848,6 +851,23 @@ return false; } +// Returns a suitable immediate operand for PTRUE instruction. +template +bool AArch64DAGToDAGISel::PTrueLoImm(SDValue N, SDValue &Imm) { + if (!isa(N)) + return false; + int64_t NodeImm = cast(N)->getZExtValue(); + int64_t MinSVEVectorSize = Subtarget->getMinSVEVectorSizeInBits() + ? Subtarget->getMinSVEVectorSizeInBits() + : 128; + if (getSVEPredPatternFromNumElements(NodeImm) != None && + NodeImm < (MinSVEVectorSize / ElementSize)) { + Imm = CurDAG->getTargetConstant(NodeImm, SDLoc(N), MVT::i32); + return true; + } + return false; +} + /// SelectArithExtendedRegister - Select a "extended register" operand. This /// operand folds in an extend followed by an optional left shift. bool AArch64DAGToDAGISel::SelectArithExtendedRegister(SDValue N, SDValue &Reg, Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -141,6 +141,11 @@ def sve_cntw_imm : ComplexPattern">; def sve_cntd_imm : ComplexPattern">; +def ptrue_pred_lo_b : ComplexPattern">; +def ptrue_pred_lo_h : ComplexPattern">; +def ptrue_pred_lo_s : ComplexPattern">; +def ptrue_pred_lo_d : ComplexPattern">; + // SVE DEC def sve_cnth_imm_neg : ComplexPattern">; def sve_cntw_imm_neg : ComplexPattern">; @@ -1808,6 +1813,15 @@ def : Pat<(nxv2bf16 (extract_subvector (nxv8bf16 ZPR:$Zs), (i64 6))), (UUNPKHI_ZZ_D (UUNPKHI_ZZ_S ZPR:$Zs))>; + def : Pat<(nxv16i1 (int_aarch64_sve_whilelo (i64 0), (ptrue_pred_lo_b i32:$imm))), + (PTRUE_B i32:$imm)>; + def : Pat<(nxv8i1 (int_aarch64_sve_whilelo (i64 0), (ptrue_pred_lo_h i32:$imm))), + (PTRUE_H i32:$imm)>; + def : Pat<(nxv4i1 (int_aarch64_sve_whilelo (i64 0), (ptrue_pred_lo_s i32:$imm))), + (PTRUE_S i32:$imm)>; + def : Pat<(nxv2i1 (int_aarch64_sve_whilelo (i64 0), (ptrue_pred_lo_d i32:$imm))), + (PTRUE_D i32:$imm)>; + // Concatenate two predicates. def : Pat<(nxv2i1 (concat_vectors nxv1i1:$p1, nxv1i1:$p2)), (UZP1_PPP_D $p1, $p2)>; Index: llvm/test/CodeGen/AArch64/active_lane_mask.ll =================================================================== --- llvm/test/CodeGen/AArch64/active_lane_mask.ll +++ llvm/test/CodeGen/AArch64/active_lane_mask.ll @@ -475,6 +475,38 @@ ret <2 x i1> %active.lane.mask } +define @lane_mask_nxv4i1_imm3() { +; CHECK-LABEL: lane_mask_nxv4i1_imm3: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.s, vl3 +; CHECK-NEXT: ret +entry: + %active.lane.mask = call @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 3) + ret %active.lane.mask +} + +define @lane_mask_nxv4i1_imm4() { +; CHECK-LABEL: lane_mask_nxv4i1_imm4: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov w8, #4 +; CHECK-NEXT: whilelo p0.s, xzr, x8 +; CHECK-NEXT: ret +entry: + %active.lane.mask = call @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 4) + ret %active.lane.mask +} + +define @lane_mask_nxv16i1_imm10() { +; CHECK-LABEL: lane_mask_nxv16i1_imm10: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov w8, #10 +; CHECK-NEXT: whilelo p0.b, xzr, x8 +; CHECK-NEXT: ret +entry: + %active.lane.mask = call @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 10) + ret %active.lane.mask +} + declare @llvm.get.active.lane.mask.nxv32i1.i32(i32, i32) declare @llvm.get.active.lane.mask.nxv16i1.i32(i32, i32)