Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -14519,6 +14519,24 @@ return false; } +// Emit KTEST instruction for bit vectors on AVX-512 +static SDValue EmitKTEST(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (Op.getOpcode() == ISD::BITCAST) { + auto hasKTEST = [&](MVT VT) { + unsigned SizeInBits = VT.getSizeInBits(); + return (Subtarget.hasDQI() && (SizeInBits == 8 || SizeInBits == 8)) || + (Subtarget.hasBWI() && (SizeInBits == 32 || SizeInBits == 64)); + }; + SDValue Op0 = Op.getOperand(0); + MVT Op0VT = Op0.getValueType().getSimpleVT(); + if (Op0VT.isVector() && Op0VT.getVectorElementType() == MVT::i1 && + hasKTEST(Op0VT)) + return DAG.getNode(X86ISD::KTEST, SDLoc(Op), Op0VT, Op0, Op0); + } + return SDValue(); +} + /// Emit nodes that will be selected as "test Op0,Op0", or something /// equivalent. SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, SDLoc dl, @@ -14564,10 +14582,10 @@ // doing a separate TEST. TEST always sets OF and CF to 0, so unless // we prove that the arithmetic won't overflow, we can't use OF or CF. if (Op.getResNo() != 0 || NeedOF || NeedCF) { + // Emit KTEST for bit vectors + if (auto Node = EmitKTEST(Op, DAG, Subtarget)) + return Node; // Emit a CMP with 0, which is the TEST pattern. - //if (Op.getValueType() == MVT::i1) - // return DAG.getNode(X86ISD::CMP, dl, MVT::i1, Op, - // DAG.getConstant(0, MVT::i1)); return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op, DAG.getConstant(0, dl, Op.getValueType())); } @@ -14739,11 +14757,15 @@ } } - if (Opcode == 0) + if (Opcode == 0) { + // Emit KTEST for bit vectors + if (auto Node = EmitKTEST(Op, DAG, Subtarget)) + return Node; + // Emit a CMP with 0, which is the TEST pattern. return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op, DAG.getConstant(0, dl, Op.getValueType())); - + } SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32); SmallVector Ops(Op->op_begin(), Op->op_begin() + NumOperands); Index: llvm/trunk/test/CodeGen/X86/avx512-mask-op.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/avx512-mask-op.ll +++ llvm/trunk/test/CodeGen/X86/avx512-mask-op.ll @@ -244,8 +244,7 @@ ; SKX-NEXT: movb $85, %al ; SKX-NEXT: kmovb %eax, %k1 ; SKX-NEXT: korb %k1, %k0, %k0 -; SKX-NEXT: kmovb %k0, %eax -; SKX-NEXT: testb %al, %al +; SKX-NEXT: ktestb %k0, %k0 ; SKX-NEXT: retq allocas: %a= or <8 x i1> %mask, @@ -1681,3 +1680,113 @@ %ret = select <64 x i1> , <64 x i8> %x, <64 x i8> zeroinitializer ret <64 x i8> %ret } + +define void @ktest_1(<8 x double> %in, double * %base) { +; KNL-LABEL: ktest_1: +; KNL: ## BB#0: +; KNL-NEXT: vmovupd (%rdi), %zmm1 +; KNL-NEXT: vcmpltpd %zmm0, %zmm1, %k1 +; KNL-NEXT: vmovupd 8(%rdi), %zmm1 {%k1} {z} +; KNL-NEXT: vcmpltpd %zmm1, %zmm0, %k0 {%k1} +; KNL-NEXT: kmovw %k0, %eax +; KNL-NEXT: testb %al, %al +; KNL-NEXT: je LBB38_2 +; KNL-NEXT: ## BB#1: ## %L1 +; KNL-NEXT: vmovapd %zmm0, (%rdi) +; KNL-NEXT: retq +; KNL-NEXT: LBB38_2: ## %L2 +; KNL-NEXT: vmovapd %zmm0, 8(%rdi) +; KNL-NEXT: retq +; +; SKX-LABEL: ktest_1: +; SKX: ## BB#0: +; SKX-NEXT: vmovupd (%rdi), %zmm1 +; SKX-NEXT: vcmpltpd %zmm0, %zmm1, %k1 +; SKX-NEXT: vmovupd 8(%rdi), %zmm1 {%k1} {z} +; SKX-NEXT: vcmpltpd %zmm1, %zmm0, %k0 {%k1} +; SKX-NEXT: ktestb %k0, %k0 +; SKX-NEXT: je LBB38_2 +; SKX-NEXT: ## BB#1: ## %L1 +; SKX-NEXT: vmovapd %zmm0, (%rdi) +; SKX-NEXT: retq +; SKX-NEXT: LBB38_2: ## %L2 +; SKX-NEXT: vmovapd %zmm0, 8(%rdi) +; SKX-NEXT: retq + %addr1 = getelementptr double, double * %base, i64 0 + %addr2 = getelementptr double, double * %base, i64 1 + + %vaddr1 = bitcast double* %addr1 to <8 x double>* + %vaddr2 = bitcast double* %addr2 to <8 x double>* + + %val1 = load <8 x double>, <8 x double> *%vaddr1, align 1 + %val2 = load <8 x double>, <8 x double> *%vaddr2, align 1 + + %sel1 = fcmp ogt <8 x double>%in, %val1 + %val3 = select <8 x i1> %sel1, <8 x double> %val2, <8 x double> zeroinitializer + %sel2 = fcmp olt <8 x double> %in, %val3 + %sel3 = and <8 x i1> %sel1, %sel2 + + %int_sel3 = bitcast <8 x i1> %sel3 to i8 + %res = icmp eq i8 %int_sel3, zeroinitializer + br i1 %res, label %L2, label %L1 +L1: + store <8 x double> %in, <8 x double>* %vaddr1 + br label %End +L2: + store <8 x double> %in, <8 x double>* %vaddr2 + br label %End +End: + ret void +} + +define void @ktest_2(<32 x float> %in, float * %base) { +; +; SKX-LABEL: ktest_2: +; SKX: ## BB#0: +; SKX-NEXT: vmovups 64(%rdi), %zmm2 +; SKX-NEXT: vmovups (%rdi), %zmm3 +; SKX-NEXT: vcmpltps %zmm0, %zmm3, %k1 +; SKX-NEXT: vcmpltps %zmm1, %zmm2, %k2 +; SKX-NEXT: kunpckwd %k1, %k2, %k0 +; SKX-NEXT: vmovups 68(%rdi), %zmm2 {%k2} {z} +; SKX-NEXT: vmovups 4(%rdi), %zmm3 {%k1} {z} +; SKX-NEXT: vcmpltps %zmm3, %zmm0, %k1 +; SKX-NEXT: vcmpltps %zmm2, %zmm1, %k2 +; SKX-NEXT: kunpckwd %k1, %k2, %k1 +; SKX-NEXT: kord %k1, %k0, %k0 +; SKX-NEXT: ktestd %k0, %k0 +; SKX-NEXT: je LBB39_2 +; SKX-NEXT: ## BB#1: ## %L1 +; SKX-NEXT: vmovaps %zmm0, (%rdi) +; SKX-NEXT: vmovaps %zmm1, 64(%rdi) +; SKX-NEXT: retq +; SKX-NEXT: LBB39_2: ## %L2 +; SKX-NEXT: vmovaps %zmm0, 4(%rdi) +; SKX-NEXT: vmovaps %zmm1, 68(%rdi) +; SKX-NEXT: retq + %addr1 = getelementptr float, float * %base, i64 0 + %addr2 = getelementptr float, float * %base, i64 1 + + %vaddr1 = bitcast float* %addr1 to <32 x float>* + %vaddr2 = bitcast float* %addr2 to <32 x float>* + + %val1 = load <32 x float>, <32 x float> *%vaddr1, align 1 + %val2 = load <32 x float>, <32 x float> *%vaddr2, align 1 + + %sel1 = fcmp ogt <32 x float>%in, %val1 + %val3 = select <32 x i1> %sel1, <32 x float> %val2, <32 x float> zeroinitializer + %sel2 = fcmp olt <32 x float> %in, %val3 + %sel3 = or <32 x i1> %sel1, %sel2 + + %int_sel3 = bitcast <32 x i1> %sel3 to i32 + %res = icmp eq i32 %int_sel3, zeroinitializer + br i1 %res, label %L2, label %L1 +L1: + store <32 x float> %in, <32 x float>* %vaddr1 + br label %End +L2: + store <32 x float> %in, <32 x float>* %vaddr2 + br label %End +End: + ret void +}