diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -888,6 +888,7 @@ setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); setTargetDAGCombine(ISD::VECREDUCE_ADD); setTargetDAGCombine(ISD::VECREDUCE_OR); + setTargetDAGCombine(ISD::VECREDUCE_AND); setTargetDAGCombine(ISD::STEP_VECTOR); setTargetDAGCombine(ISD::MGATHER); @@ -13327,7 +13328,12 @@ // manually define these to be 0 or 1. if (!TLI.isAllActivePredicate(DAG, PredForVL)) { EVT PredVT = PredForVL.getValueType(); - Pred = DAG.getNode(ISD::AND, DL, PredVT, Pred, PredForVL); + if (N->getOpcode() == ISD::VECREDUCE_OR) + Pred = DAG.getNode(ISD::AND, DL, PredVT, Pred, PredForVL); + else + Pred = DAG.getNode( + ISD::OR, DL, PredVT, Pred, + DAG.getNode(ISD::XOR, DL, PredVT, PredForVL, PredForVL)); } return DAG.getNode(N->getOpcode(), DL, N->getValueType(0), Pred); @@ -14309,6 +14315,20 @@ if (SDValue R = PerformANDCSELCombine(N, DAG)) return R; + // Try to perform the operation on SVE predicate vectors, if available. + SDValue NewLHS, NewRHS; + if (VT.isFixedLengthVector() && + (NewLHS = findScalablePredicateOperand(N->getOperand(0), DAG)) && + (NewRHS = findScalablePredicateOperand(N->getOperand(1), DAG))) { + assert(!(isa(NewLHS) && isa(NewRHS)) && + "Expected nodes to have been constant folded"); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, N->getValueType(0)); + SDValue PredAND = + DAG.getNode(ISD::AND, SDLoc(N), NewLHS.getValueType(), NewLHS, NewRHS); + SDValue Ext = DAG.getSExtOrTrunc(PredAND, SDLoc(N), ContainerVT); + return convertFromScalableVector(DAG, N->getValueType(0), Ext); + } + if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT)) return SDValue(); @@ -18263,6 +18283,7 @@ case AArch64ISD::UADDV: return performUADDVCombine(N, DAG); case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_AND: return performVecreduceAndOrCombine(N, *this, DAG); case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-ptest.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-ptest.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-ptest.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-ptest.ll @@ -106,16 +106,12 @@ ; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0] ; CHECK-NEXT: ld1w { z1.s }, p0/z, [x1] ; CHECK-NEXT: fcmeq p1.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fcmeq p2.s, p0/z, z1.s, #0.0 ; CHECK-NEXT: not p1.b, p0/z, p1.b -; CHECK-NEXT: fcmeq p0.s, p0/z, z1.s, #0.0 -; CHECK-NEXT: bic p0.b, p1/z, p1.b, p0.b -; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff -; CHECK-NEXT: ptrue p0.b, vl16 -; CHECK-NEXT: uzp1 z0.h, z0.h, z0.h -; CHECK-NEXT: uzp1 z0.b, z0.b, z0.b -; CHECK-NEXT: andv b0, p0, z0.b -; CHECK-NEXT: fmov w8, s0 -; CHECK-NEXT: and w0, w8, #0x1 +; CHECK-NEXT: bic p1.b, p1/z, p1.b, p2.b +; CHECK-NEXT: not p1.b, p0/z, p1.b +; CHECK-NEXT: ptest p0, p1.b +; CHECK-NEXT: cset w0, eq ; CHECK-NEXT: ret %v0 = bitcast float* %a to <16 x float>* %v1 = load <16 x float>, <16 x float>* %v0, align 4 @@ -139,13 +135,9 @@ ; CHECK-NEXT: fcmeq p0.s, p0/z, z1.s, #0.0 ; CHECK-NEXT: not p1.b, p2/z, p1.b ; CHECK-NEXT: bic p0.b, p1/z, p1.b, p0.b -; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff -; CHECK-NEXT: ptrue p0.b, vl16 -; CHECK-NEXT: uzp1 z0.h, z0.h, z0.h -; CHECK-NEXT: uzp1 z0.b, z0.b, z0.b -; CHECK-NEXT: andv b0, p0, z0.b -; CHECK-NEXT: fmov w8, s0 -; CHECK-NEXT: and w0, w8, #0x1 +; CHECK-NEXT: not p0.b, p2/z, p0.b +; CHECK-NEXT: ptest p2, p0.b +; CHECK-NEXT: cset w0, eq ; CHECK-NEXT: ret %v0 = bitcast float* %a to <16 x float>* %v1 = load <16 x float>, <16 x float>* %v0, align 4