diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -20923,11 +20923,11 @@ return Result; } -// Try to select this as a KORTEST+SETCC if possible. -static SDValue EmitKORTEST(SDValue Op0, SDValue Op1, ISD::CondCode CC, - const SDLoc &dl, SelectionDAG &DAG, - const X86Subtarget &Subtarget, - SDValue &X86CC) { +// Try to select this as a KORTEST+SETCC or KTEST+SETCC if possible. +static SDValue EmitAVX512Test(SDValue Op0, SDValue Op1, ISD::CondCode CC, + const SDLoc &dl, SelectionDAG &DAG, + const X86Subtarget &Subtarget, + SDValue &X86CC) { // Only support equality comparisons. if (CC != ISD::SETEQ && CC != ISD::SETNE) return SDValue(); @@ -20952,6 +20952,21 @@ } else return SDValue(); + // If the input is an AND, we can combine it's operands into the KTEST. + bool KTestable = false; + if (Subtarget.hasDQI() && (VT == MVT::v8i1 || VT == MVT::v16i1)) + KTestable = true; + if (Subtarget.hasBWI() && (VT == MVT::v32i1 || VT == MVT::v64i1)) + KTestable = true; + if (!isNullConstant(Op1)) + KTestable = false; + if (KTestable && Op0.getOpcode() == ISD::AND && Op0.hasOneUse()) { + SDValue LHS = Op0.getOperand(0); + SDValue RHS = Op0.getOperand(1); + X86CC = DAG.getTargetConstant(X86Cond, dl, MVT::i8); + return DAG.getNode(X86ISD::KTEST, dl, MVT::i32, LHS, RHS); + } + // If the input is an OR, we can combine it's operands into the KORTEST. SDValue LHS = Op0; SDValue RHS = Op0; @@ -20988,9 +21003,9 @@ return PTEST; } - // Try to lower using KORTEST. - if (SDValue KORTEST = EmitKORTEST(Op0, Op1, CC, dl, DAG, Subtarget, X86CC)) - return KORTEST; + // Try to lower using KORTEST or KTEST. + if (SDValue Test = EmitAVX512Test(Op0, Op1, CC, dl, DAG, Subtarget, X86CC)) + return Test; // Look for X == 0, X == 1, X != 0, or X != 1. We can simplify some forms of // these. diff --git a/llvm/test/CodeGen/X86/avx512-mask-op.ll b/test/CodeGen/X86/avx512-mask-op.ll --- a/llvm/test/CodeGen/X86/avx512-mask-op.ll +++ b/test/CodeGen/X86/avx512-mask-op.ll @@ -2027,8 +2027,8 @@ ; SKX: ## %bb.0: ; SKX-NEXT: vcmpgtpd (%rdi), %zmm0, %k1 ; SKX-NEXT: vmovupd 8(%rdi), %zmm1 {%k1} {z} -; SKX-NEXT: vcmpltpd %zmm1, %zmm0, %k0 {%k1} -; SKX-NEXT: kortestb %k0, %k0 +; SKX-NEXT: vcmpltpd %zmm1, %zmm0, %k0 +; SKX-NEXT: ktestb %k0, %k1 ; SKX-NEXT: je LBB43_2 ; SKX-NEXT: ## %bb.1: ## %L1 ; SKX-NEXT: vmovapd %zmm0, (%rdi) @@ -2060,8 +2060,8 @@ ; AVX512DQ: ## %bb.0: ; AVX512DQ-NEXT: vcmpgtpd (%rdi), %zmm0, %k1 ; AVX512DQ-NEXT: vmovupd 8(%rdi), %zmm1 {%k1} {z} -; AVX512DQ-NEXT: vcmpltpd %zmm1, %zmm0, %k0 {%k1} -; AVX512DQ-NEXT: kortestb %k0, %k0 +; AVX512DQ-NEXT: vcmpltpd %zmm1, %zmm0, %k0 +; AVX512DQ-NEXT: ktestb %k0, %k1 ; AVX512DQ-NEXT: je LBB43_2 ; AVX512DQ-NEXT: ## %bb.1: ## %L1 ; AVX512DQ-NEXT: vmovapd %zmm0, (%rdi) @@ -2077,8 +2077,8 @@ ; X86-NEXT: movl {{[0-9]+}}(%esp), %eax ; X86-NEXT: vcmpgtpd (%eax), %zmm0, %k1 ; X86-NEXT: vmovupd 8(%eax), %zmm1 {%k1} {z} -; X86-NEXT: vcmpltpd %zmm1, %zmm0, %k0 {%k1} -; X86-NEXT: kortestb %k0, %k0 +; X86-NEXT: vcmpltpd %zmm1, %zmm0, %k0 +; X86-NEXT: ktestb %k0, %k1 ; X86-NEXT: je LBB43_2 ; X86-NEXT: ## %bb.1: ## %L1 ; X86-NEXT: vmovapd %zmm0, (%eax)