diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14410,12 +14410,57 @@ return getPTest(DAG, N->getValueType(0), Pg, SetCC, AArch64CC::FIRST_ACTIVE); } +// Materialize : Idx = (add (mul vscale, NumEls), -1) +// i1 = extract_vector_elt t37, Constant:i64 +// ... into: "ptrue p, all" + PTEST +static SDValue +performLastTrueTestVectorCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const AArch64Subtarget *Subtarget) { + assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT); + // Make sure PTEST is legal types. + if (!Subtarget->hasSVE() || DCI.isBeforeLegalize()) + return SDValue(); + + SDValue SetCC = N->getOperand(0); + EVT OpVT = SetCC.getValueType(); + + if (!OpVT.isScalableVector() || OpVT.getVectorElementType() != MVT::i1) + return SDValue(); + + // Idx == (add (mul vscale, NumEls), -1) + SDValue Idx = N->getOperand(1); + if (Idx.getOpcode() != ISD::ADD) + return SDValue(); + + SDValue VS = Idx.getOperand(0); + if (VS.getOpcode() != ISD::VSCALE) + return SDValue(); + + unsigned NumEls = OpVT.getVectorElementCount().getKnownMinValue(); + if (VS.getConstantOperandVal(0) != NumEls) + return SDValue(); + + // Restricted the DAG combine to only cases where we're extracting from a + // flag-setting operation + auto *CI = dyn_cast(Idx.getOperand(1)); + if (!CI || !CI->isAllOnes() || SetCC.getOpcode() != ISD::SETCC) + return SDValue(); + + // Extracts of lane EC-1 for SVE can be expressed as PTEST(Op, LAST) ? 1 : 0 + SelectionDAG &DAG = DCI.DAG; + SDValue Pg = getPTrue(DAG, SDLoc(N), OpVT, AArch64SVEPredPattern::all); + return getPTest(DAG, N->getValueType(0), Pg, SetCC, AArch64CC::LAST_ACTIVE); +} + static SDValue performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT); if (SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget)) return Res; + if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget)) + return Res; SelectionDAG &DAG = DCI.DAG; SDValue N0 = N->getOperand(0), N1 = N->getOperand(1); diff --git a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll --- a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll +++ b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll @@ -53,8 +53,8 @@ ret %not } -define i1 @foo( %a, %b) { -; CHECK-LABEL: foo: +define i1 @foo_first( %a, %b) { +; CHECK-LABEL: foo_first: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.s ; CHECK-NEXT: fcmeq p1.s, p0/z, z0.s, z1.s @@ -66,3 +66,21 @@ ret i1 %bit } +define i1 @foo_last( %a, %b) { +; CHECK-LABEL: foo_last: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmeq p1.s, p0/z, z0.s, z1.s +; CHECK-NEXT: ptest p0, p1.b +; CHECK-NEXT: cset w0, lo +; CHECK-NEXT: ret + %vcond = fcmp oeq %a, %b + %vscale = call i64 @llvm.vscale.i64() + %shl2 = shl nuw nsw i64 %vscale, 2 + %idx = add nuw nsw i64 %shl2, -1 + %bit = extractelement %vcond, i64 %idx + ret i1 %bit +} + + +declare i64 @llvm.vscale.i64()