diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -45532,12 +45532,22 @@ assert((CC == ISD::SETNE || CC == ISD::SETEQ) && "Bad comparison predicate"); // We're looking for an oversized integer equality comparison. + bool IsReduction = false; SDValue X = SetCC->getOperand(0); SDValue Y = SetCC->getOperand(1); - EVT OpVT = X.getValueType(); - unsigned OpSize = OpVT.getSizeInBits(); - if (!OpVT.isScalarInteger() || OpSize < 128) - return SDValue(); + unsigned OpSize = X.getValueSizeInBits(); + if (!X.getValueType().isScalarInteger() || OpSize < 128) { + // See if we can find a horizontal OR reduction, compared against zero. + if (X.getOpcode() != ISD::EXTRACT_VECTOR_ELT || !isNullConstant(Y)) + return SDValue(); + ISD::NodeType BinOp; + SDValue Match = DAG.matchBinOpReduction(X.getNode(), BinOp, {ISD::OR}); + if (!Match || Match.getValueSizeInBits() < 128) + return SDValue(); + OpSize = Match.getValueSizeInBits(); + IsReduction = true; + X = Match; + } // Ignore a comparison with zero because that gets special treatment in // EmitTest(). But make an exception for the special case of a pair of @@ -45545,7 +45555,7 @@ // be generated by the memcmp expansion pass with oversized integer compares // (see PR33325). bool IsOrXorXorTreeCCZero = isNullConstant(Y) && isOrXorXorTree(X); - if (isNullConstant(Y) && !IsOrXorXorTreeCCZero) + if (isNullConstant(Y) && !IsOrXorXorTreeCCZero && !IsReduction) return SDValue(); // Don't perform this combine if constructing the vector will be expensive. @@ -45633,8 +45643,12 @@ // MOVMSK. Cmp = emitOrXorXorTree(X, DL, DAG, VecVT, CmpVT, HasPT, ScalarToVector); } else { + // For reductions, we were comparing against a scalar zero, but now it + // needs to be compared to a vector zero. SDValue VecX = ScalarToVector(X); - SDValue VecY = ScalarToVector(Y); + SDValue VecY = IsReduction ? getZeroVector(VecX.getSimpleValueType(), + Subtarget, DAG, DL) + : ScalarToVector(Y); if (VecVT != CmpVT) { Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, ISD::SETNE); } else if (HasPT) { diff --git a/llvm/test/CodeGen/X86/pr45378.ll b/llvm/test/CodeGen/X86/pr45378.ll --- a/llvm/test/CodeGen/X86/pr45378.ll +++ b/llvm/test/CodeGen/X86/pr45378.ll @@ -9,43 +9,29 @@ declare i64 @llvm.experimental.vector.reduce.or.v2i64(<2 x i64>) define i1 @parseHeaders(i64 * %ptr) nounwind { -; SSE-LABEL: parseHeaders: -; SSE: # %bb.0: -; SSE-NEXT: movdqu (%rdi), %xmm0 -; SSE-NEXT: pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; SSE-NEXT: por %xmm0, %xmm1 -; SSE-NEXT: movq %xmm1, %rax -; SSE-NEXT: testq %rax, %rax -; SSE-NEXT: sete %al -; SSE-NEXT: retq -; -; AVX1-LABEL: parseHeaders: -; AVX1: # %bb.0: -; AVX1-NEXT: vmovdqu (%rdi), %xmm0 -; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX1-NEXT: vpor %xmm1, %xmm0, %xmm0 -; AVX1-NEXT: vmovq %xmm0, %rax -; AVX1-NEXT: testq %rax, %rax -; AVX1-NEXT: sete %al -; AVX1-NEXT: retq +; SSE2-LABEL: parseHeaders: +; SSE2: # %bb.0: +; SSE2-NEXT: movdqu (%rdi), %xmm0 +; SSE2-NEXT: pxor %xmm1, %xmm1 +; SSE2-NEXT: pcmpeqb %xmm0, %xmm1 +; SSE2-NEXT: pmovmskb %xmm1, %eax +; SSE2-NEXT: cmpl $65535, %eax # imm = 0xFFFF +; SSE2-NEXT: sete %al +; SSE2-NEXT: retq ; -; AVX2-LABEL: parseHeaders: -; AVX2: # %bb.0: -; AVX2-NEXT: vpbroadcastq 8(%rdi), %xmm0 -; AVX2-NEXT: vpor (%rdi), %xmm0, %xmm0 -; AVX2-NEXT: vmovq %xmm0, %rax -; AVX2-NEXT: testq %rax, %rax -; AVX2-NEXT: sete %al -; AVX2-NEXT: retq +; SSE41-LABEL: parseHeaders: +; SSE41: # %bb.0: +; SSE41-NEXT: movdqu (%rdi), %xmm0 +; SSE41-NEXT: ptest %xmm0, %xmm0 +; SSE41-NEXT: sete %al +; SSE41-NEXT: retq ; -; AVX512-LABEL: parseHeaders: -; AVX512: # %bb.0: -; AVX512-NEXT: vpbroadcastq 8(%rdi), %xmm0 -; AVX512-NEXT: vpor (%rdi), %xmm0, %xmm0 -; AVX512-NEXT: vmovq %xmm0, %rax -; AVX512-NEXT: testq %rax, %rax -; AVX512-NEXT: sete %al -; AVX512-NEXT: retq +; AVX-LABEL: parseHeaders: +; AVX: # %bb.0: +; AVX-NEXT: vmovdqu (%rdi), %xmm0 +; AVX-NEXT: vptest %xmm0, %xmm0 +; AVX-NEXT: sete %al +; AVX-NEXT: retq %vptr = bitcast i64 * %ptr to <2 x i64> * %vload = load <2 x i64>, <2 x i64> * %vptr, align 8 %vreduce = call i64 @llvm.experimental.vector.reduce.or.v2i64(<2 x i64> %vload)