Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1154,10 +1154,6 @@ // This function does not handle predicate bitcasts. SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const; - // Returns a safe bitcast between two scalable vector predicates, where - // any newly created lanes from a widening bitcast are defined as zero. - SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const; - bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1, LLT Ty2) const override; }; Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -237,6 +237,39 @@ } } +// Returns true if inactive lanes are known to be zeroed by construction. +static bool isZeroingInactiveLanes(SDValue Op) { + switch (Op.getOpcode()) { + default: + // We guarantee i1 splat_vectors to zero the other lanes by + // implementing it with ptrue and possibly a punpklo for nxv1i1. + if (ISD::isConstantSplatVectorAllOnes(Op.getNode())) + return true; + return false; + case AArch64ISD::PTRUE: + case AArch64ISD::SETCC_MERGE_ZERO: + return true; + case ISD::INTRINSIC_WO_CHAIN: + switch (Op.getConstantOperandVal(0)) { + default: + return false; + case Intrinsic::aarch64_sve_ptrue: + case Intrinsic::aarch64_sve_pnext: + case Intrinsic::aarch64_sve_cmpeq_wide: + case Intrinsic::aarch64_sve_cmpne_wide: + case Intrinsic::aarch64_sve_cmpge_wide: + case Intrinsic::aarch64_sve_cmpgt_wide: + case Intrinsic::aarch64_sve_cmplt_wide: + case Intrinsic::aarch64_sve_cmple_wide: + case Intrinsic::aarch64_sve_cmphs_wide: + case Intrinsic::aarch64_sve_cmphi_wide: + case Intrinsic::aarch64_sve_cmplo_wide: + case Intrinsic::aarch64_sve_cmpls_wide: + return true; + } + } +} + AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, const AArch64Subtarget &STI) : TargetLowering(TM), Subtarget(&STI) { @@ -4368,16 +4401,18 @@ DAG.getTargetConstant(Pattern, DL, MVT::i32)); } -SDValue AArch64TargetLowering::getSVEPredicateBitCast(EVT VT, SDValue Op, - SelectionDAG &DAG) const { +// Returns a safe bitcast between two scalable vector predicates, where +// any newly created lanes from a widening bitcast are defined as zero. +static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) { SDLoc DL(Op); EVT InVT = Op.getValueType(); assert(InVT.getVectorElementType() == MVT::i1 && VT.getVectorElementType() == MVT::i1 && "Expected a predicate-to-predicate bitcast"); - assert(VT.isScalableVector() && isTypeLegal(VT) && - InVT.isScalableVector() && isTypeLegal(InVT) && + assert(VT.isScalableVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) && + InVT.isScalableVector() && + DAG.getTargetLoweringInfo().isTypeLegal(InVT) && "Only expect to cast between legal scalable predicate types!"); // Return the operand if the cast isn't changing type, @@ -4396,33 +4431,8 @@ // Check if the other lanes are already known to be zeroed by // construction. - switch (Op.getOpcode()) { - default: - // We guarantee i1 splat_vectors to zero the other lanes by - // implementing it with ptrue and possibly a punpklo for nxv1i1. - if (ISD::isConstantSplatVectorAllOnes(Op.getNode())) - return Reinterpret; - break; - case AArch64ISD::SETCC_MERGE_ZERO: + if (isZeroingInactiveLanes(Op)) return Reinterpret; - case ISD::INTRINSIC_WO_CHAIN: - switch (Op.getConstantOperandVal(0)) { - default: - break; - case Intrinsic::aarch64_sve_ptrue: - case Intrinsic::aarch64_sve_cmpeq_wide: - case Intrinsic::aarch64_sve_cmpne_wide: - case Intrinsic::aarch64_sve_cmpge_wide: - case Intrinsic::aarch64_sve_cmpgt_wide: - case Intrinsic::aarch64_sve_cmplt_wide: - case Intrinsic::aarch64_sve_cmple_wide: - case Intrinsic::aarch64_sve_cmphs_wide: - case Intrinsic::aarch64_sve_cmphi_wide: - case Intrinsic::aarch64_sve_cmplo_wide: - case Intrinsic::aarch64_sve_cmpls_wide: - return Reinterpret; - } - } // Zero the newly introduced lanes. SDValue Mask = DAG.getConstant(1, DL, InVT); @@ -16164,12 +16174,24 @@ assert(Op.getValueType().isScalableVector() && TLI.isTypeLegal(Op.getValueType()) && "Expected legal scalable vector type!"); + assert(Op.getValueType() == Pg.getValueType() && + "Expected same type for PTEST operands"); // Ensure target specific opcodes are using legal type. EVT OutVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT); SDValue TVal = DAG.getConstant(1, DL, OutVT); SDValue FVal = DAG.getConstant(0, DL, OutVT); + // Ensure operands have type nxv16i1. + if (Op.getValueType() != MVT::nxv16i1) { + if ((Cond == AArch64CC::ANY_ACTIVE || Cond == AArch64CC::NONE_ACTIVE) && + isZeroingInactiveLanes(Op)) + Pg = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Pg); + else + Pg = getSVEPredicateBitCast(MVT::nxv16i1, Pg, DAG); + Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Op); + } + // Set condition code (CC) flags. SDValue Test = DAG.getNode(AArch64ISD::PTEST, DL, MVT::Other, Pg, Op); Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -778,7 +778,7 @@ defm BRKB_PPmP : sve_int_break_m<0b101, "brkb", int_aarch64_sve_brkb>; defm BRKBS_PPzP : sve_int_break_z<0b110, "brkbs", null_frag>; - def PTEST_PP : sve_int_ptest<0b010000, "ptest">; + def PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest>; defm PFALSE : sve_int_pfalse<0b000000, "pfalse">; defm PFIRST : sve_int_pfirst<0b00000, "pfirst", int_aarch64_sve_pfirst>; defm PNEXT : sve_int_pnext<0b00110, "pnext", int_aarch64_sve_pnext>; @@ -2131,17 +2131,6 @@ def STR_ZZZZXI : Pseudo<(outs), (ins ZZZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>; } - def : Pat<(AArch64ptest (nxv16i1 PPR:$pg), (nxv16i1 PPR:$src)), - (PTEST_PP PPR:$pg, PPR:$src)>; - def : Pat<(AArch64ptest (nxv8i1 PPR:$pg), (nxv8i1 PPR:$src)), - (PTEST_PP PPR:$pg, PPR:$src)>; - def : Pat<(AArch64ptest (nxv4i1 PPR:$pg), (nxv4i1 PPR:$src)), - (PTEST_PP PPR:$pg, PPR:$src)>; - def : Pat<(AArch64ptest (nxv2i1 PPR:$pg), (nxv2i1 PPR:$src)), - (PTEST_PP PPR:$pg, PPR:$src)>; - def : Pat<(AArch64ptest (nxv1i1 PPR:$pg), (nxv1i1 PPR:$src)), - (PTEST_PP PPR:$pg, PPR:$src)>; - let AddedComplexity = 1 in { class LD1RPat : Index: llvm/lib/Target/AArch64/SVEInstrFormats.td =================================================================== --- llvm/lib/Target/AArch64/SVEInstrFormats.td +++ llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -650,11 +650,11 @@ def : Pat<(nxv1i1 immAllZerosV), (!cast(NAME))>; } -class sve_int_ptest opc, string asm> +class sve_int_ptest opc, string asm, SDPatternOperator op> : I<(outs), (ins PPRAny:$Pg, PPR8:$Pn), asm, "\t$Pg, $Pn", "", - []>, Sched<[]> { + [(op (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>, Sched<[]> { bits<4> Pg; bits<4> Pn; let Inst{31-24} = 0b00100101; Index: llvm/test/CodeGen/AArch64/sve-setcc.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-setcc.ll +++ llvm/test/CodeGen/AArch64/sve-setcc.ll @@ -51,7 +51,10 @@ define void @sve_cmplt_setcc_hslo(* %out, %in, %pg) { ; CHECK-LABEL: sve_cmplt_setcc_hslo: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: cmplt p1.h, p0/z, z0.h, #0 +; CHECK-NEXT: ptrue p1.h +; CHECK-NEXT: cmplt p2.h, p0/z, z0.h, #0 +; CHECK-NEXT: and p1.b, p0/z, p0.b, p1.b +; CHECK-NEXT: ptest p1, p2.b ; CHECK-NEXT: b.hs .LBB2_2 ; CHECK-NEXT: // %bb.1: // %if.then ; CHECK-NEXT: st1h { z0.h }, p0, [x0]