diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1211,6 +1211,11 @@ setTruncStoreAction(MVT::v4i16, MVT::v4i8, Custom); + setOperationAction(ISD::BITCAST, MVT::i2, Custom); + setOperationAction(ISD::BITCAST, MVT::i4, Custom); + setOperationAction(ISD::BITCAST, MVT::i8, Custom); + setOperationAction(ISD::BITCAST, MVT::i16, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom); setLoadExtAction(ISD::SEXTLOAD, MVT::v4i16, MVT::v4i8, Custom); setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i16, MVT::v4i8, Custom); @@ -19626,6 +19631,127 @@ return DAG.getMergeValues({ExtractSubVector, TokenFactor}, DL); } +// Small helper to check if a node chain consists entirely of comparisons +// combined with logical operations. This guarantees that all elements' bits are +// either 1 or 0. `BaseVT` contains the type of the base comparison operand if +// valid or something unspecified otherwise. +static bool isChainOfComparesAndLogicalOps(SDValue Op, EVT &BaseVT, + int Depth = 0) { + if (Depth > 3) + return false; + + if (Op.getOpcode() == ISD::SETCC) { + BaseVT = Op.getOperand(0).getValueType(); + return true; + } + + unsigned Opcode = Op.getOpcode(); + if (Opcode == ISD::OR || Opcode == ISD::AND || Opcode == ISD::XOR) + return isChainOfComparesAndLogicalOps(Op.getOperand(0), BaseVT, + Depth + 1) && + isChainOfComparesAndLogicalOps(Op.getOperand(1), BaseVT, Depth + 1); + + return false; +} + +// When converting a vector to to store or use as a scalar +// iN, we can use a trick that extracts the i^th bit from the i^th element and +// then performs a vector add to get a scalar bitmask. This requires that each +// element's bits are either all 1 or all 0. +static SDValue vectorToScalarBitmask(SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + SDValue ComparisonResult(N, 0); + EVT VecVT = ComparisonResult.getValueType(); + assert(VecVT.isVector() && "Must be a vector type"); + + unsigned NumElts = VecVT.getVectorNumElements(); + if (NumElts != 2 && NumElts != 4 && NumElts != 8 && NumElts != 16) + return SDValue(); + + EVT BaseVT; + if (isChainOfComparesAndLogicalOps(ComparisonResult, BaseVT)) { + // If we have a comparison, we can get the original types to work on instead + // of a vector of i1, which may remove conversion instructions. + ComparisonResult = + DAG.getBoolExtOrTrunc(ComparisonResult, DL, BaseVT, VecVT); + VecVT = BaseVT; + } else { + // We need to ensure correct truncation semantics here, i.e., only use the + // least significant bit. So we mask it and set all bits to that value. + ComparisonResult = DAG.getNode( + ISD::AND, DL, VecVT, ComparisonResult, + DAG.getSplatBuildVector(VecVT, DL, DAG.getConstant(1, DL, MVT::i64))); + + // We may not have the original vector type here anymore but only one + // consisting of i1's, which gets promoted to (or i16 depending on + // n). So if the original vector is 16-byte, we may only get an 8-byte + // comparison here, which means there will be an extra vector extract + // somewhere along the way. + ComparisonResult = DAG.getSetCC( + DL, VecVT, ComparisonResult, + DAG.getSplatBuildVector(VecVT, DL, DAG.getConstant(0, DL, MVT::i64)), + ISD::CondCode::SETNE); + } + + // Larger vectors don't map directly to this conversion, so to avoid too many + // edge cases, we don't apply it here. The conversion will likely still be + // applied later via multiple smaller vectors, whose results are concatenated. + if (VecVT.getSizeInBits() > 128) + return SDValue(); + + SDValue VectorBits; + if (VecVT == MVT::v16i8) { + // v16i8 is a special case, as we need to split it into two halves and + // combine, perform the mask+addition twice, and then combine them. + SmallVector MaskConstants; + for (unsigned Half = 0; Half < 2; ++Half) { + for (unsigned MaskBit = 1; MaskBit <= 128; MaskBit *= 2) { + MaskConstants.push_back(DAG.getConstant(MaskBit, DL, MVT::i32)); + } + } + SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants); + SDValue RepresentativeBits = + DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask); + + EVT HalfVT = VecVT.getHalfNumVectorElementsVT(*DAG.getContext()); + unsigned NumElementsInHalf = HalfVT.getVectorNumElements(); + + SDValue LowHalf = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, RepresentativeBits, + DAG.getConstant(0, DL, MVT::i64)); + SDValue HighHalf = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, RepresentativeBits, + DAG.getConstant(NumElementsInHalf, DL, MVT::i64)); + + SDValue ReducedLowBits = + DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i16, LowHalf); + SDValue ReducedHighBits = + DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i16, HighHalf); + + SDValue ShiftedHighBits = + DAG.getNode(ISD::SHL, DL, MVT::i16, ReducedHighBits, + DAG.getConstant(NumElementsInHalf, DL, MVT::i32)); + VectorBits = + DAG.getNode(ISD::OR, DL, MVT::i16, ShiftedHighBits, ReducedLowBits); + } else { + SmallVector MaskConstants; + unsigned MaxBitMask = 1u << (VecVT.getVectorNumElements() - 1); + for (unsigned MaskBit = 1; MaskBit <= MaxBitMask; MaskBit *= 2) { + MaskConstants.push_back(DAG.getConstant(MaskBit, DL, MVT::i64)); + } + + SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants); + SDValue RepresentativeBits = + DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask); + EVT ResultVT = MVT::getIntegerVT(std::max( + NumElts, VecVT.getVectorElementType().getSizeInBits())); + VectorBits = + DAG.getNode(ISD::VECREDUCE_ADD, DL, ResultVT, RepresentativeBits); + } + + return VectorBits; +} + static SDValue performSTORECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG, @@ -22156,6 +22282,34 @@ return true; } +static void replaceBoolVectorBitcast(SDNode *N, + SmallVectorImpl &Results, + SelectionDAG &DAG) { + SDLoc DL(N); + SDValue Op = N->getOperand(0); + EVT VT = N->getValueType(0); + EVT SrcVT = Op.getValueType(); + assert(SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 && + "Must be bool vector."); + + // Special handling for Clang's __builtin_convertvector. For vectors with <8 + // elements, it adds a vector concatenation with undef(s). If we encounter + // this here, we can skip the concat. + if (Op.getOpcode() == ISD::CONCAT_VECTORS && Op.hasOneUse() && + !Op.getOperand(0).isUndef()) { + bool AllUndef = true; + for (unsigned I = 1; I < Op.getNumOperands(); ++I) + AllUndef &= Op.getOperand(I).isUndef(); + + if (AllUndef) + Op = Op.getOperand(0); + } + + SDValue VectorBits = vectorToScalarBitmask(Op.getNode(), DAG); + if (VectorBits) + Results.push_back(DAG.getZExtOrTrunc(VectorBits, DL, VT)); +} + void AArch64TargetLowering::ReplaceBITCASTResults( SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const { SDLoc DL(N); @@ -22180,6 +22334,9 @@ return; } + if (SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1) + return replaceBoolVectorBitcast(N, Results, DAG); + if (VT != MVT::i16 || (SrcVT != MVT::f16 && SrcVT != MVT::bf16)) return; diff --git a/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll b/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll @@ -0,0 +1,274 @@ +; RUN: llc -mtriple=aarch64-apple-darwin -mattr=+neon -verify-machineinstrs < %s | FileCheck %s + +; Basic tests from input vector to bitmask +; IR generated from clang for: +; __builtin_convertvector + reinterpret_cast + +define i16 @convert_to_bitmask16(<16 x i8> %vec) { +; Bits used in mask +; CHECK-LABEL: lCPI0_0 +; CHECK-NEXT: .byte 1 +; CHECK-NEXT: .byte 2 +; CHECK-NEXT: .byte 4 +; CHECK-NEXT: .byte 8 +; CHECK-NEXT: .byte 16 +; CHECK-NEXT: .byte 32 +; CHECK-NEXT: .byte 64 +; CHECK-NEXT: .byte 128 +; CHECK-NEXT: .byte 1 +; CHECK-NEXT: .byte 2 +; CHECK-NEXT: .byte 4 +; CHECK-NEXT: .byte 8 +; CHECK-NEXT: .byte 16 +; CHECK-NEXT: .byte 32 +; CHECK-NEXT: .byte 64 +; CHECK-NEXT: .byte 128 + +; Actual conversion +; CHECK-LABEL: convert_to_bitmask16 +; CHECK: ; %bb.0: +; CHECK-NEXT: Lloh0: +; CHECK-NEXT: adrp x8, lCPI0_0@PAGE +; CHECK-NEXT: cmeq.16b v0, v0, #0 +; CHECK-NEXT: Lloh1: +; CHECK-NEXT: ldr q1, [x8, lCPI0_0@PAGEOFF] +; CHECK-NEXT: bic.16b v0, v1, v0 +; CHECK-NEXT: ext.16b v1, v0, v0, #8 +; CHECK-NEXT: addv.8b b0, v0 +; CHECK-NEXT: addv.8b b1, v1 +; CHECK-NEXT: fmov w9, s0 +; CHECK-NEXT: fmov w8, s1 +; CHECK-NEXT: orr w0, w9, w8, lsl #8 +; CHECK-NEXT: ret + + %cmp_result = icmp ne <16 x i8> %vec, zeroinitializer + %bitmask = bitcast <16 x i1> %cmp_result to i16 + ret i16 %bitmask +} + +define i16 @convert_to_bitmask8(<8 x i16> %vec) { +; CHECK-LABEL: lCPI1_0: +; CHECK-NEXT: .short 1 +; CHECK-NEXT: .short 2 +; CHECK-NEXT: .short 4 +; CHECK-NEXT: .short 8 +; CHECK-NEXT: .short 16 +; CHECK-NEXT: .short 32 +; CHECK-NEXT: .short 64 +; CHECK-NEXT: .short 128 + +; CHECK-LABEL: convert_to_bitmask8 +; CHECK: ; %bb.0: +; CHECK-NEXT: Lloh2: +; CHECK-NEXT: adrp x8, lCPI1_0@PAGE +; CHECK-NEXT: cmeq.8h v0, v0, #0 +; CHECK-NEXT: Lloh3: +; CHECK-NEXT: ldr q1, [x8, lCPI1_0@PAGEOFF] +; CHECK-NEXT: bic.16b v0, v1, v0 +; CHECK-NEXT: addv.8h h0, v0 +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: and w0, w8, #0xff +; CHECK-NEXT: ret + + %cmp_result = icmp ne <8 x i16> %vec, zeroinitializer + %bitmask = bitcast <8 x i1> %cmp_result to i8 + %extended_bitmask = zext i8 %bitmask to i16 + ret i16 %extended_bitmask +} + +define i4 @convert_to_bitmask4(<4 x i32> %vec) { +; CHECK-LABEL: lCPI2_0: +; CHECK-NEXT: .long 1 +; CHECK-NEXT: .long 2 +; CHECK-NEXT: .long 4 +; CHECK-NEXT: .long 8 + +; CHECK-LABEL: convert_to_bitmask4 +; CHECK: ; %bb.0: +; CHECK-NEXT: Lloh4: +; CHECK-NEXT: adrp x8, lCPI2_0@PAGE +; CHECK-NEXT: cmeq.4s v0, v0, #0 +; CHECK-NEXT: Lloh5: +; CHECK-NEXT: ldr q1, [x8, lCPI2_0@PAGEOFF] +; CHECK-NEXT: bic.16b v0, v1, v0 +; CHECK-NEXT: addv.4s s0, v0 +; CHECK-NEXT: fmov w0, s0 +; CHECK-NEXT: ret + + %cmp_result = icmp ne <4 x i32> %vec, zeroinitializer + %bitmask = bitcast <4 x i1> %cmp_result to i4 + ret i4 %bitmask +} + +define i8 @convert_to_bitmask2(<2 x i64> %vec) { +; CHECK-LABEL: lCPI3_0: +; CHECK-NEXT: .quad 1 +; CHECK-NEXT: .quad 2 + +; CHECK-LABEL: convert_to_bitmask2 +; CHECK: ; %bb.0: +; CHECK-NEXT: Lloh6: +; CHECK-NEXT: adrp x8, lCPI3_0@PAGE +; CHECK-NEXT: cmeq.2d v0, v0, #0 +; CHECK-NEXT: Lloh7: +; CHECK-NEXT: ldr q1, [x8, lCPI3_0@PAGEOFF] +; CHECK-NEXT: bic.16b v0, v1, v0 +; CHECK-NEXT: addp.2d d0, v0 +; CHECK-NEXT: fmov x8, d0 +; CHECK-NEXT: and w0, w8, #0x3 +; CHECK-NEXT: ret + + %cmp_result = icmp ne <2 x i64> %vec, zeroinitializer + %bitmask = bitcast <2 x i1> %cmp_result to i2 + %extended_bitmask = zext i2 %bitmask to i8 + ret i8 %extended_bitmask +} + +; Clang's __builtin_convertvector adds an undef vector concat for vectors with <8 elements. +define i8 @clang_builtins_undef_concat_convert_to_bitmask4(<4 x i32> %vec) { +; CHECK-LABEL: lCPI4_0: +; CHECK-NEXT: .long 1 +; CHECK-NEXT: .long 2 +; CHECK-NEXT: .long 4 +; CHECK-NEXT: .long 8 + +; CHECK-LABEL: clang_builtins_undef_concat_convert_to_bitmask4 +; CHECK: ; %bb.0: +; CHECK-NEXT: Lloh8: +; CHECK-NEXT: adrp x8, lCPI4_0@PAGE +; CHECK-NEXT: cmeq.4s v0, v0, #0 +; CHECK-NEXT: Lloh9: +; CHECK-NEXT: ldr q1, [x8, lCPI4_0@PAGEOFF] +; CHECK-NEXT: bic.16b v0, v1, v0 +; CHECK-NEXT: addv.4s s0, v0 +; CHECK-NEXT: fmov w0, s0 +; CHECK-NEXT: ret + + %cmp_result = icmp ne <4 x i32> %vec, zeroinitializer + %vector_pad = shufflevector <4 x i1> %cmp_result, <4 x i1> poison, <8 x i32> + %bitmask = bitcast <8 x i1> %vector_pad to i8 + ret i8 %bitmask +} + + +define i4 @convert_to_bitmask_no_compare(<4 x i32> %vec1, <4 x i32> %vec2) { +; CHECK-LABEL: lCPI5_0: +; CHECK-NEXT: .short 1 +; CHECK-NEXT: .short 2 +; CHECK-NEXT: .short 4 +; CHECK-NEXT: .short 8 + +; CHECK-LABEL: convert_to_bitmask_no_compare +; CHECK: ; %bb.0: +; CHECK-NEXT: movi.4h v2, #1 +; CHECK-NEXT: Lloh10: +; CHECK-NEXT: adrp x8, lCPI5_0@PAGE +; CHECK-NEXT: and.16b v0, v0, v1 +; CHECK-NEXT: xtn.4h v0, v0 +; CHECK-NEXT: Lloh11: +; CHECK-NEXT: ldr d1, [x8, lCPI5_0@PAGEOFF] +; CHECK-NEXT: and.8b v0, v0, v2 +; CHECK-NEXT: cmeq.4h v0, v0, #0 +; CHECK-NEXT: bic.8b v0, v1, v0 +; CHECK-NEXT: addv.4h h0, v0 +; CHECK-NEXT: fmov w0, s0 +; CHECK-NEXT: ret + + %cmp = and <4 x i32> %vec1, %vec2 + %trunc = trunc <4 x i32> %cmp to <4 x i1> + %bitmask = bitcast <4 x i1> %trunc to i4 + ret i4 %bitmask +} + +define i4 @convert_to_bitmask_with_compare_chain(<4 x i32> %vec1, <4 x i32> %vec2) { +; CHECK-LABEL: lCPI6_0: +; CHECK-NEXT: .long 1 +; CHECK-NEXT: .long 2 +; CHECK-NEXT: .long 4 +; CHECK-NEXT: .long 8 + +; CHECK-LABEL: convert_to_bitmask_with_compare_chain +; CHECK: ; %bb.0: +; CHECK-NEXT: Lloh12: +; CHECK-NEXT: adrp x8, lCPI6_0@PAGE +; CHECK-NEXT: cmeq.4s v2, v0, #0 +; CHECK-NEXT: cmeq.4s v0, v0, v1 +; CHECK-NEXT: Lloh13: +; CHECK-NEXT: ldr q1, [x8, lCPI6_0@PAGEOFF] +; CHECK-NEXT: bic.16b v0, v0, v2 +; CHECK-NEXT: and.16b v0, v0, v1 +; CHECK-NEXT: addv.4s s0, v0 +; CHECK-NEXT: fmov w0, s0 +; CHECK-NEXT: ret + + %cmp1 = icmp ne <4 x i32> %vec1, zeroinitializer + %cmp2 = icmp eq <4 x i32> %vec1, %vec2 + %cmp3 = and <4 x i1> %cmp1, %cmp2 + %bitmask = bitcast <4 x i1> %cmp3 to i4 + ret i4 %bitmask +} + +define i4 @convert_to_bitmask_with_defensive_compare_for_bad_chain(<4 x i32> %vec1, <4 x i32> %vec2) { +; CHECK-LABEL: lCPI7_0: +; CHECK-NEXT: .short 1 +; CHECK-NEXT: .short 2 +; CHECK-NEXT: .short 4 +; CHECK-NEXT: .short 8 + +; CHECK-LABEL: convert_to_bitmask_with_defensive_compare_for_bad_chain +; CHECK: ; %bb.0: +; CHECK-NEXT: cmeq.4s v0, v0, #0 +; CHECK-NEXT: Lloh14: +; CHECK-NEXT: adrp x8, lCPI7_0@PAGE +; CHECK-NEXT: movi.4h v2, #1 +; CHECK-NEXT: bic.16b v0, v1, v0 +; CHECK-NEXT: xtn.4h v0, v0 +; CHECK-NEXT: Lloh15: +; CHECK-NEXT: ldr d1, [x8, lCPI7_0@PAGEOFF] +; CHECK-NEXT: and.8b v0, v0, v2 +; CHECK-NEXT: cmeq.4h v0, v0, #0 +; CHECK-NEXT: bic.8b v0, v1, v0 +; CHECK-NEXT: addv.4h h0, v0 +; CHECK-NEXT: fmov w0, s0 +; CHECK-NEXT: ret + + %cmp1 = icmp ne <4 x i32> %vec1, zeroinitializer + %trunc_vec = trunc <4 x i32> %vec2 to <4 x i1> + %and_res = and <4 x i1> %cmp1, %trunc_vec + %bitmask = bitcast <4 x i1> %and_res to i4 + ret i4 %bitmask +} + +; TODO(lawben): Change this in follow-up patch to #D145301, as truncating stores fix this. +; Larger vector types don't map directly. +define i8 @no_convert_large_vector(<8 x i32> %vec) { +; CHECK-LABEL: convert_large_vector: +; CHECK: cmeq.4s v1, v1, #0 +; CHECK-NOT: addv + + %cmp_result = icmp ne <8 x i32> %vec, zeroinitializer + %bitmask = bitcast <8 x i1> %cmp_result to i8 + ret i8 %bitmask +} + +define i16 @no_convert_bad_concat(<4 x i32> %vec) { +; CHECK-LABEL: no_convert_bad_concat: +; CHECK: cmtst.4s v0, v0, v0 +; CHECK-NOT: addv.4s s0, v0 + + %cmp_result = icmp ne <4 x i32> %vec, zeroinitializer + %vector_pad = shufflevector <4 x i1> poison, <4 x i1> %cmp_result, <8 x i32> + %bitmask = bitcast <8 x i1> %vector_pad to i8 + %extended_bitmask = zext i8 %bitmask to i16 + ret i16 %extended_bitmask +} + +define <8 x i1> @no_convert_without_direct_bitcast(<8 x i16> %vec) { +; CHECK-LABEL: no_convert_without_direct_bitcast: +; CHECK: cmtst.8h v0, v0, v0 +; CHECK-NOT: addv.4s s0, v0 + + %cmp_result = icmp ne <8 x i16> %vec, zeroinitializer + ret <8 x i1> %cmp_result +} +