diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14509,6 +14509,22 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, AArch64CC::CondCode Cond); +static bool isPredicateCCSettingOp(SDValue N) { + if ((N.getOpcode() == ISD::SETCC) || + (N.getOpcode() == ISD::INTRINSIC_WO_CHAIN && + (N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege || + N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt || + N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi || + N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs || + N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele || + N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo || + N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels || + N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt))) + return true; + + return false; +} + // Materialize : i1 = extract_vector_elt t37, Constant:i64<0> // ... into: "ptrue p, all" + PTEST static SDValue @@ -14520,22 +14536,22 @@ if (!Subtarget->hasSVE() || DCI.isBeforeLegalize()) return SDValue(); - SDValue SetCC = N->getOperand(0); - EVT VT = SetCC.getValueType(); + SDValue N0 = N->getOperand(0); + EVT VT = N0.getValueType(); - if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1) + if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1 || + !isNullConstant(N->getOperand(1))) return SDValue(); // Restricted the DAG combine to only cases where we're extracting from a - // flag-setting operation - auto *Idx = dyn_cast(N->getOperand(1)); - if (!Idx || !Idx->isZero() || SetCC.getOpcode() != ISD::SETCC) + // flag-setting operation. + if (!isPredicateCCSettingOp(N0)) return SDValue(); // Extracts of lane 0 for SVE can be expressed as PTEST(Op, FIRST) ? 1 : 0 SelectionDAG &DAG = DCI.DAG; SDValue Pg = getPTrue(DAG, SDLoc(N), VT, AArch64SVEPredPattern::all); - return getPTest(DAG, N->getValueType(0), Pg, SetCC, AArch64CC::FIRST_ACTIVE); + return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::FIRST_ACTIVE); } // Materialize : Idx = (add (mul vscale, NumEls), -1) diff --git a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll --- a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll +++ b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc -mtriple=aarch64-linux-unknown -mattr=+sve -o - < %s | FileCheck %s +; RUN: llc -mtriple=aarch64-linux-unknown -mattr=+sve2 -o - < %s | FileCheck %s define @not_icmp_sle_nxv8i16( %a, %b) { ; CHECK-LABEL: not_icmp_sle_nxv8i16: @@ -82,5 +82,100 @@ ret i1 %bit } +define i1 @whilege_first(i64 %next, i64 %end) { +; CHECK-LABEL: whilege_first: +; CHECK: // %bb.0: +; CHECK-NEXT: whilege p0.s, x0, x1 +; CHECK-NEXT: cset w0, mi +; CHECK-NEXT: ret + %predicate = call @llvm.aarch64.sve.whilege.nxv4i1.i64(i64 %next, i64 %end) + %bit = extractelement %predicate, i64 0 + ret i1 %bit +} + +define i1 @whilegt_first(i64 %next, i64 %end) { +; CHECK-LABEL: whilegt_first: +; CHECK: // %bb.0: +; CHECK-NEXT: whilegt p0.s, x0, x1 +; CHECK-NEXT: cset w0, mi +; CHECK-NEXT: ret + %predicate = call @llvm.aarch64.sve.whilegt.nxv4i1.i64(i64 %next, i64 %end) + %bit = extractelement %predicate, i64 0 + ret i1 %bit +} + +define i1 @whilehi_first(i64 %next, i64 %end) { +; CHECK-LABEL: whilehi_first: +; CHECK: // %bb.0: +; CHECK-NEXT: whilehi p0.s, x0, x1 +; CHECK-NEXT: cset w0, mi +; CHECK-NEXT: ret + %predicate = call @llvm.aarch64.sve.whilehi.nxv4i1.i64(i64 %next, i64 %end) + %bit = extractelement %predicate, i64 0 + ret i1 %bit +} + +define i1 @whilehs_first(i64 %next, i64 %end) { +; CHECK-LABEL: whilehs_first: +; CHECK: // %bb.0: +; CHECK-NEXT: whilehs p0.s, x0, x1 +; CHECK-NEXT: cset w0, mi +; CHECK-NEXT: ret + %predicate = call @llvm.aarch64.sve.whilehs.nxv4i1.i64(i64 %next, i64 %end) + %bit = extractelement %predicate, i64 0 + ret i1 %bit +} + +define i1 @whilele_first(i64 %next, i64 %end) { +; CHECK-LABEL: whilele_first: +; CHECK: // %bb.0: +; CHECK-NEXT: whilele p0.s, x0, x1 +; CHECK-NEXT: cset w0, mi +; CHECK-NEXT: ret + %predicate = call @llvm.aarch64.sve.whilele.nxv4i1.i64(i64 %next, i64 %end) + %bit = extractelement %predicate, i64 0 + ret i1 %bit +} + +define i1 @whilelo_first(i64 %next, i64 %end) { +; CHECK-LABEL: whilelo_first: +; CHECK: // %bb.0: +; CHECK-NEXT: whilelo p0.s, x0, x1 +; CHECK-NEXT: cset w0, mi +; CHECK-NEXT: ret + %predicate = call @llvm.aarch64.sve.whilelo.nxv4i1.i64(i64 %next, i64 %end) + %bit = extractelement %predicate, i64 0 + ret i1 %bit +} + +define i1 @whilels_first(i64 %next, i64 %end) { +; CHECK-LABEL: whilels_first: +; CHECK: // %bb.0: +; CHECK-NEXT: whilels p0.s, x0, x1 +; CHECK-NEXT: cset w0, mi +; CHECK-NEXT: ret + %predicate = call @llvm.aarch64.sve.whilels.nxv4i1.i64(i64 %next, i64 %end) + %bit = extractelement %predicate, i64 0 + ret i1 %bit +} + +define i1 @whilelt_first(i64 %next, i64 %end) { +; CHECK-LABEL: whilelt_first: +; CHECK: // %bb.0: +; CHECK-NEXT: whilelt p0.s, x0, x1 +; CHECK-NEXT: cset w0, mi +; CHECK-NEXT: ret + %predicate = call @llvm.aarch64.sve.whilelt.nxv4i1.i64(i64 %next, i64 %end) + %bit = extractelement %predicate, i64 0 + ret i1 %bit +} declare i64 @llvm.vscale.i64() +declare @llvm.aarch64.sve.whilege.nxv4i1.i64(i64, i64) +declare @llvm.aarch64.sve.whilegt.nxv4i1.i64(i64, i64) +declare @llvm.aarch64.sve.whilehi.nxv4i1.i64(i64, i64) +declare @llvm.aarch64.sve.whilehs.nxv4i1.i64(i64, i64) +declare @llvm.aarch64.sve.whilele.nxv4i1.i64(i64, i64) +declare @llvm.aarch64.sve.whilelo.nxv4i1.i64(i64, i64) +declare @llvm.aarch64.sve.whilels.nxv4i1.i64(i64, i64) +declare @llvm.aarch64.sve.whilelt.nxv4i1.i64(i64, i64)