Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -12123,10 +12123,15 @@ } unsigned IdxVal = cast(Idx)->getZExtValue(); - const TargetRegisterClass* rc = getRegClassFor(VecVT); - if (!Subtarget.hasDQI() && (VecVT.getVectorNumElements() <= 8)) - rc = getRegClassFor(MVT::v16i1); - unsigned MaxSift = rc->getSize()*8 - 1; + if (!Subtarget.hasDQI() && (VecVT.getVectorNumElements() <= 8)) { + // Use kshiftlw/rw instruction. + VecVT = MVT::v16i1; + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VecVT, + DAG.getUNDEF(VecVT), + Vec, + DAG.getIntPtrConstant(0, dl)); + } + unsigned MaxSift = VecVT.getVectorNumElements() - 1; Vec = DAG.getNode(X86ISD::VSHLI, dl, VecVT, Vec, DAG.getConstant(MaxSift - IdxVal, dl, MVT::i8)); Vec = DAG.getNode(X86ISD::VSRLI, dl, VecVT, Vec, Index: llvm/trunk/lib/Target/X86/X86InstrInfo.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86InstrInfo.cpp +++ llvm/trunk/lib/Target/X86/X86InstrInfo.cpp @@ -4302,12 +4302,14 @@ return 0; } +static bool isMaskRegClass(const TargetRegisterClass *RC) { + // All KMASK RegClasses hold the same k registers, can be tested against anyone. + return X86::VK16RegClass.hasSubClassEq(RC); +} + static bool MaskRegClassContains(unsigned Reg) { - return X86::VK8RegClass.contains(Reg) || - X86::VK16RegClass.contains(Reg) || - X86::VK32RegClass.contains(Reg) || - X86::VK64RegClass.contains(Reg) || - X86::VK1RegClass.contains(Reg); + // All KMASK RegClasses hold the same k registers, can be tested against anyone. + return X86::VK16RegClass.contains(Reg); } static bool GRRegClassContains(unsigned Reg) { @@ -4509,15 +4511,28 @@ llvm_unreachable("Cannot emit physreg copy instruction"); } +static unsigned getLoadStoreMaskRegOpcode(const TargetRegisterClass *RC, + bool load) { + switch (RC->getSize()) { + default: + llvm_unreachable("Unknown spill size"); + case 2: + return load ? X86::KMOVWkm : X86::KMOVWmk; + case 4: + return load ? X86::KMOVDkm : X86::KMOVDmk; + case 8: + return load ? X86::KMOVQkm : X86::KMOVQmk; + } +} + static unsigned getLoadStoreRegOpcode(unsigned Reg, const TargetRegisterClass *RC, bool isStackAligned, const X86Subtarget &STI, bool load) { if (STI.hasAVX512()) { - if (X86::VK8RegClass.hasSubClassEq(RC) || - X86::VK16RegClass.hasSubClassEq(RC)) - return load ? X86::KMOVWkm : X86::KMOVWmk; + if (isMaskRegClass(RC)) + return getLoadStoreMaskRegOpcode(RC, load); if (RC->getSize() == 4 && X86::FR32XRegClass.hasSubClassEq(RC)) return load ? X86::VMOVSSZrm : X86::VMOVSSZmr; if (RC->getSize() == 8 && X86::FR64XRegClass.hasSubClassEq(RC)) Index: llvm/trunk/lib/Target/X86/X86RegisterInfo.td =================================================================== --- llvm/trunk/lib/Target/X86/X86RegisterInfo.td +++ llvm/trunk/lib/Target/X86/X86RegisterInfo.td @@ -477,18 +477,18 @@ 256, (sequence "YMM%u", 0, 31)>; // Mask registers -def VK1 : RegisterClass<"X86", [i1], 8, (sequence "K%u", 0, 7)> {let Size = 8;} -def VK2 : RegisterClass<"X86", [v2i1], 8, (add VK1)> {let Size = 8;} -def VK4 : RegisterClass<"X86", [v4i1], 8, (add VK2)> {let Size = 8;} -def VK8 : RegisterClass<"X86", [v8i1], 8, (add VK4)> {let Size = 8;} +def VK1 : RegisterClass<"X86", [i1], 16, (sequence "K%u", 0, 7)> {let Size = 16;} +def VK2 : RegisterClass<"X86", [v2i1], 16, (add VK1)> {let Size = 16;} +def VK4 : RegisterClass<"X86", [v4i1], 16, (add VK2)> {let Size = 16;} +def VK8 : RegisterClass<"X86", [v8i1], 16, (add VK4)> {let Size = 16;} def VK16 : RegisterClass<"X86", [v16i1], 16, (add VK8)> {let Size = 16;} def VK32 : RegisterClass<"X86", [v32i1], 32, (add VK16)> {let Size = 32;} def VK64 : RegisterClass<"X86", [v64i1], 64, (add VK32)> {let Size = 64;} -def VK1WM : RegisterClass<"X86", [i1], 8, (sub VK1, K0)> {let Size = 8;} -def VK2WM : RegisterClass<"X86", [v2i1], 8, (sub VK2, K0)> {let Size = 8;} -def VK4WM : RegisterClass<"X86", [v4i1], 8, (sub VK4, K0)> {let Size = 8;} -def VK8WM : RegisterClass<"X86", [v8i1], 8, (sub VK8, K0)> {let Size = 8;} +def VK1WM : RegisterClass<"X86", [i1], 16, (sub VK1, K0)> {let Size = 16;} +def VK2WM : RegisterClass<"X86", [v2i1], 16, (sub VK2, K0)> {let Size = 16;} +def VK4WM : RegisterClass<"X86", [v4i1], 16, (sub VK4, K0)> {let Size = 16;} +def VK8WM : RegisterClass<"X86", [v8i1], 16, (sub VK8, K0)> {let Size = 16;} def VK16WM : RegisterClass<"X86", [v16i1], 16, (add VK8WM)> {let Size = 16;} def VK32WM : RegisterClass<"X86", [v32i1], 32, (add VK16WM)> {let Size = 32;} def VK64WM : RegisterClass<"X86", [v64i1], 64, (add VK32WM)> {let Size = 64;} Index: llvm/trunk/test/CodeGen/X86/avx512-intel-ocl.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/avx512-intel-ocl.ll +++ llvm/trunk/test/CodeGen/X86/avx512-intel-ocl.ll @@ -68,10 +68,10 @@ ; WIN64: vmovups {{.*(%rbp).*}}, %zmm21 # 64-byte Reload ; X64-LABEL: test_prolog_epilog -; X64: kmovw %k7, {{.*}}(%rsp) ## 8-byte Folded Spill -; X64: kmovw %k6, {{.*}}(%rsp) ## 8-byte Folded Spill -; X64: kmovw %k5, {{.*}}(%rsp) ## 8-byte Folded Spill -; X64: kmovw %k4, {{.*}}(%rsp) ## 8-byte Folded Spill +; X64: kmovq %k7, {{.*}}(%rsp) ## 8-byte Folded Spill +; X64: kmovq %k6, {{.*}}(%rsp) ## 8-byte Folded Spill +; X64: kmovq %k5, {{.*}}(%rsp) ## 8-byte Folded Spill +; X64: kmovq %k4, {{.*}}(%rsp) ## 8-byte Folded Spill ; X64: vmovups %zmm31, {{.*}}(%rsp) ## 64-byte Spill ; X64: vmovups %zmm16, {{.*}}(%rsp) ## 64-byte Spill ; X64: call Index: llvm/trunk/test/CodeGen/X86/avx512-mask-spills.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/avx512-mask-spills.ll +++ llvm/trunk/test/CodeGen/X86/avx512-mask-spills.ll @@ -0,0 +1,126 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-apple-darwin -mcpu=skx | FileCheck %s --check-prefix=CHECK --check-prefix=SKX + +declare void @f() +define <4 x i1> @test_4i1(<4 x i32> %a, <4 x i32> %b) { +; CHECK-LABEL: test_4i1: +; CHECK: ## BB#0: +; CHECK-NEXT: pushq %rax +; CHECK-NEXT: Ltmp0: +; CHECK-NEXT: .cfi_def_cfa_offset 16 +; CHECK-NEXT: vpcmpnleud %xmm1, %xmm0, %k0 +; CHECK-NEXT: kmovw %k0, {{[0-9]+}}(%rsp) ## 2-byte Folded Spill +; CHECK-NEXT: vpcmpgtd %xmm1, %xmm0, %k0 +; CHECK-NEXT: kmovw %k0, {{[0-9]+}}(%rsp) ## 2-byte Folded Spill +; CHECK-NEXT: callq _f +; CHECK-NEXT: kmovw {{[0-9]+}}(%rsp), %k0 ## 2-byte Folded Reload +; CHECK-NEXT: kmovw {{[0-9]+}}(%rsp), %k1 ## 2-byte Folded Reload +; CHECK-NEXT: korw %k1, %k0, %k0 +; CHECK-NEXT: vpmovm2d %k0, %xmm0 +; CHECK-NEXT: popq %rax +; CHECK-NEXT: retq + + %cmp_res = icmp ugt <4 x i32> %a, %b + %cmp_res2 = icmp sgt <4 x i32> %a, %b + call void @f() + %res = or <4 x i1> %cmp_res, %cmp_res2 + ret <4 x i1> %res +} + +define <8 x i1> @test_8i1(<8 x i32> %a, <8 x i32> %b) { +; CHECK-LABEL: test_8i1: +; CHECK: ## BB#0: +; CHECK-NEXT: pushq %rax +; CHECK-NEXT: Ltmp1: +; CHECK-NEXT: .cfi_def_cfa_offset 16 +; CHECK-NEXT: vpcmpnleud %ymm1, %ymm0, %k0 +; CHECK-NEXT: kmovw %k0, {{[0-9]+}}(%rsp) ## 2-byte Folded Spill +; CHECK-NEXT: vpcmpgtd %ymm1, %ymm0, %k0 +; CHECK-NEXT: kmovw %k0, {{[0-9]+}}(%rsp) ## 2-byte Folded Spill +; CHECK-NEXT: callq _f +; CHECK-NEXT: kmovw {{[0-9]+}}(%rsp), %k0 ## 2-byte Folded Reload +; CHECK-NEXT: kmovw {{[0-9]+}}(%rsp), %k1 ## 2-byte Folded Reload +; CHECK-NEXT: korb %k1, %k0, %k0 +; CHECK-NEXT: vpmovm2w %k0, %xmm0 +; CHECK-NEXT: popq %rax +; CHECK-NEXT: retq + + %cmp_res = icmp ugt <8 x i32> %a, %b + %cmp_res2 = icmp sgt <8 x i32> %a, %b + call void @f() + %res = or <8 x i1> %cmp_res, %cmp_res2 + ret <8 x i1> %res +} + +define <16 x i1> @test_16i1(<16 x i32> %a, <16 x i32> %b) { +; CHECK-LABEL: test_16i1: +; CHECK: ## BB#0: +; CHECK-NEXT: pushq %rax +; CHECK-NEXT: Ltmp2: +; CHECK-NEXT: .cfi_def_cfa_offset 16 +; CHECK-NEXT: vpcmpnleud %zmm1, %zmm0, %k0 +; CHECK-NEXT: kmovw %k0, {{[0-9]+}}(%rsp) ## 2-byte Folded Spill +; CHECK-NEXT: vpcmpgtd %zmm1, %zmm0, %k0 +; CHECK-NEXT: kmovw %k0, {{[0-9]+}}(%rsp) ## 2-byte Folded Spill +; CHECK-NEXT: callq _f +; CHECK-NEXT: kmovw {{[0-9]+}}(%rsp), %k0 ## 2-byte Folded Reload +; CHECK-NEXT: kmovw {{[0-9]+}}(%rsp), %k1 ## 2-byte Folded Reload +; CHECK-NEXT: korw %k1, %k0, %k0 +; CHECK-NEXT: vpmovm2b %k0, %xmm0 +; CHECK-NEXT: popq %rax +; CHECK-NEXT: retq + %cmp_res = icmp ugt <16 x i32> %a, %b + %cmp_res2 = icmp sgt <16 x i32> %a, %b + call void @f() + %res = or <16 x i1> %cmp_res, %cmp_res2 + ret <16 x i1> %res +} + +define <32 x i1> @test_32i1(<32 x i16> %a, <32 x i16> %b) { +; CHECK-LABEL: test_32i1: +; CHECK: ## BB#0: +; CHECK-NEXT: pushq %rax +; CHECK-NEXT: Ltmp3: +; CHECK-NEXT: .cfi_def_cfa_offset 16 +; CHECK-NEXT: vpcmpnleuw %zmm1, %zmm0, %k0 +; CHECK-NEXT: kmovd %k0, {{[0-9]+}}(%rsp) ## 4-byte Folded Spill +; CHECK-NEXT: vpcmpgtw %zmm1, %zmm0, %k0 +; CHECK-NEXT: kmovd %k0, (%rsp) ## 4-byte Folded Spill +; CHECK-NEXT: callq _f +; CHECK-NEXT: kmovd {{[0-9]+}}(%rsp), %k0 ## 4-byte Folded Reload +; CHECK-NEXT: kmovd (%rsp), %k1 ## 4-byte Folded Reload +; CHECK-NEXT: kord %k1, %k0, %k0 +; CHECK-NEXT: vpmovm2b %k0, %ymm0 +; CHECK-NEXT: popq %rax +; CHECK-NEXT: retq + %cmp_res = icmp ugt <32 x i16> %a, %b + %cmp_res2 = icmp sgt <32 x i16> %a, %b + call void @f() + %res = or <32 x i1> %cmp_res, %cmp_res2 + ret <32 x i1> %res +} + +define <64 x i1> @test_64i1(<64 x i8> %a, <64 x i8> %b) { +; CHECK-LABEL: test_64i1: +; CHECK: ## BB#0: +; CHECK-NEXT: subq $24, %rsp +; CHECK-NEXT: Ltmp4: +; CHECK-NEXT: .cfi_def_cfa_offset 32 +; CHECK-NEXT: vpcmpnleub %zmm1, %zmm0, %k0 +; CHECK-NEXT: kmovq %k0, {{[0-9]+}}(%rsp) ## 8-byte Folded Spill +; CHECK-NEXT: vpcmpgtb %zmm1, %zmm0, %k0 +; CHECK-NEXT: kmovq %k0, {{[0-9]+}}(%rsp) ## 8-byte Folded Spill +; CHECK-NEXT: callq _f +; CHECK-NEXT: kmovq {{[0-9]+}}(%rsp), %k0 ## 8-byte Folded Reload +; CHECK-NEXT: kmovq {{[0-9]+}}(%rsp), %k1 ## 8-byte Folded Reload +; CHECK-NEXT: korq %k1, %k0, %k0 +; CHECK-NEXT: vpmovm2b %k0, %zmm0 +; CHECK-NEXT: addq $24, %rsp +; CHECK-NEXT: retq + + %cmp_res = icmp ugt <64 x i8> %a, %b + %cmp_res2 = icmp sgt <64 x i8> %a, %b + call void @f() + %res = or <64 x i1> %cmp_res, %cmp_res2 + ret <64 x i1> %res +}