Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17010,15 +17010,43 @@ SDValue LHS = N->getOperand(1); SDValue RHS = N->getOperand(2); ISD::CondCode Cond = cast(N->getOperand(3))->get(); + SDLoc dl(N); - // setcc_merge_zero pred (sign_extend (setcc_merge_zero ... pred ...)), 0, ne - // => inner setcc_merge_zero - if (Cond == ISD::SETNE && isZerosVector(RHS.getNode()) && - LHS->getOpcode() == ISD::SIGN_EXTEND && - LHS->getOperand(0)->getValueType(0) == N->getValueType(0) && - LHS->getOperand(0)->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO && - LHS->getOperand(0)->getOperand(0) == Pred) - return LHS->getOperand(0); + // Common `X != 0` combines. + if (Cond == ISD::SETNE && isZerosVector(RHS.getNode())) { + // setcc_merge_zero pred (sext (setcc_merge_zero ... pred ...)), 0, ne + // => inner setcc_merge_zero + if (LHS->getOpcode() == ISD::SIGN_EXTEND && + LHS->getOperand(0)->getValueType(0) == N->getValueType(0) && + LHS->getOperand(0)->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO && + LHS->getOperand(0)->getOperand(0) == Pred) + return LHS->getOperand(0); + + // setcc_merge_zero pred (splat(1) & 1), 0, ne + // => ptrue + uint64_t SplatVal; + bool Negated; + if (isAllActivePredicate(DAG, Pred) && LHS.getOpcode() == ISD::AND && + isPow2Splat(LHS.getOperand(1), SplatVal, Negated) && SplatVal == 1) { + // We found a vXi1 truncate. Now check if we're truncating a fixed width + // splat. + SDValue Trunc = LHS.getOperand(0); + uint64_t TruncSplatVal; + bool TruncNegated; + if (Trunc.getOpcode() == ISD::INSERT_SUBVECTOR && + Trunc.getOperand(0).isUndef() && + isPow2Splat(Trunc.getOperand(1), TruncSplatVal, TruncNegated) && + TruncSplatVal == 1) { + // Generate a PTRUE. The VL pattern is the number of elements + // of the fixed width splat input. + EVT FixedVT = Trunc.getOperand(1).getValueType(); + Optional PredPattern = + getSVEPredPatternFromNumElements(FixedVT.getVectorNumElements()); + auto PredTy = LHS.getValueType().changeVectorElementType(MVT::i1); + return getPTrue(DAG, dl, PredTy, *PredPattern); + } + } + } if (SDValue V = performSetCCPunpkCombine(N, DAG)) return V; Index: llvm/test/CodeGen/AArch64/sve-fixed-length-ptrue.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-fixed-length-ptrue.ll @@ -0,0 +1,51 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s | FileCheck %s + +target triple = "aarch64-unknown-linux-gnu" + +define @ptest_v8i1() #0 { +; CHECK-LABEL: ptest_v8i1: +; CHECK: // %bb.0: // %L.entry +; CHECK-NEXT: ptrue p0.d, vl8 +; CHECK-NEXT: ret +L.entry: + %0 = call @llvm.experimental.vector.insert.nxv2i1.v8i1 ( undef, <8 x i1> , i64 0) + ret %0 +} + +define @ptest_v16i1() #0 { +; CHECK-LABEL: ptest_v16i1: +; CHECK: // %bb.0: // %L.entry +; CHECK-NEXT: ptrue p0.s, vl16 +; CHECK-NEXT: ret +L.entry: + %0 = call @llvm.experimental.vector.insert.nxv4i1.v16i1 ( undef, <16 x i1> , i64 0) + ret %0 +} + +define @ptest_v32i1() #0 { +; CHECK-LABEL: ptest_v32i1: +; CHECK: // %bb.0: // %L.entry +; CHECK-NEXT: ptrue p0.h, vl32 +; CHECK-NEXT: ret +L.entry: + %0 = call @llvm.experimental.vector.insert.nxv8i1.v32i1 ( undef, <32 x i1> , i64 0) + ret %0 +} + +define @ptest_v64i1() #0 { +; CHECK-LABEL: ptest_v64i1: +; CHECK: // %bb.0: // %L.entry +; CHECK-NEXT: ptrue p0.b, vl64 +; CHECK-NEXT: ret +L.entry: + %0 = call @llvm.experimental.vector.insert.nxv16i1.v64i1 ( undef, <64 x i1> , i64 0) + ret %0 +} + +declare @llvm.experimental.vector.insert.nxv2i1.v8i1(, <8 x i1>, i64) +declare @llvm.experimental.vector.insert.nxv4i1.v16i1(, <16 x i1>, i64) +declare @llvm.experimental.vector.insert.nxv8i1.v32i1(, <32 x i1>, i64) +declare @llvm.experimental.vector.insert.nxv16i1.v64i1(, <64 x i1>, i64) + +attributes #0 = { vscale_range(4,4) "target-features"="+sve" }