diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -20450,6 +20450,113 @@ Op0ExtV, Op1ExtV, Op->getOperand(2)); } +// When performing a vector compare with n elements followed by a bitcast to +// , we can use a trick that extracts the i^th bit from the i^th +// element and then performs a vector add to get a scalar bitmask. +static SDValue +combineVectorCompareAndBitcast(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + if (!VT.isVector() || VT.getVectorElementType() != MVT::i1) + return SDValue(); + + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + assert(LHS.getValueType() == RHS.getValueType()); + EVT VecVT = LHS.getValueType(); + EVT ElementType = VecVT.getVectorElementType(); + + if (VecVT != MVT::v2i64 && VecVT != MVT::v2i32 && VecVT != MVT::v4i32 && + VecVT != MVT::v4i16 && VecVT != MVT::v8i16 && VecVT != MVT::v8i8 && + VecVT != MVT::v16i8) + return SDValue(); + + SDLoc DL(N); + SDValue ComparisonResult = DAG.getNode(N->getOpcode(), DL, VecVT, LHS, RHS, + N->getOperand(2), N->getFlags()); + + SDValue VectorBits; + if (VecVT == MVT::v16i8) { + // v16i8 is a special case, as we need to split it into two halves and + // combine, perform the mask+addition twice, and then combine them. + SmallVector MaskConstants; + for (unsigned Half = 0; Half < 2; ++Half) { + for (unsigned MaskBit = 1; MaskBit <= 128; MaskBit *= 2) { + MaskConstants.push_back(DAG.getConstant(MaskBit, DL, MVT::i32)); + } + } + SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants); + SDValue RepresentativeBits = + DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask); + + EVT HalfVT = VecVT.getHalfNumVectorElementsVT(*DAG.getContext()); + unsigned NumElementsInHalf = HalfVT.getVectorNumElements(); + + SDValue LowHalf = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, RepresentativeBits, + DAG.getConstant(0, DL, MVT::i64)); + SDValue HighHalf = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, RepresentativeBits, + DAG.getConstant(NumElementsInHalf, DL, MVT::i64)); + + SDValue ReducedLowBits = + DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i16, LowHalf); + SDValue ReducedHighBits = + DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i16, HighHalf); + + SDValue ShiftedHighBits = + DAG.getNode(ISD::SHL, DL, MVT::i16, ReducedHighBits, + DAG.getConstant(NumElementsInHalf, DL, MVT::i32)); + VectorBits = + DAG.getNode(ISD::OR, DL, MVT::i16, ShiftedHighBits, ReducedLowBits); + } else { + SmallVector MaskConstants; + unsigned MaxBitMask = 1u << (VecVT.getVectorNumElements() - 1); + for (unsigned MaskBit = 1; MaskBit <= MaxBitMask; MaskBit *= 2) { + MaskConstants.push_back(DAG.getConstant(MaskBit, DL, ElementType)); + } + + SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants); + SDValue RepresentativeBits = + DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask); + VectorBits = + DAG.getNode(ISD::VECREDUCE_ADD, DL, ElementType, RepresentativeBits); + } + + // Check chain of uses to see if the compare is followed by a bitcast. + for (SDNode *User : N->uses()) { + // For v4i1 and v2i1, we get a vector concatenation to fill the remaining + // bits before bitcasting. This op can be skipped if we replace the user + // chain. We are defensive here to avoid producing wrong code in case we + // are not aware of the bitcast pattern. This pattern is generated by + // clang, e.g., when using __builtin_convertvector(). + if (User->getOpcode() == ISD::CONCAT_VECTORS) { + if (!User->hasOneUse() || User->getOperand(0).getValueType() != VT) + return SDValue(); + + // The vector with the relevant bits must be the first vector. + if (User->getOperand(0) != SDValue(N, 0)) + return SDValue(); + + // The other vectors must be undef. + for (unsigned I = 1; I < User->getNumOperands(); ++I) + if (!User->getOperand(I).isUndef()) + return SDValue(); + + User = *User->use_begin(); + } + + if (User->getOpcode() != ISD::BITCAST) + continue; + + EVT TargetVT = User->getValueType(0); + SDValue BitcastedResult = DAG.getZExtOrTrunc(VectorBits, DL, TargetVT); + DCI.CombineTo(User, BitcastedResult); + } + + return SDValue(N, 0); +} + static SDValue performSETCCCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { @@ -20511,6 +20618,9 @@ } } + if (SDValue V = combineVectorCompareAndBitcast(N, DCI, DAG)) + return V; + // Try to perform the memcmp when the result is tested for [in]equality with 0 if (SDValue V = performOrXorChainCombine(N, DAG)) return V; diff --git a/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll b/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll @@ -0,0 +1,117 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-apple-darwin -mattr=+neon -verify-machineinstrs < %s | FileCheck %s + +; Basic tests from input vector to bitmask +; IR generated from clang for: +; __builtin_convertvector + reinterpret_cast + +define i16 @convert_to_bitmask16(<16 x i8> %vec) { +; Bits used in mask +; CHECK-LABEL: lCPI0_0 +; CHECK-NEXT: .byte 1 +; CHECK-NEXT: .byte 2 +; CHECK-NEXT: .byte 4 +; CHECK-NEXT: .byte 8 +; CHECK-NEXT: .byte 16 +; CHECK-NEXT: .byte 32 +; CHECK-NEXT: .byte 64 +; CHECK-NEXT: .byte 128 +; CHECK-NEXT: .byte 1 +; CHECK-NEXT: .byte 2 +; CHECK-NEXT: .byte 4 +; CHECK-NEXT: .byte 8 +; CHECK-NEXT: .byte 16 +; CHECK-NEXT: .byte 32 +; CHECK-NEXT: .byte 64 +; CHECK-NEXT: .byte 128 + +; Actual conversion +; CHECK-LABEL: convert_to_bitmask16 +; CHECK: adrp x8, lCPI0_0@PAGE +; CHECK: cmeq.16b v0, v0, #0 +; CHECK: ldr q1, [x8, lCPI0_0@PAGEOFF] +; CHECK: bic.16b v0, v1, v0 +; CHECK: ext.16b v1, v0, v0, #8 +; CHECK: addv.8b b0, v0 +; CHECK: addv.8b b1, v1 +; CHECK: fmov w9, s0 +; CHECK: fmov w8, s1 +; CHECK: orr w0, w9, w8, lsl #8 +; CHECK: ret + + %cmp_result = icmp ne <16 x i8> %vec, zeroinitializer + %bitmask = bitcast <16 x i1> %cmp_result to i16 + ret i16 %bitmask +} + +define i16 @convert_to_bitmask8(<8 x i16> %vec) { +; CHECK-LABEL: lCPI1_0: +; CHECK-NEXT: .short 1 +; CHECK-NEXT: .short 2 +; CHECK-NEXT: .short 4 +; CHECK-NEXT: .short 8 +; CHECK-NEXT: .short 16 +; CHECK-NEXT: .short 32 +; CHECK-NEXT: .short 64 +; CHECK-NEXT: .short 128 + +; CHECK: adrp x8, lCPI1_0@PAGE +; CHECK: cmeq.8h v0, v0, #0 +; CHECK: ldr q1, [x8, lCPI1_0@PAGEOFF] +; CHECK: bic.16b v0, v1, v0 +; CHECK: addv.8h h0, v0 +; CHECK: fmov w8, s0 +; CHECK: and w0, w8, #0xff +; CHECK: ret + + %cmp_result = icmp ne <8 x i16> %vec, zeroinitializer + %bitmask = bitcast <8 x i1> %cmp_result to i8 + %extended_bitmask = zext i8 %bitmask to i16 + ret i16 %extended_bitmask +} + +define i16 @convert_to_bitmask4(<4 x i32> %vec) { +; CHECK-LABEL: lCPI2_0: +; CHECK: .long 1 +; CHECK: .long 2 +; CHECK: .long 4 +; CHECK: .long 8 + +; CHECK: adrp x8, lCPI2_0@PAGE +; CHECK: cmeq.4s v0, v0, #0 +; CHECK: Lloh5: +; CHECK: ldr q1, [x8, lCPI2_0@PAGEOFF] +; CHECK: bic.16b v0, v1, v0 +; CHECK: addv.4s s0, v0 +; CHECK: fmov w8, s0 +; CHECK: and w0, w8, #0xff +; CHECK: ret + + %cmp_result = icmp ne <4 x i32> %vec, zeroinitializer + %vector_pad = shufflevector <4 x i1> %cmp_result, <4 x i1> poison, <8 x i32> + %bitmask = bitcast <8 x i1> %vector_pad to i8 + %extended_bitmask = zext i8 %bitmask to i16 + ret i16 %extended_bitmask +} + +define i16 @convert_to_bitmask2(<2 x i64> %vec) { +; CHECK-LABEL: lCPI3_0: +; CHECK-NEXXT: .quad 1 +; CHECK-NEXXT: .quad 2 + +; CHECK: adrp x8, lCPI3_0@PAGE +; CHECK: cmeq.2d v0, v0, #0 +; CHECK: Lloh7: +; CHECK: ldr q1, [x8, lCPI3_0@PAGEOFF] +; CHECK: bic.16b v0, v1, v0 +; CHECK: addp.2d d0, v0 +; CHECK: fmov x8, d0 +; CHECK: and w0, w8, #0xff +; CHECK: ret + + %cmp_result = icmp ne <2 x i64> %vec, zeroinitializer + %vector_pad = shufflevector <2 x i1> %cmp_result, <2 x i1> poison, <8 x i32> + %bitmask = bitcast <8 x i1> %vector_pad to i8 + %extended_bitmask = zext i8 %bitmask to i16 + ret i16 %extended_bitmask +}