diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14364,7 +14364,46 @@ } } -static SDValue performExtractVectorEltCombine(SDNode *N, SelectionDAG &DAG) { +static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, + AArch64CC::CondCode Cond); + +// Materialize : i1 = extract_vector_elt t37, Constant:i64<0> +// ... into: "ptrue p, all" + PTEST +static SDValue +performFirstTrueTestVectorCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const AArch64Subtarget *Subtarget) { + assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT); + // Make sure PTEST can be legalised with illegal types. + if (!Subtarget->hasSVE() || DCI.isBeforeLegalize()) + return SDValue(); + + SDValue SetCC = N->getOperand(0); + EVT VT = SetCC.getValueType(); + + if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1) + return SDValue(); + + // Restricted the DAG combine to only cases where we're extracting from a + // flag-setting operation + auto *Idx = dyn_cast(N->getOperand(1)); + if (!Idx || !Idx->isZero() || SetCC.getOpcode() != ISD::SETCC) + return SDValue(); + + // Extracts of lane 0 for SVE can be expressed as PTEST(Op, FIRST) ? 1 : 0 + SelectionDAG &DAG = DCI.DAG; + SDValue Pg = getPTrue(DAG, SDLoc(N), VT, AArch64SVEPredPattern::all); + return getPTest(DAG, N->getValueType(0), Pg, SetCC, AArch64CC::FIRST_ACTIVE); +} + +static SDValue +performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, + const AArch64Subtarget *Subtarget) { + assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT); + if (SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget)) + return Res; + + SelectionDAG &DAG = DCI.DAG; SDValue N0 = N->getOperand(0), N1 = N->getOperand(1); ConstantSDNode *ConstantN1 = dyn_cast(N1); @@ -18356,7 +18395,7 @@ case ISD::INSERT_VECTOR_ELT: return performInsertVectorEltCombine(N, DCI); case ISD::EXTRACT_VECTOR_ELT: - return performExtractVectorEltCombine(N, DAG); + return performExtractVectorEltCombine(N, DCI, Subtarget); case ISD::VECREDUCE_ADD: return performVecReduceAddCombine(N, DCI.DAG, Subtarget); case AArch64ISD::UADDV: diff --git a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll --- a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll +++ b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll @@ -52,3 +52,17 @@ %not = xor %icmp, %ones ret %not } + +define i1 @foo( %a, %b) { +; CHECK-LABEL: foo: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmeq p1.s, p0/z, z0.s, z1.s +; CHECK-NEXT: ptest p0, p1.b +; CHECK-NEXT: cset w0, mi +; CHECK-NEXT: ret + %vcond = fcmp oeq %a, %b + %bit = extractelement %vcond, i64 0 + ret i1 %bit +} +