diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -21346,55 +21346,67 @@ return true; } -// Check whether an OR'd tree is PTEST-able, or if we can fallback to +// Check whether an OR'd reduction tree is PTEST-able, or if we can fallback to // CMP(MOVMSK(PCMPEQB(X,0))). static SDValue LowerVectorAllZeroTest(SDValue Op, ISD::CondCode CC, const X86Subtarget &Subtarget, SelectionDAG &DAG, SDValue &X86CC) { - assert(Op.getOpcode() == ISD::OR && "Only check OR'd tree."); - - if (!Subtarget.hasSSE2() || !Op->hasOneUse()) + if (!Subtarget.hasSSE2()) return SDValue(); + bool UsePTEST = Subtarget.hasSSE41(); + + auto LowerCmpZero = [&](SDValue V, const SDLoc &DL) { + X86CC = DAG.getTargetConstant(CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE, + DL, MVT::i8); + + if (UsePTEST) + return DAG.getNode(X86ISD::PTEST, DL, MVT::i32, V, V); + + SDValue Result = DAG.getNode(X86ISD::PCMPEQ, DL, MVT::v16i8, V, + getZeroVector(MVT::v16i8, Subtarget, DAG, DL)); + Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result); + return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result, + DAG.getConstant(0xFFFF, DL, MVT::i32)); + }; + SmallVector VecIns; - if (!matchScalarReduction(Op, ISD::OR, VecIns)) - return SDValue(); + if (Op.getOpcode() == ISD::OR && matchScalarReduction(Op, ISD::OR, VecIns)) { + if (!Op->hasOneUse()) + return SDValue(); - // Quit if not 128/256-bit vector. - EVT VT = VecIns[0].getValueType(); - if (!VT.is128BitVector() && !VT.is256BitVector()) - return SDValue(); + // Quit if not 128/256-bit vector. + EVT VT = VecIns[0].getValueType(); + if (!VT.is128BitVector() && !VT.is256BitVector()) + return SDValue(); - SDLoc DL(Op); - bool UsePTEST = Subtarget.hasSSE41(); - MVT TestVT = - VT.is128BitVector() ? (UsePTEST ? MVT::v2i64 : MVT::v16i8) : MVT::v4i64; + SDLoc DL(Op); + MVT TestVT = + VT.is128BitVector() ? (UsePTEST ? MVT::v2i64 : MVT::v16i8) : MVT::v4i64; - // Cast all vectors into TestVT for PTEST/PCMPEQ. - for (unsigned i = 0, e = VecIns.size(); i < e; ++i) - VecIns[i] = DAG.getBitcast(TestVT, VecIns[i]); + // Cast all vectors into TestVT for PTEST/PCMPEQ. + for (unsigned i = 0, e = VecIns.size(); i < e; ++i) + VecIns[i] = DAG.getBitcast(TestVT, VecIns[i]); - // If more than one full vector is evaluated, OR them first before PTEST. - for (unsigned Slot = 0, e = VecIns.size(); e - Slot > 1; Slot += 2, e += 1) { - // Each iteration will OR 2 nodes and append the result until there is only - // 1 node left, i.e. the final OR'd value of all vectors. - SDValue LHS = VecIns[Slot]; - SDValue RHS = VecIns[Slot + 1]; - VecIns.push_back(DAG.getNode(ISD::OR, DL, TestVT, LHS, RHS)); - } + // If more than one full vector is evaluated, OR them first before PTEST. + for (unsigned Slot = 0, e = VecIns.size(); e - Slot > 1; + Slot += 2, e += 1) { + // Each iteration will OR 2 nodes and append the result until there is + // only 1 node left, i.e. the final OR'd value of all vectors. + SDValue LHS = VecIns[Slot]; + SDValue RHS = VecIns[Slot + 1]; + VecIns.push_back(DAG.getNode(ISD::OR, DL, TestVT, LHS, RHS)); + } - X86CC = DAG.getTargetConstant(CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE, - DL, MVT::i8); + return LowerCmpZero(VecIns.back(), DL); + } - if (UsePTEST) - return DAG.getNode(X86ISD::PTEST, DL, MVT::i32, VecIns.back(), - VecIns.back()); + ISD::NodeType BinOp; + if (Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) + if (SDValue Match = DAG.matchBinOpReduction(Op.getNode(), BinOp, {ISD::OR})) + return LowerCmpZero(Match, SDLoc(Op)); - SDValue Result = DAG.getNode(X86ISD::PCMPEQ, DL, MVT::v16i8, VecIns.back(), - getZeroVector(MVT::v16i8, Subtarget, DAG, DL)); - Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result); - return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result, - DAG.getConstant(0xFFFF, DL, MVT::i32)); + return SDValue(); } /// return true if \c Op has a use that doesn't just read flags. @@ -22544,11 +22556,9 @@ // Try to use PTEST/PMOVMSKB for a tree ORs equality compared with 0. // TODO: We could do AND tree with all 1s as well by using the C flag. - if (Op0.getOpcode() == ISD::OR && isNullConstant(Op1) && - (CC == ISD::SETEQ || CC == ISD::SETNE)) { + if (isNullConstant(Op1) && (CC == ISD::SETEQ || CC == ISD::SETNE)) if (SDValue CmpZ = LowerVectorAllZeroTest(Op0, CC, Subtarget, DAG, X86CC)) return CmpZ; - } // Try to lower using KORTEST or KTEST. if (SDValue Test = EmitAVX512Test(Op0, Op1, CC, dl, DAG, Subtarget, X86CC)) diff --git a/llvm/test/CodeGen/X86/pr45378.ll b/llvm/test/CodeGen/X86/pr45378.ll --- a/llvm/test/CodeGen/X86/pr45378.ll +++ b/llvm/test/CodeGen/X86/pr45378.ll @@ -9,43 +9,29 @@ declare i64 @llvm.experimental.vector.reduce.or.v2i64(<2 x i64>) define i1 @parseHeaders(i64 * %ptr) nounwind { -; SSE-LABEL: parseHeaders: -; SSE: # %bb.0: -; SSE-NEXT: movdqu (%rdi), %xmm0 -; SSE-NEXT: pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; SSE-NEXT: por %xmm0, %xmm1 -; SSE-NEXT: movq %xmm1, %rax -; SSE-NEXT: testq %rax, %rax -; SSE-NEXT: sete %al -; SSE-NEXT: retq -; -; AVX1-LABEL: parseHeaders: -; AVX1: # %bb.0: -; AVX1-NEXT: vmovdqu (%rdi), %xmm0 -; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX1-NEXT: vpor %xmm1, %xmm0, %xmm0 -; AVX1-NEXT: vmovq %xmm0, %rax -; AVX1-NEXT: testq %rax, %rax -; AVX1-NEXT: sete %al -; AVX1-NEXT: retq +; SSE2-LABEL: parseHeaders: +; SSE2: # %bb.0: +; SSE2-NEXT: movdqu (%rdi), %xmm0 +; SSE2-NEXT: pxor %xmm1, %xmm1 +; SSE2-NEXT: pcmpeqb %xmm0, %xmm1 +; SSE2-NEXT: pmovmskb %xmm1, %eax +; SSE2-NEXT: cmpl $65535, %eax # imm = 0xFFFF +; SSE2-NEXT: sete %al +; SSE2-NEXT: retq ; -; AVX2-LABEL: parseHeaders: -; AVX2: # %bb.0: -; AVX2-NEXT: vpbroadcastq 8(%rdi), %xmm0 -; AVX2-NEXT: vpor (%rdi), %xmm0, %xmm0 -; AVX2-NEXT: vmovq %xmm0, %rax -; AVX2-NEXT: testq %rax, %rax -; AVX2-NEXT: sete %al -; AVX2-NEXT: retq +; SSE41-LABEL: parseHeaders: +; SSE41: # %bb.0: +; SSE41-NEXT: movdqu (%rdi), %xmm0 +; SSE41-NEXT: ptest %xmm0, %xmm0 +; SSE41-NEXT: sete %al +; SSE41-NEXT: retq ; -; AVX512-LABEL: parseHeaders: -; AVX512: # %bb.0: -; AVX512-NEXT: vpbroadcastq 8(%rdi), %xmm0 -; AVX512-NEXT: vpor (%rdi), %xmm0, %xmm0 -; AVX512-NEXT: vmovq %xmm0, %rax -; AVX512-NEXT: testq %rax, %rax -; AVX512-NEXT: sete %al -; AVX512-NEXT: retq +; AVX-LABEL: parseHeaders: +; AVX: # %bb.0: +; AVX-NEXT: vmovdqu (%rdi), %xmm0 +; AVX-NEXT: vptest %xmm0, %xmm0 +; AVX-NEXT: sete %al +; AVX-NEXT: retq %vptr = bitcast i64 * %ptr to <2 x i64> * %vload = load <2 x i64>, <2 x i64> * %vptr, align 8 %vreduce = call i64 @llvm.experimental.vector.reduce.or.v2i64(<2 x i64> %vload)