Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1154,10 +1154,6 @@ // This function does not handle predicate bitcasts. SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const; - // Returns a safe bitcast between two scalable vector predicates, where - // any newly created lanes from a widening bitcast are defined as zero. - SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const; - bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1, LLT Ty2) const override; }; Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -237,6 +237,39 @@ } } +// Returns true if newly defined lanes are known to be zeroed by construction. +static bool hasZeroedOtherLanes(SDValue Op) { + switch (Op.getOpcode()) { + default: + // We guarantee i1 splat_vectors to zero the other lanes by + // implementing it with ptrue and possibly a punpklo for nxv1i1. + if (ISD::isConstantSplatVectorAllOnes(Op.getNode())) + return true; + return false; + case AArch64ISD::PTRUE: + case AArch64ISD::SETCC_MERGE_ZERO: + return true; + case ISD::INTRINSIC_WO_CHAIN: + switch (Op.getConstantOperandVal(0)) { + default: + return false; + case Intrinsic::aarch64_sve_ptrue: + case Intrinsic::aarch64_sve_pnext: + case Intrinsic::aarch64_sve_cmpeq_wide: + case Intrinsic::aarch64_sve_cmpne_wide: + case Intrinsic::aarch64_sve_cmpge_wide: + case Intrinsic::aarch64_sve_cmpgt_wide: + case Intrinsic::aarch64_sve_cmplt_wide: + case Intrinsic::aarch64_sve_cmple_wide: + case Intrinsic::aarch64_sve_cmphs_wide: + case Intrinsic::aarch64_sve_cmphi_wide: + case Intrinsic::aarch64_sve_cmplo_wide: + case Intrinsic::aarch64_sve_cmpls_wide: + return true; + } + } +} + AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, const AArch64Subtarget &STI) : TargetLowering(TM), Subtarget(&STI) { @@ -4368,16 +4401,18 @@ DAG.getTargetConstant(Pattern, DL, MVT::i32)); } -SDValue AArch64TargetLowering::getSVEPredicateBitCast(EVT VT, SDValue Op, - SelectionDAG &DAG) const { +// Returns a safe bitcast between two scalable vector predicates, where +// any newly created lanes from a widening bitcast are defined as zero. +static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) { SDLoc DL(Op); EVT InVT = Op.getValueType(); assert(InVT.getVectorElementType() == MVT::i1 && VT.getVectorElementType() == MVT::i1 && "Expected a predicate-to-predicate bitcast"); - assert(VT.isScalableVector() && isTypeLegal(VT) && - InVT.isScalableVector() && isTypeLegal(InVT) && + assert(VT.isScalableVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) && + InVT.isScalableVector() && + DAG.getTargetLoweringInfo().isTypeLegal(InVT) && "Only expect to cast between legal scalable predicate types!"); // Return the operand if the cast isn't changing type, @@ -4396,33 +4431,8 @@ // Check if the other lanes are already known to be zeroed by // construction. - switch (Op.getOpcode()) { - default: - // We guarantee i1 splat_vectors to zero the other lanes by - // implementing it with ptrue and possibly a punpklo for nxv1i1. - if (ISD::isConstantSplatVectorAllOnes(Op.getNode())) - return Reinterpret; - break; - case AArch64ISD::SETCC_MERGE_ZERO: + if (hasZeroedOtherLanes(Op)) return Reinterpret; - case ISD::INTRINSIC_WO_CHAIN: - switch (Op.getConstantOperandVal(0)) { - default: - break; - case Intrinsic::aarch64_sve_ptrue: - case Intrinsic::aarch64_sve_cmpeq_wide: - case Intrinsic::aarch64_sve_cmpne_wide: - case Intrinsic::aarch64_sve_cmpge_wide: - case Intrinsic::aarch64_sve_cmpgt_wide: - case Intrinsic::aarch64_sve_cmplt_wide: - case Intrinsic::aarch64_sve_cmple_wide: - case Intrinsic::aarch64_sve_cmphs_wide: - case Intrinsic::aarch64_sve_cmphi_wide: - case Intrinsic::aarch64_sve_cmplo_wide: - case Intrinsic::aarch64_sve_cmpls_wide: - return Reinterpret; - } - } // Zero the newly introduced lanes. SDValue Mask = DAG.getConstant(1, DL, InVT); @@ -16171,8 +16181,38 @@ SDValue TVal = DAG.getConstant(1, DL, OutVT); SDValue FVal = DAG.getConstant(0, DL, OutVT); + // Ensure operands have type nxv16i1. + MVT WidenVT = MVT::nxv16i1; + SDValue ReinterpretOp = Op; + SDValue ReinterpretPg = Pg; + if (Op.getValueType() != WidenVT) { + switch (Cond) { + default: + ReinterpretPg = getSVEPredicateBitCast(WidenVT, Pg, DAG); + break; + case AArch64CC::ANY_ACTIVE: + case AArch64CC::NONE_ACTIVE: + if (hasZeroedOtherLanes(Op) || hasZeroedOtherLanes(Pg)) + ReinterpretPg = + DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, WidenVT, Pg); + else + ReinterpretPg = getSVEPredicateBitCast(WidenVT, Pg, DAG); + break; + case AArch64CC::FIRST_ACTIVE: + case AArch64CC::LAST_ACTIVE: + if (hasZeroedOtherLanes(Pg)) + ReinterpretPg = + DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, WidenVT, Pg); + else + ReinterpretPg = getSVEPredicateBitCast(WidenVT, Pg, DAG); + break; + } + ReinterpretOp = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, WidenVT, Op); + } + // Set condition code (CC) flags. - SDValue Test = DAG.getNode(AArch64ISD::PTEST, DL, MVT::Other, Pg, Op); + SDValue Test = DAG.getNode(AArch64ISD::PTEST, DL, MVT::Other, ReinterpretPg, + ReinterpretOp); // Convert CC to integer based on requested condition. // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare. Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -2133,14 +2133,6 @@ def : Pat<(AArch64ptest (nxv16i1 PPR:$pg), (nxv16i1 PPR:$src)), (PTEST_PP PPR:$pg, PPR:$src)>; - def : Pat<(AArch64ptest (nxv8i1 PPR:$pg), (nxv8i1 PPR:$src)), - (PTEST_PP PPR:$pg, PPR:$src)>; - def : Pat<(AArch64ptest (nxv4i1 PPR:$pg), (nxv4i1 PPR:$src)), - (PTEST_PP PPR:$pg, PPR:$src)>; - def : Pat<(AArch64ptest (nxv2i1 PPR:$pg), (nxv2i1 PPR:$src)), - (PTEST_PP PPR:$pg, PPR:$src)>; - def : Pat<(AArch64ptest (nxv1i1 PPR:$pg), (nxv1i1 PPR:$src)), - (PTEST_PP PPR:$pg, PPR:$src)>; let AddedComplexity = 1 in { class LD1RPat* %out, %in, %pg) { ; CHECK-LABEL: sve_cmplt_setcc_hslo: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: cmplt p1.h, p0/z, z0.h, #0 +; CHECK-NEXT: ptrue p1.h +; CHECK-NEXT: cmplt p2.h, p0/z, z0.h, #0 +; CHECK-NEXT: and p1.b, p0/z, p0.b, p1.b +; CHECK-NEXT: ptest p1, p2.b ; CHECK-NEXT: b.hs .LBB2_2 ; CHECK-NEXT: // %bb.1: // %if.then ; CHECK-NEXT: st1h { z0.h }, p0, [x0]