diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -45696,28 +45696,49 @@ llvm_unreachable("Impossible"); } -/// Try to map a 128-bit or larger integer comparison to vector instructions -/// before type legalization splits it up into chunks. +/// Try to map a 128-bit or larger integer comparison/or-vector-reduction to +/// vector instructions before type legalization splits it up into chunks. static SDValue combineVectorSizedSetCCEquality(SDNode *SetCC, SelectionDAG &DAG, const X86Subtarget &Subtarget) { ISD::CondCode CC = cast(SetCC->getOperand(2))->get(); assert((CC == ISD::SETNE || CC == ISD::SETEQ) && "Bad comparison predicate"); - // We're looking for an oversized integer equality comparison. + // We're looking for an oversized integer equality comparison or an + // or-vector-reduction compared agaionst zero. SDValue X = SetCC->getOperand(0); SDValue Y = SetCC->getOperand(1); EVT OpVT = X.getValueType(); + EVT VT = SetCC->getValueType(0); + bool IsCmpWithZero = isNullConstant(Y); unsigned OpSize = OpVT.getSizeInBits(); - if (!OpVT.isScalarInteger() || OpSize < 128) + if (!OpVT.isScalarInteger()) return SDValue(); + // If we're comparing a smaller integer - see if it came from a OR vector + // reduction, compared against zero, else bail out. + SDLoc DL(SetCC); + bool IsReduction = false; + if (OpSize < 128) { + if (!IsCmpWithZero || X.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + ISD::NodeType BinOp; + SDValue Match = DAG.matchBinOpReduction(X.getNode(), BinOp, {ISD::OR}); + if (!Match || Match.getValueSizeInBits() < 128) + return SDValue(); + // Adjust the comparison values to the reductions's vector sources. + IsReduction = true; + OpSize = Match.getValueSizeInBits(); + X = Match; + Y = DAG.getConstant(0, DL, X.getValueType()); + } + // Ignore a comparison with zero because that gets special treatment in // EmitTest(). But make an exception for the special case of a pair of // logically-combined vector-sized operands compared to zero. This pattern may // be generated by the memcmp expansion pass with oversized integer compares // (see PR33325). - bool IsOrXorXorTreeCCZero = isNullConstant(Y) && isOrXorXorTree(X); - if (isNullConstant(Y) && !IsOrXorXorTreeCCZero) + bool IsOrXorXorTreeCCZero = IsCmpWithZero && isOrXorXorTree(X); + if (IsCmpWithZero && !IsOrXorXorTreeCCZero && !IsReduction) return SDValue(); // Don't perform this combine if constructing the vector will be expensive. @@ -45726,19 +45747,15 @@ return isa(X) || X.getValueType().isVector() || X.getOpcode() == ISD::LOAD; }; - if ((!IsVectorBitCastCheap(X) || !IsVectorBitCastCheap(Y)) && - !IsOrXorXorTreeCCZero) + if (!IsOrXorXorTreeCCZero && !IsReduction && + (!IsVectorBitCastCheap(X) || !IsVectorBitCastCheap(Y))) return SDValue(); - EVT VT = SetCC->getValueType(0); - SDLoc DL(SetCC); - bool HasAVX = Subtarget.hasAVX(); - // Use XOR (plus OR) and PTEST after SSE4.1 for 128/256-bit operands. // Use PCMPNEQ (plus OR) and KORTEST for 512-bit operands. // Otherwise use PCMPEQ (plus AND) and mask testing. if ((OpSize == 128 && Subtarget.hasSSE2()) || - (OpSize == 256 && HasAVX) || + (OpSize == 256 && Subtarget.hasAVX()) || (OpSize == 512 && Subtarget.useAVX512Regs())) { bool HasPT = Subtarget.hasSSE41(); diff --git a/llvm/test/CodeGen/X86/pr45378.ll b/llvm/test/CodeGen/X86/pr45378.ll --- a/llvm/test/CodeGen/X86/pr45378.ll +++ b/llvm/test/CodeGen/X86/pr45378.ll @@ -9,43 +9,29 @@ declare i64 @llvm.experimental.vector.reduce.or.v2i64(<2 x i64>) define i1 @parseHeaders(i64 * %ptr) nounwind { -; SSE-LABEL: parseHeaders: -; SSE: # %bb.0: -; SSE-NEXT: movdqu (%rdi), %xmm0 -; SSE-NEXT: pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; SSE-NEXT: por %xmm0, %xmm1 -; SSE-NEXT: movq %xmm1, %rax -; SSE-NEXT: testq %rax, %rax -; SSE-NEXT: sete %al -; SSE-NEXT: retq -; -; AVX1-LABEL: parseHeaders: -; AVX1: # %bb.0: -; AVX1-NEXT: vmovdqu (%rdi), %xmm0 -; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX1-NEXT: vpor %xmm1, %xmm0, %xmm0 -; AVX1-NEXT: vmovq %xmm0, %rax -; AVX1-NEXT: testq %rax, %rax -; AVX1-NEXT: sete %al -; AVX1-NEXT: retq +; SSE2-LABEL: parseHeaders: +; SSE2: # %bb.0: +; SSE2-NEXT: movdqu (%rdi), %xmm0 +; SSE2-NEXT: pxor %xmm1, %xmm1 +; SSE2-NEXT: pcmpeqb %xmm0, %xmm1 +; SSE2-NEXT: pmovmskb %xmm1, %eax +; SSE2-NEXT: cmpl $65535, %eax # imm = 0xFFFF +; SSE2-NEXT: sete %al +; SSE2-NEXT: retq ; -; AVX2-LABEL: parseHeaders: -; AVX2: # %bb.0: -; AVX2-NEXT: vpbroadcastq 8(%rdi), %xmm0 -; AVX2-NEXT: vpor (%rdi), %xmm0, %xmm0 -; AVX2-NEXT: vmovq %xmm0, %rax -; AVX2-NEXT: testq %rax, %rax -; AVX2-NEXT: sete %al -; AVX2-NEXT: retq +; SSE41-LABEL: parseHeaders: +; SSE41: # %bb.0: +; SSE41-NEXT: movdqu (%rdi), %xmm0 +; SSE41-NEXT: ptest %xmm0, %xmm0 +; SSE41-NEXT: sete %al +; SSE41-NEXT: retq ; -; AVX512-LABEL: parseHeaders: -; AVX512: # %bb.0: -; AVX512-NEXT: vpbroadcastq 8(%rdi), %xmm0 -; AVX512-NEXT: vpor (%rdi), %xmm0, %xmm0 -; AVX512-NEXT: vmovq %xmm0, %rax -; AVX512-NEXT: testq %rax, %rax -; AVX512-NEXT: sete %al -; AVX512-NEXT: retq +; AVX-LABEL: parseHeaders: +; AVX: # %bb.0: +; AVX-NEXT: vmovdqu (%rdi), %xmm0 +; AVX-NEXT: vptest %xmm0, %xmm0 +; AVX-NEXT: sete %al +; AVX-NEXT: retq %vptr = bitcast i64 * %ptr to <2 x i64> * %vload = load <2 x i64>, <2 x i64> * %vptr, align 8 %vreduce = call i64 @llvm.experimental.vector.reduce.or.v2i64(<2 x i64> %vload)