diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -40079,7 +40079,8 @@ /// If we are inverting an PTEST/TESTP operand, attempt to adjust the CC /// to avoid the inversion. static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC, - SelectionDAG &DAG) { + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { // TODO: Handle X86ISD::KTEST/X86ISD::KORTEST. if (EFLAGS.getOpcode() != X86ISD::PTEST && EFLAGS.getOpcode() != X86ISD::TESTP) @@ -40141,6 +40142,9 @@ if (Op0 == Op1) { SDValue BC = peekThroughBitcasts(Op0); + EVT BCVT = BC.getValueType(); + assert(BCVT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(BCVT) && + "Unexpected vector type"); // TESTZ(AND(X,Y),AND(X,Y)) == TESTZ(X,Y) if (BC.getOpcode() == ISD::AND || BC.getOpcode() == X86ISD::FAND) { @@ -40156,6 +40160,35 @@ DAG.getBitcast(OpVT, BC.getOperand(0)), DAG.getBitcast(OpVT, BC.getOperand(1))); } + + // If every element is an all-sign value, see if we can use MOVMSK to + // more efficiently extract the sign bits and compare that. + // TODO: Handle TESTC with comparison inversion. + // TODO: Can we remove SimplifyMultipleUseDemandedBits and rely on + // MOVMSK combines to make sure its never worse than PTEST? + unsigned EltBits = BCVT.getScalarSizeInBits(); + if (DAG.ComputeNumSignBits(BC) == EltBits) { + assert(VT == MVT::i32 && "Expected i32 EFLAGS comparison result"); + APInt SignMask = APInt::getSignMask(EltBits); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (SDValue Res = + TLI.SimplifyMultipleUseDemandedBits(BC, SignMask, DAG)) { + // For vXi16 cases we need to use pmovmksb and extract every other + // sign bit. + SDLoc DL(EFLAGS); + if (EltBits == 16) { + MVT MovmskVT = BCVT.is128BitVector() ? MVT::v16i8 : MVT::v32i8; + Res = DAG.getBitcast(MovmskVT, Res); + Res = getPMOVMSKB(DL, Res, DAG, Subtarget); + Res = DAG.getNode(ISD::AND, DL, MVT::i32, Res, + DAG.getConstant(0xAAAAAAAA, DL, MVT::i32)); + } else { + Res = getPMOVMSKB(DL, Res, DAG, Subtarget); + } + return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Res, + DAG.getConstant(0, DL, MVT::i32)); + } + } } // TESTZ(-1,X) == TESTZ(X,X) @@ -40183,7 +40216,7 @@ if (SDValue R = checkBoolTestSetCCCombine(EFLAGS, CC)) return R; - if (SDValue R = combinePTESTCC(EFLAGS, CC, DAG)) + if (SDValue R = combinePTESTCC(EFLAGS, CC, DAG, Subtarget)) return R; return combineSetCCAtomicArith(EFLAGS, CC, DAG, Subtarget); diff --git a/llvm/test/CodeGen/X86/combine-ptest.ll b/llvm/test/CodeGen/X86/combine-ptest.ll --- a/llvm/test/CodeGen/X86/combine-ptest.ll +++ b/llvm/test/CodeGen/X86/combine-ptest.ll @@ -299,16 +299,15 @@ } ; -; TODO: testz(ashr(X,bw-1),-1) -> movmsk(X) +; testz(ashr(X,bw-1),-1) -> movmsk(X) ; define i32 @ptestz_v2i64_signbits(<2 x i64> %c, i32 %a, i32 %b) { ; CHECK-LABEL: ptestz_v2i64_signbits: ; CHECK: # %bb.0: ; CHECK-NEXT: movl %edi, %eax -; CHECK-NEXT: vpxor %xmm1, %xmm1, %xmm1 -; CHECK-NEXT: vpcmpgtq %xmm0, %xmm1, %xmm0 -; CHECK-NEXT: vptest %xmm0, %xmm0 +; CHECK-NEXT: vmovmskpd %xmm0, %ecx +; CHECK-NEXT: testl %ecx, %ecx ; CHECK-NEXT: cmovnel %esi, %eax ; CHECK-NEXT: retq %t1 = ashr <2 x i64> %c, @@ -334,8 +333,8 @@ ; AVX2-LABEL: ptestz_v8i32_signbits: ; AVX2: # %bb.0: ; AVX2-NEXT: movl %edi, %eax -; AVX2-NEXT: vpsrad $31, %ymm0, %ymm0 -; AVX2-NEXT: vptest %ymm0, %ymm0 +; AVX2-NEXT: vmovmskps %ymm0, %ecx +; AVX2-NEXT: testl %ecx, %ecx ; AVX2-NEXT: cmovnel %esi, %eax ; AVX2-NEXT: vzeroupper ; AVX2-NEXT: retq @@ -351,8 +350,8 @@ ; CHECK-LABEL: ptestz_v8i16_signbits: ; CHECK: # %bb.0: ; CHECK-NEXT: movl %edi, %eax -; CHECK-NEXT: vpsraw $15, %xmm0, %xmm0 -; CHECK-NEXT: vptest %xmm0, %xmm0 +; CHECK-NEXT: vpmovmskb %xmm0, %ecx +; CHECK-NEXT: testl $43690, %ecx # imm = 0xAAAA ; CHECK-NEXT: cmovnel %esi, %eax ; CHECK-NEXT: retq %t1 = ashr <8 x i16> %c, @@ -380,9 +379,8 @@ ; AVX2-LABEL: ptestz_v32i8_signbits: ; AVX2: # %bb.0: ; AVX2-NEXT: movl %edi, %eax -; AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1 -; AVX2-NEXT: vpcmpgtb %ymm0, %ymm1, %ymm0 -; AVX2-NEXT: vptest %ymm0, %ymm0 +; AVX2-NEXT: vpmovmskb %ymm0, %ecx +; AVX2-NEXT: testl %ecx, %ecx ; AVX2-NEXT: cmovnel %esi, %eax ; AVX2-NEXT: vzeroupper ; AVX2-NEXT: retq