diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1211,6 +1211,11 @@ setTruncStoreAction(MVT::v4i16, MVT::v4i8, Custom); + setTruncStoreAction(MVT::v16i8, MVT::v16i1, Custom); + setTruncStoreAction(MVT::v8i16, MVT::v8i1, Custom); + setTruncStoreAction(MVT::v4i32, MVT::v4i1, Custom); + setTruncStoreAction(MVT::v2i64, MVT::v2i1, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom); setLoadExtAction(ISD::SEXTLOAD, MVT::v4i16, MVT::v4i8, Custom); setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i16, MVT::v4i8, Custom); @@ -19532,6 +19537,115 @@ return DAG.getMergeValues({ExtractSubVector, TokenFactor}, DL); } +// When performing a vector compare with n elements followed by some form of +// truncation/casting to , we can use a trick that extracts the i^th +// bit from the i^th element and then performs a vector add to get a scalar +// bitmask. +static SDValue vectorCompareToBitVector(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + assert(VT.isVector() && "Should be a vector type"); + assert(N->getOpcode() == ISD::SETCC && "Must be vector compare."); + + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + assert(LHS.getValueType() == RHS.getValueType()); + EVT VecVT = LHS.getValueType(); + EVT ElementType = VecVT.getVectorElementType(); + + if (VecVT != MVT::v2i64 && VecVT != MVT::v2i32 && VecVT != MVT::v4i32 && + VecVT != MVT::v4i16 && VecVT != MVT::v8i16 && VecVT != MVT::v8i8 && + VecVT != MVT::v16i8) + return SDValue(); + + SDLoc DL(N); + SDValue ComparisonResult = DAG.getNode(N->getOpcode(), DL, VecVT, LHS, RHS, + N->getOperand(2), N->getFlags()); + + SDValue VectorBits; + if (VecVT == MVT::v16i8) { + // v16i8 is a special case, as we need to split it into two halves and + // combine, perform the mask+addition twice, and then combine them. + SmallVector MaskConstants; + for (unsigned Half = 0; Half < 2; ++Half) { + for (unsigned MaskBit = 1; MaskBit <= 128; MaskBit *= 2) { + MaskConstants.push_back(DAG.getConstant(MaskBit, DL, MVT::i32)); + } + } + SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants); + SDValue RepresentativeBits = + DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask); + + EVT HalfVT = VecVT.getHalfNumVectorElementsVT(*DAG.getContext()); + unsigned NumElementsInHalf = HalfVT.getVectorNumElements(); + + SDValue LowHalf = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, RepresentativeBits, + DAG.getConstant(0, DL, MVT::i64)); + SDValue HighHalf = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, RepresentativeBits, + DAG.getConstant(NumElementsInHalf, DL, MVT::i64)); + + SDValue ReducedLowBits = + DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i16, LowHalf); + SDValue ReducedHighBits = + DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i16, HighHalf); + + SDValue ShiftedHighBits = + DAG.getNode(ISD::SHL, DL, MVT::i16, ReducedHighBits, + DAG.getConstant(NumElementsInHalf, DL, MVT::i32)); + VectorBits = + DAG.getNode(ISD::OR, DL, MVT::i16, ShiftedHighBits, ReducedLowBits); + } else { + SmallVector MaskConstants; + unsigned MaxBitMask = 1u << (VecVT.getVectorNumElements() - 1); + for (unsigned MaskBit = 1; MaskBit <= MaxBitMask; MaskBit *= 2) { + MaskConstants.push_back(DAG.getConstant(MaskBit, DL, ElementType)); + } + + SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants); + SDValue RepresentativeBits = + DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask); + VectorBits = + DAG.getNode(ISD::VECREDUCE_ADD, DL, ElementType, RepresentativeBits); + } + + return VectorBits; +} + +static SDValue combineVectorCompareAndTruncateStore(SelectionDAG &DAG, + StoreSDNode *Store) { + if (!Store->isTruncatingStore()) + return SDValue(); + + SDValue VecCompareOp = Store->getValue(); + EVT VT = VecCompareOp.getValueType(); + EVT MemVT = Store->getMemoryVT(); + + if (!MemVT.isVector() || !VT.isVector()) + return SDValue(); + + // We only want to combine truncating stores to single bits. + if (MemVT.getVectorElementType() != MVT::i1 || + MemVT.getVectorNumElements() != VT.getVectorNumElements()) + return SDValue(); + + // We can only apply this if we know that the input is all 1s or all 0s, + // which is the case for vector comparisons. + if (VecCompareOp->getOpcode() != ISD::SETCC) + return SDValue(); + + SDValue VectorBits = vectorCompareToBitVector(VecCompareOp.getNode(), DAG); + if (!VectorBits) + return SDValue(); + + SDLoc DL(Store); + EVT StoreVT = + EVT::getIntegerVT(*DAG.getContext(), MemVT.getStoreSizeInBits()); + SDValue ExtendedBits = DAG.getZExtOrTrunc(VectorBits, DL, StoreVT); + return DAG.getStore(Store->getChain(), DL, ExtendedBits, Store->getBasePtr(), + Store->getMemOperand()); +} + static SDValue performSTORECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG, @@ -19570,6 +19684,9 @@ if (SDValue Store = foldTruncStoreOfExt(DAG, N)) return Store; + if (SDValue Store = combineVectorCompareAndTruncateStore(DAG, ST)) + return Store; + return SDValue(); } @@ -20450,111 +20567,49 @@ Op0ExtV, Op1ExtV, Op->getOperand(2)); } -// When performing a vector compare with n elements followed by a bitcast to -// , we can use a trick that extracts the i^th bit from the i^th -// element and then performs a vector add to get a scalar bitmask. static SDValue combineVectorCompareAndBitcast(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { - EVT VT = N->getValueType(0); - if (!VT.isVector() || VT.getVectorElementType() != MVT::i1) + SDLoc DL(N); + EVT VecVT = N->getValueType(0); + if (!VecVT.isVector() || VecVT.getVectorElementType() != MVT::i1) return SDValue(); - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - assert(LHS.getValueType() == RHS.getValueType()); - EVT VecVT = LHS.getValueType(); - EVT ElementType = VecVT.getVectorElementType(); - - if (VecVT != MVT::v2i64 && VecVT != MVT::v2i32 && VecVT != MVT::v4i32 && - VecVT != MVT::v4i16 && VecVT != MVT::v8i16 && VecVT != MVT::v8i8 && - VecVT != MVT::v16i8) + SDValue VectorBits = vectorCompareToBitVector(N, DAG); + if (!VectorBits) return SDValue(); - SDLoc DL(N); - SDValue ComparisonResult = DAG.getNode(N->getOpcode(), DL, VecVT, LHS, RHS, - N->getOperand(2), N->getFlags()); - - SDValue VectorBits; - if (VecVT == MVT::v16i8) { - // v16i8 is a special case, as we need to split it into two halves and - // combine, perform the mask+addition twice, and then combine them. - SmallVector MaskConstants; - for (unsigned Half = 0; Half < 2; ++Half) { - for (unsigned MaskBit = 1; MaskBit <= 128; MaskBit *= 2) { - MaskConstants.push_back(DAG.getConstant(MaskBit, DL, MVT::i32)); - } - } - SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants); - SDValue RepresentativeBits = - DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask); - - EVT HalfVT = VecVT.getHalfNumVectorElementsVT(*DAG.getContext()); - unsigned NumElementsInHalf = HalfVT.getVectorNumElements(); - - SDValue LowHalf = - DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, RepresentativeBits, - DAG.getConstant(0, DL, MVT::i64)); - SDValue HighHalf = - DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, RepresentativeBits, - DAG.getConstant(NumElementsInHalf, DL, MVT::i64)); - - SDValue ReducedLowBits = - DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i16, LowHalf); - SDValue ReducedHighBits = - DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i16, HighHalf); - - SDValue ShiftedHighBits = - DAG.getNode(ISD::SHL, DL, MVT::i16, ReducedHighBits, - DAG.getConstant(NumElementsInHalf, DL, MVT::i32)); - VectorBits = - DAG.getNode(ISD::OR, DL, MVT::i16, ShiftedHighBits, ReducedLowBits); - } else { - SmallVector MaskConstants; - unsigned MaxBitMask = 1u << (VecVT.getVectorNumElements() - 1); - for (unsigned MaskBit = 1; MaskBit <= MaxBitMask; MaskBit *= 2) { - MaskConstants.push_back(DAG.getConstant(MaskBit, DL, ElementType)); - } - - SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants); - SDValue RepresentativeBits = - DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask); - VectorBits = - DAG.getNode(ISD::VECREDUCE_ADD, DL, ElementType, RepresentativeBits); - } - // Check chain of uses to see if the compare is followed by a bitcast. + bool CombinedBitcast = false; for (SDNode *User : N->uses()) { - // For v4i1 and v2i1, we get a vector concatenation to fill the remaining - // bits before bitcasting. This op can be skipped if we replace the user - // chain. We are defensive here to avoid producing wrong code in case we - // are not aware of the bitcast pattern. This pattern is generated by - // clang, e.g., when using __builtin_convertvector(). + // When using Clang's __builtin_convertvector(), for v4i1 and v2i1, we get a + // vector concatenation to fill the remaining bits before bitcasting. This + // op can be skipped if we replace the user chain. We are defensive here to + // avoid producing wrong code in case we are not aware of the pattern. if (User->getOpcode() == ISD::CONCAT_VECTORS) { - if (!User->hasOneUse() || User->getOperand(0).getValueType() != VT) + if (!User->hasOneUse()) return SDValue(); // The vector with the relevant bits must be the first vector. - if (User->getOperand(0) != SDValue(N, 0)) + if (User->getOperand(0) != SDValue(N, 0) || + !User->getOperand(1).isUndef()) return SDValue(); - // The other vectors must be undef. - for (unsigned I = 1; I < User->getNumOperands(); ++I) - if (!User->getOperand(I).isUndef()) - return SDValue(); - User = *User->use_begin(); } + // The comparison result may have other users, but we only want to replace + // the explicit bitcast and let other passes combine other instructions. if (User->getOpcode() != ISD::BITCAST) continue; EVT TargetVT = User->getValueType(0); SDValue BitcastedResult = DAG.getZExtOrTrunc(VectorBits, DL, TargetVT); DCI.CombineTo(User, BitcastedResult); + CombinedBitcast = true; } - return SDValue(N, 0); + return CombinedBitcast ? SDValue(N, 0) : SDValue(); } static SDValue performSETCCCombine(SDNode *N, diff --git a/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll b/llvm/test/CodeGen/AArch64/vec-combine-compare-and-store.ll copy from llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll copy to llvm/test/CodeGen/AArch64/vec-combine-compare-and-store.ll --- a/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll +++ b/llvm/test/CodeGen/AArch64/vec-combine-compare-and-store.ll @@ -1,11 +1,6 @@ -; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc -mtriple=aarch64-apple-darwin -mattr=+neon -verify-machineinstrs < %s | FileCheck %s -; Basic tests from input vector to bitmask -; IR generated from clang for: -; __builtin_convertvector + reinterpret_cast - -define i16 @convert_to_bitmask16(<16 x i8> %vec) { +define void @store_16_elements(<16 x i8> %vec, ptr %out) { ; Bits used in mask ; CHECK-LABEL: lCPI0_0 ; CHECK-NEXT: .byte 1 @@ -26,7 +21,7 @@ ; CHECK-NEXT: .byte 128 ; Actual conversion -; CHECK-LABEL: convert_to_bitmask16 +; CHECK-LABEL: store_16_elements ; CHECK: adrp x8, lCPI0_0@PAGE ; CHECK: cmeq.16b v0, v0, #0 ; CHECK: ldr q1, [x8, lCPI0_0@PAGEOFF] @@ -36,15 +31,16 @@ ; CHECK: addv.8b b1, v1 ; CHECK: fmov w9, s0 ; CHECK: fmov w8, s1 -; CHECK: orr w0, w9, w8, lsl #8 +; CHECK: orr w8, w9, w8, lsl #8 +; CHECK; strh w8, [x0] ; CHECK: ret %cmp_result = icmp ne <16 x i8> %vec, zeroinitializer - %bitmask = bitcast <16 x i1> %cmp_result to i16 - ret i16 %bitmask + store <16 x i1> %cmp_result, ptr %out + ret void } -define i16 @convert_to_bitmask8(<8 x i16> %vec) { +define void @store_8_elements(<8 x i16> %vec, ptr %out) { ; CHECK-LABEL: lCPI1_0: ; CHECK-NEXT: .short 1 ; CHECK-NEXT: .short 2 @@ -61,16 +57,15 @@ ; CHECK: bic.16b v0, v1, v0 ; CHECK: addv.8h h0, v0 ; CHECK: fmov w8, s0 -; CHECK: and w0, w8, #0xff +; CHECK: strb w8, [x0] ; CHECK: ret %cmp_result = icmp ne <8 x i16> %vec, zeroinitializer - %bitmask = bitcast <8 x i1> %cmp_result to i8 - %extended_bitmask = zext i8 %bitmask to i16 - ret i16 %extended_bitmask + store <8 x i1> %cmp_result, ptr %out + ret void } -define i16 @convert_to_bitmask4(<4 x i32> %vec) { +define void @store_4_elements(<4 x i32> %vec, ptr %out) { ; CHECK-LABEL: lCPI2_0: ; CHECK: .long 1 ; CHECK: .long 2 @@ -84,20 +79,18 @@ ; CHECK: bic.16b v0, v1, v0 ; CHECK: addv.4s s0, v0 ; CHECK: fmov w8, s0 -; CHECK: and w0, w8, #0xff +; CHECK: strb w8, [x0] ; CHECK: ret %cmp_result = icmp ne <4 x i32> %vec, zeroinitializer - %vector_pad = shufflevector <4 x i1> %cmp_result, <4 x i1> poison, <8 x i32> - %bitmask = bitcast <8 x i1> %vector_pad to i8 - %extended_bitmask = zext i8 %bitmask to i16 - ret i16 %extended_bitmask + store <4 x i1> %cmp_result, ptr %out + ret void } -define i16 @convert_to_bitmask2(<2 x i64> %vec) { +define void @store_2_elements(<2 x i64> %vec, ptr %out) { ; CHECK-LABEL: lCPI3_0: -; CHECK-NEXXT: .quad 1 -; CHECK-NEXXT: .quad 2 +; CHECK-NEXT: .quad 1 +; CHECK-NEXT: .quad 2 ; CHECK: adrp x8, lCPI3_0@PAGE ; CHECK: cmeq.2d v0, v0, #0 @@ -106,12 +99,31 @@ ; CHECK: bic.16b v0, v1, v0 ; CHECK: addp.2d d0, v0 ; CHECK: fmov x8, d0 -; CHECK: and w0, w8, #0xff +; CHECK: strb w8, [x0] ; CHECK: ret %cmp_result = icmp ne <2 x i64> %vec, zeroinitializer - %vector_pad = shufflevector <2 x i1> %cmp_result, <2 x i1> poison, <8 x i32> - %bitmask = bitcast <8 x i1> %vector_pad to i8 - %extended_bitmask = zext i8 %bitmask to i16 - ret i16 %extended_bitmask + store <2 x i1> %cmp_result, ptr %out + ret void +} + +define void @no_combine_without_truncate(<16 x i8> %vec, ptr %out) { +; CHECK-LABEL: no_combine_without_truncate +; CHECK: cmtst.16b v0, v0, v0 +; CHECK-NOT: addv.8b b0, v0 + + %cmp_result = icmp ne <16 x i8> %vec, zeroinitializer + %extended_result = sext <16 x i1> %cmp_result to <16 x i8> + store <16 x i8> %extended_result, ptr %out + ret void +} + +define void @no_combine_without_compare(<16 x i8> %vec, ptr %out) { +; CHECK-LABEL: no_combine_without_compare +; CHECK: umov.b w8, v0[0] +; CHECK-NOT: addv.8b b0, v0 + + %trunc = trunc <16 x i8> %vec to <16 x i1> + store <16 x i1> %trunc, ptr %out + ret void } diff --git a/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll b/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll --- a/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll +++ b/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll @@ -1,4 +1,3 @@ -; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc -mtriple=aarch64-apple-darwin -mattr=+neon -verify-machineinstrs < %s | FileCheck %s ; Basic tests from input vector to bitmask @@ -115,3 +114,26 @@ %extended_bitmask = zext i8 %bitmask to i16 ret i16 %extended_bitmask } + + +define i16 @no_convert_bad_concat(<4 x i32> %vec) { +; CHECK-LABEL: no_convert_bad_concat: +; CHECK: cmtst.4s v0, v0, v0 +; CHECK-NOT: addv.4s s0, v0 + + %cmp_result = icmp ne <4 x i32> %vec, zeroinitializer + %vector_pad = shufflevector <4 x i1> poison, <4 x i1> %cmp_result, <8 x i32> + %bitmask = bitcast <8 x i1> %vector_pad to i8 + %extended_bitmask = zext i8 %bitmask to i16 + ret i16 %extended_bitmask +} + +define <8 x i1> @no_convert_without_direct_bitcast(<8 x i16> %vec) { +; CHECK-LABEL: no_convert_without_direct_bitcast: +; CHECK: cmtst.8h v0, v0, v0 +; CHECK-NOT: addv.4s s0, v0 + + %cmp_result = icmp ne <8 x i16> %vec, zeroinitializer + ret <8 x i1> %cmp_result +} +