diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14103,6 +14103,23 @@ return DAG.getNode(Opc, DL, N->getValueType(0), And); } + // Fold: + // and( + // setcc_merge_zero(pred, ...), + // setcc_merge_zero(pred, ...)) + // -> setcc_merge_zero( + // setcc_merge_zero(pred, ...), + // ...) + if (N->getOperand(0).getOpcode() == AArch64ISD::SETCC_MERGE_ZERO && + N->getOperand(1).getOpcode() == AArch64ISD::SETCC_MERGE_ZERO) { + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + if (LHS.getOperand(0) == RHS.getOperand(0)) + return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N), + N->getValueType(0), LHS, RHS.getOperand(1), + RHS.getOperand(2), RHS.getOperand(3)); + } + if (!EnableCombineMGatherIntrinsics) return SDValue(); @@ -17072,6 +17089,16 @@ return SDValue(); } +static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL, + EVT VT); +// Pattern match utility function to return if V is a conversion of a +// fixed-width vector -> scalable vector. +static bool isConvertToScalableVector(SDValue V) { + return V.getOpcode() == ISD::INSERT_SUBVECTOR && V.getOperand(0).isUndef() && + V.getOperand(1).getValueType().isFixedLengthVector() && + V.getConstantOperandVal(2) == 0; +} + static SDValue performSetccMergeZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { assert(N->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO && @@ -17086,6 +17113,24 @@ if (SDValue V = performSetCCPunpkCombine(N, DAG)) return V; + // Make the predicate more specific, because there is likely already a + // predicate available for the given VL. + // setcc(cc, , insert_subvector(undef, fixed_width_vec, 0), vec) + // -> setcc(cc, , insert_subvector(undef, fixed_width_vec, 0), vec) + // + // This avoids ending up with multiple PTRUEs and also helps some of the other + // folds for setcc_merge_zero which try to match multiple sequences of setcc's + // when the predicate is equivalent. + EVT FixedLengthVT; + if (isConvertToScalableVector(LHS) && isAllActivePredicate(DAG, Pred)) { + FixedLengthVT = LHS.getOperand(1).getValueType(); + SDLoc DL(N); + SDValue FixedLengthPred = + getPredicateForFixedLengthVector(DAG, DL, FixedLengthVT); + return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, N->getValueType(0), + FixedLengthPred, LHS, RHS, N->getOperand(3)); + } + if (Cond == ISD::SETNE && isZerosVector(RHS.getNode()) && LHS->getOpcode() == ISD::SIGN_EXTEND && LHS->getOperand(0)->getValueType(0) == N->getValueType(0)) { diff --git a/llvm/test/CodeGen/AArch64/sve-setcc.ll b/llvm/test/CodeGen/AArch64/sve-setcc.ll --- a/llvm/test/CodeGen/AArch64/sve-setcc.ll +++ b/llvm/test/CodeGen/AArch64/sve-setcc.ll @@ -115,6 +115,48 @@ declare @llvm.aarch64.sve.cmpne.nxv16i8(, , ) +; Make sure that only a single PTRUE instruction is generated (to avoid having both a `ptrue p0.s` and a `ptrue p0.s, vl16` for the different operations). +define @foo( %other, <16 x i8> %subvec) { +; CHECK-LABEL: foo: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.b, vl16 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: cmpne p1.b, p0/z, z0.b, #0 +; CHECK-NEXT: cmpne p0.b, p0/z, z1.b, z0.b +; CHECK-NEXT: mov p0.b, p1/m, p1.b +; CHECK-NEXT: ret + %ptrue_vl16 = call @llvm.aarch64.sve.ptrue.nxv16i1(i32 9) + %cmp1 = call @llvm.aarch64.sve.cmpne.nxv16i8( %ptrue_vl16, %other, zeroinitializer) + %vec = call @llvm.experimental.vector.insert.nxv16i8.v16i8( poison, <16 x i8> %subvec, i64 0) + %cmp2 = icmp ne %vec, %other + %retval = or %cmp1, %cmp2 + ret %retval +} + +; Same as the above test, but make sure we still fold the AND operation into the SETCC. +; +; Fold: +; and(setcc_merge_zero(pred, ...),setcc_merge_zero(pred, ...)) +; -> setcc_merge_zero(setcc_merge_zero(pred, ...), ...) +define @setcc_reuse_ptrue_vl16_fold_and( %other, <16 x i8> %subvec) { +; CHECK-LABEL: setcc_reuse_ptrue_vl16_fold_and: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.b, vl16 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: cmpne p0.b, p0/z, z0.b, #0 +; CHECK-NEXT: cmpne p0.b, p0/z, z1.b, z0.b +; CHECK-NEXT: ret + %ptrue_vl16 = call @llvm.aarch64.sve.ptrue.nxv16i1(i32 9) + %cmp1 = call @llvm.aarch64.sve.cmpne.nxv16i8( %ptrue_vl16, %other, zeroinitializer) + %vec = call @llvm.experimental.vector.insert.nxv16i8.v16i8( poison, <16 x i8> %subvec, i64 0) + %cmp2 = icmp ne %vec, %other + %retval = and %cmp1, %cmp2 + ret %retval +} + +declare @llvm.experimental.vector.insert.nxv16i8.v16i8(, <16 x i8>, i64) +declare @llvm.aarch64.sve.ptrue.nxv16i1(i32) + declare i1 @llvm.aarch64.sve.ptest.any.nxv8i1(, ) declare i1 @llvm.aarch64.sve.ptest.last.nxv8i1(, )