diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1211,6 +1211,11 @@ setTruncStoreAction(MVT::v4i16, MVT::v4i8, Custom); + setTruncStoreAction(MVT::v16i8, MVT::v16i1, Custom); + setTruncStoreAction(MVT::v8i16, MVT::v8i1, Custom); + setTruncStoreAction(MVT::v4i32, MVT::v4i1, Custom); + setTruncStoreAction(MVT::v2i64, MVT::v2i1, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom); setLoadExtAction(ISD::SEXTLOAD, MVT::v4i16, MVT::v4i8, Custom); setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i16, MVT::v4i8, Custom); @@ -19454,6 +19459,115 @@ return DAG.getMergeValues({ExtractSubVector, TokenFactor}, DL); } +// When performing a vector compare with n elements followed by some form of +// truncation/casting to , we can use a trick that extracts the i^th +// bit from the i^th element and then performs a vector add to get a scalar +// bitmask. +static SDValue vectorCompareToBitVector(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + assert(VT.isVector() && "Should be a vector type"); + assert(N->getOpcode() == ISD::SETCC && "Must be vector compare."); + + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + assert(LHS.getValueType() == RHS.getValueType()); + EVT VecVT = LHS.getValueType(); + EVT ElementType = VecVT.getVectorElementType(); + + if (VecVT != MVT::v2i64 && VecVT != MVT::v2i32 && VecVT != MVT::v4i32 && + VecVT != MVT::v4i16 && VecVT != MVT::v8i16 && VecVT != MVT::v8i8 && + VecVT != MVT::v16i8) + return SDValue(); + + SDLoc DL(N); + SDValue ComparisonResult = DAG.getNode(N->getOpcode(), DL, VecVT, LHS, RHS, + N->getOperand(2), N->getFlags()); + + SDValue VectorBits; + if (VecVT == MVT::v16i8) { + // v16i8 is a special case, as we need to split it into two halves and + // combine, perform the mask+addition twice, and then combine them. + SmallVector MaskConstants; + for (unsigned Half = 0; Half < 2; ++Half) { + for (unsigned MaskBit = 1; MaskBit <= 128; MaskBit *= 2) { + MaskConstants.push_back(DAG.getConstant(MaskBit, DL, MVT::i32)); + } + } + SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants); + SDValue RepresentativeBits = + DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask); + + EVT HalfVT = VecVT.getHalfNumVectorElementsVT(*DAG.getContext()); + unsigned NumElementsInHalf = HalfVT.getVectorNumElements(); + + SDValue LowHalf = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, RepresentativeBits, + DAG.getConstant(0, DL, MVT::i64)); + SDValue HighHalf = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, RepresentativeBits, + DAG.getConstant(NumElementsInHalf, DL, MVT::i64)); + + SDValue ReducedLowBits = + DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i16, LowHalf); + SDValue ReducedHighBits = + DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i16, HighHalf); + + SDValue ShiftedHighBits = + DAG.getNode(ISD::SHL, DL, MVT::i16, ReducedHighBits, + DAG.getConstant(NumElementsInHalf, DL, MVT::i32)); + VectorBits = + DAG.getNode(ISD::OR, DL, MVT::i16, ShiftedHighBits, ReducedLowBits); + } else { + SmallVector MaskConstants; + unsigned MaxBitMask = 1u << (VecVT.getVectorNumElements() - 1); + for (unsigned MaskBit = 1; MaskBit <= MaxBitMask; MaskBit *= 2) { + MaskConstants.push_back(DAG.getConstant(MaskBit, DL, ElementType)); + } + + SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants); + SDValue RepresentativeBits = + DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask); + VectorBits = + DAG.getNode(ISD::VECREDUCE_ADD, DL, ElementType, RepresentativeBits); + } + + return VectorBits; +} + +static SDValue combineVectorCompareAndTruncateStore(SelectionDAG &DAG, + StoreSDNode *Store) { + if (!Store->isTruncatingStore()) + return SDValue(); + + SDValue VecCompareOp = Store->getValue(); + EVT VT = VecCompareOp.getValueType(); + EVT MemVT = Store->getMemoryVT(); + + if (!MemVT.isVector() || !VT.isVector()) + return SDValue(); + + // We only want to combine truncating stores to single bits. + if (MemVT.getVectorElementType() != MVT::i1 || + MemVT.getVectorNumElements() != VT.getVectorNumElements()) + return SDValue(); + + // We can only apply this if we know that the input is all 1s or all 0s, + // which is the case for vector comparisons. + if (VecCompareOp->getOpcode() != ISD::SETCC) + return SDValue(); + + SDValue VectorBits = vectorCompareToBitVector(VecCompareOp.getNode(), DAG); + if (!VectorBits) + return SDValue(); + + SDLoc DL(Store); + EVT StoreVT = + EVT::getIntegerVT(*DAG.getContext(), MemVT.getStoreSizeInBits()); + SDValue ExtendedBits = DAG.getZExtOrTrunc(VectorBits, DL, StoreVT); + return DAG.getStore(Store->getChain(), DL, ExtendedBits, Store->getBasePtr(), + Store->getMemOperand()); +} + static SDValue performSTORECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG, @@ -19492,6 +19606,9 @@ if (SDValue Store = foldTruncStoreOfExt(DAG, N)) return Store; + if (SDValue Store = combineVectorCompareAndTruncateStore(DAG, ST)) + return Store; + return SDValue(); } @@ -20372,6 +20489,51 @@ Op0ExtV, Op1ExtV, Op->getOperand(2)); } +static SDValue +combineVectorCompareAndBitcast(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + SDLoc DL(N); + EVT VecVT = N->getValueType(0); + if (!VecVT.isVector() || VecVT.getVectorElementType() != MVT::i1) + return SDValue(); + + SDValue VectorBits = vectorCompareToBitVector(N, DAG); + if (!VectorBits) + return SDValue(); + + // Check chain of uses to see if the compare is followed by a bitcast. + bool CombinedBitcast = false; + for (SDNode *User : N->uses()) { + // When using Clang's __builtin_convertvector(), for v4i1 and v2i1, we get a + // vector concatenation to fill the remaining bits before bitcasting. This + // op can be skipped if we replace the user chain. We are defensive here to + // avoid producing wrong code in case we are not aware of the pattern. + if (User->getOpcode() == ISD::CONCAT_VECTORS) { + if (!User->hasOneUse()) + return SDValue(); + + // The vector with the relevant bits must be the first vector. + if (User->getOperand(0) != SDValue(N, 0) || + !User->getOperand(1).isUndef()) + return SDValue(); + + User = *User->use_begin(); + } + + // The comparison result may have other users, but we only want to replace + // the explicit bitcast and let other passes combine other instructions. + if (User->getOpcode() != ISD::BITCAST) + continue; + + EVT TargetVT = User->getValueType(0); + SDValue BitcastedResult = DAG.getZExtOrTrunc(VectorBits, DL, TargetVT); + DCI.CombineTo(User, BitcastedResult); + CombinedBitcast = true; + } + + return CombinedBitcast ? SDValue(N, 0) : SDValue(); +} + static SDValue performSETCCCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { @@ -20433,6 +20595,9 @@ } } + if (SDValue V = combineVectorCompareAndBitcast(N, DCI, DAG)) + return V; + // Try to perform the memcmp when the result is tested for [in]equality with 0 if (SDValue V = performOrXorChainCombine(N, DAG)) return V; diff --git a/llvm/test/CodeGen/AArch64/vec-combine-compare-and-store.ll b/llvm/test/CodeGen/AArch64/vec-combine-compare-and-store.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/vec-combine-compare-and-store.ll @@ -0,0 +1,129 @@ +; RUN: llc -mtriple=aarch64-apple-darwin -mattr=+neon -verify-machineinstrs < %s | FileCheck %s + +define void @store_16_elements(<16 x i8> %vec, ptr %out) { +; Bits used in mask +; CHECK-LABEL: lCPI0_0 +; CHECK-NEXT: .byte 1 +; CHECK-NEXT: .byte 2 +; CHECK-NEXT: .byte 4 +; CHECK-NEXT: .byte 8 +; CHECK-NEXT: .byte 16 +; CHECK-NEXT: .byte 32 +; CHECK-NEXT: .byte 64 +; CHECK-NEXT: .byte 128 +; CHECK-NEXT: .byte 1 +; CHECK-NEXT: .byte 2 +; CHECK-NEXT: .byte 4 +; CHECK-NEXT: .byte 8 +; CHECK-NEXT: .byte 16 +; CHECK-NEXT: .byte 32 +; CHECK-NEXT: .byte 64 +; CHECK-NEXT: .byte 128 + +; Actual conversion +; CHECK-LABEL: store_16_elements +; CHECK: adrp x8, lCPI0_0@PAGE +; CHECK: cmeq.16b v0, v0, #0 +; CHECK: ldr q1, [x8, lCPI0_0@PAGEOFF] +; CHECK: bic.16b v0, v1, v0 +; CHECK: ext.16b v1, v0, v0, #8 +; CHECK: addv.8b b0, v0 +; CHECK: addv.8b b1, v1 +; CHECK: fmov w9, s0 +; CHECK: fmov w8, s1 +; CHECK: orr w8, w9, w8, lsl #8 +; CHECK; strh w8, [x0] +; CHECK: ret + + %cmp_result = icmp ne <16 x i8> %vec, zeroinitializer + store <16 x i1> %cmp_result, ptr %out + ret void +} + +define void @store_8_elements(<8 x i16> %vec, ptr %out) { +; CHECK-LABEL: lCPI1_0: +; CHECK-NEXT: .short 1 +; CHECK-NEXT: .short 2 +; CHECK-NEXT: .short 4 +; CHECK-NEXT: .short 8 +; CHECK-NEXT: .short 16 +; CHECK-NEXT: .short 32 +; CHECK-NEXT: .short 64 +; CHECK-NEXT: .short 128 + +; CHECK: adrp x8, lCPI1_0@PAGE +; CHECK: cmeq.8h v0, v0, #0 +; CHECK: ldr q1, [x8, lCPI1_0@PAGEOFF] +; CHECK: bic.16b v0, v1, v0 +; CHECK: addv.8h h0, v0 +; CHECK: fmov w8, s0 +; CHECK: strb w8, [x0] +; CHECK: ret + + %cmp_result = icmp ne <8 x i16> %vec, zeroinitializer + store <8 x i1> %cmp_result, ptr %out + ret void +} + +define void @store_4_elements(<4 x i32> %vec, ptr %out) { +; CHECK-LABEL: lCPI2_0: +; CHECK: .long 1 +; CHECK: .long 2 +; CHECK: .long 4 +; CHECK: .long 8 + +; CHECK: adrp x8, lCPI2_0@PAGE +; CHECK: cmeq.4s v0, v0, #0 +; CHECK: Lloh5: +; CHECK: ldr q1, [x8, lCPI2_0@PAGEOFF] +; CHECK: bic.16b v0, v1, v0 +; CHECK: addv.4s s0, v0 +; CHECK: fmov w8, s0 +; CHECK: strb w8, [x0] +; CHECK: ret + + %cmp_result = icmp ne <4 x i32> %vec, zeroinitializer + store <4 x i1> %cmp_result, ptr %out + ret void +} + +define void @store_2_elements(<2 x i64> %vec, ptr %out) { +; CHECK-LABEL: lCPI3_0: +; CHECK-NEXT: .quad 1 +; CHECK-NEXT: .quad 2 + +; CHECK: adrp x8, lCPI3_0@PAGE +; CHECK: cmeq.2d v0, v0, #0 +; CHECK: Lloh7: +; CHECK: ldr q1, [x8, lCPI3_0@PAGEOFF] +; CHECK: bic.16b v0, v1, v0 +; CHECK: addp.2d d0, v0 +; CHECK: fmov x8, d0 +; CHECK: strb w8, [x0] +; CHECK: ret + + %cmp_result = icmp ne <2 x i64> %vec, zeroinitializer + store <2 x i1> %cmp_result, ptr %out + ret void +} + +define void @no_combine_without_truncate(<16 x i8> %vec, ptr %out) { +; CHECK-LABEL: no_combine_without_truncate +; CHECK: cmtst.16b v0, v0, v0 +; CHECK-NOT: addv.8b b0, v0 + + %cmp_result = icmp ne <16 x i8> %vec, zeroinitializer + %extended_result = sext <16 x i1> %cmp_result to <16 x i8> + store <16 x i8> %extended_result, ptr %out + ret void +} + +define void @no_combine_without_compare(<16 x i8> %vec, ptr %out) { +; CHECK-LABEL: no_combine_without_compare +; CHECK: umov.b w8, v0[0] +; CHECK-NOT: addv.8b b0, v0 + + %trunc = trunc <16 x i8> %vec to <16 x i1> + store <16 x i1> %trunc, ptr %out + ret void +} diff --git a/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll b/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll @@ -0,0 +1,139 @@ +; RUN: llc -mtriple=aarch64-apple-darwin -mattr=+neon -verify-machineinstrs < %s | FileCheck %s + +; Basic tests from input vector to bitmask +; IR generated from clang for: +; __builtin_convertvector + reinterpret_cast + +define i16 @convert_to_bitmask16(<16 x i8> %vec) { +; Bits used in mask +; CHECK-LABEL: lCPI0_0 +; CHECK-NEXT: .byte 1 +; CHECK-NEXT: .byte 2 +; CHECK-NEXT: .byte 4 +; CHECK-NEXT: .byte 8 +; CHECK-NEXT: .byte 16 +; CHECK-NEXT: .byte 32 +; CHECK-NEXT: .byte 64 +; CHECK-NEXT: .byte 128 +; CHECK-NEXT: .byte 1 +; CHECK-NEXT: .byte 2 +; CHECK-NEXT: .byte 4 +; CHECK-NEXT: .byte 8 +; CHECK-NEXT: .byte 16 +; CHECK-NEXT: .byte 32 +; CHECK-NEXT: .byte 64 +; CHECK-NEXT: .byte 128 + +; Actual conversion +; CHECK-LABEL: convert_to_bitmask16 +; CHECK: adrp x8, lCPI0_0@PAGE +; CHECK: cmeq.16b v0, v0, #0 +; CHECK: ldr q1, [x8, lCPI0_0@PAGEOFF] +; CHECK: bic.16b v0, v1, v0 +; CHECK: ext.16b v1, v0, v0, #8 +; CHECK: addv.8b b0, v0 +; CHECK: addv.8b b1, v1 +; CHECK: fmov w9, s0 +; CHECK: fmov w8, s1 +; CHECK: orr w0, w9, w8, lsl #8 +; CHECK: ret + + %cmp_result = icmp ne <16 x i8> %vec, zeroinitializer + %bitmask = bitcast <16 x i1> %cmp_result to i16 + ret i16 %bitmask +} + +define i16 @convert_to_bitmask8(<8 x i16> %vec) { +; CHECK-LABEL: lCPI1_0: +; CHECK-NEXT: .short 1 +; CHECK-NEXT: .short 2 +; CHECK-NEXT: .short 4 +; CHECK-NEXT: .short 8 +; CHECK-NEXT: .short 16 +; CHECK-NEXT: .short 32 +; CHECK-NEXT: .short 64 +; CHECK-NEXT: .short 128 + +; CHECK: adrp x8, lCPI1_0@PAGE +; CHECK: cmeq.8h v0, v0, #0 +; CHECK: ldr q1, [x8, lCPI1_0@PAGEOFF] +; CHECK: bic.16b v0, v1, v0 +; CHECK: addv.8h h0, v0 +; CHECK: fmov w8, s0 +; CHECK: and w0, w8, #0xff +; CHECK: ret + + %cmp_result = icmp ne <8 x i16> %vec, zeroinitializer + %bitmask = bitcast <8 x i1> %cmp_result to i8 + %extended_bitmask = zext i8 %bitmask to i16 + ret i16 %extended_bitmask +} + +define i16 @convert_to_bitmask4(<4 x i32> %vec) { +; CHECK-LABEL: lCPI2_0: +; CHECK: .long 1 +; CHECK: .long 2 +; CHECK: .long 4 +; CHECK: .long 8 + +; CHECK: adrp x8, lCPI2_0@PAGE +; CHECK: cmeq.4s v0, v0, #0 +; CHECK: Lloh5: +; CHECK: ldr q1, [x8, lCPI2_0@PAGEOFF] +; CHECK: bic.16b v0, v1, v0 +; CHECK: addv.4s s0, v0 +; CHECK: fmov w8, s0 +; CHECK: and w0, w8, #0xff +; CHECK: ret + + %cmp_result = icmp ne <4 x i32> %vec, zeroinitializer + %vector_pad = shufflevector <4 x i1> %cmp_result, <4 x i1> poison, <8 x i32> + %bitmask = bitcast <8 x i1> %vector_pad to i8 + %extended_bitmask = zext i8 %bitmask to i16 + ret i16 %extended_bitmask +} + +define i16 @convert_to_bitmask2(<2 x i64> %vec) { +; CHECK-LABEL: lCPI3_0: +; CHECK-NEXXT: .quad 1 +; CHECK-NEXXT: .quad 2 + +; CHECK: adrp x8, lCPI3_0@PAGE +; CHECK: cmeq.2d v0, v0, #0 +; CHECK: Lloh7: +; CHECK: ldr q1, [x8, lCPI3_0@PAGEOFF] +; CHECK: bic.16b v0, v1, v0 +; CHECK: addp.2d d0, v0 +; CHECK: fmov x8, d0 +; CHECK: and w0, w8, #0xff +; CHECK: ret + + %cmp_result = icmp ne <2 x i64> %vec, zeroinitializer + %vector_pad = shufflevector <2 x i1> %cmp_result, <2 x i1> poison, <8 x i32> + %bitmask = bitcast <8 x i1> %vector_pad to i8 + %extended_bitmask = zext i8 %bitmask to i16 + ret i16 %extended_bitmask +} + + +define i16 @no_convert_bad_concat(<4 x i32> %vec) { +; CHECK-LABEL: no_convert_bad_concat: +; CHECK: cmtst.4s v0, v0, v0 +; CHECK-NOT: addv.4s s0, v0 + + %cmp_result = icmp ne <4 x i32> %vec, zeroinitializer + %vector_pad = shufflevector <4 x i1> poison, <4 x i1> %cmp_result, <8 x i32> + %bitmask = bitcast <8 x i1> %vector_pad to i8 + %extended_bitmask = zext i8 %bitmask to i16 + ret i16 %extended_bitmask +} + +define <8 x i1> @no_convert_without_direct_bitcast(<8 x i16> %vec) { +; CHECK-LABEL: no_convert_without_direct_bitcast: +; CHECK: cmtst.8h v0, v0, v0 +; CHECK-NOT: addv.4s s0, v0 + + %cmp_result = icmp ne <8 x i16> %vec, zeroinitializer + ret <8 x i1> %cmp_result +} +