diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -21771,8 +21771,11 @@ const X86Subtarget &Subtarget, SelectionDAG &DAG, X86::CondCode &X86CC) { EVT VT = V.getValueType(); - assert(Mask.getBitWidth() == VT.getScalarSizeInBits() && - "Element Mask vs Vector bitwidth mismatch"); + unsigned ScalarSize = VT.getScalarSizeInBits(); + if (Mask.getBitWidth() != ScalarSize) { + assert(ScalarSize == 1 && "Element Mask vs Vector bitwidth mismatch"); + return SDValue(); + } assert((CC == ISD::SETEQ || CC == ISD::SETNE) && "Unsupported ISD::CondCode"); X86CC = (CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE); diff --git a/llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll b/llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll --- a/llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll +++ b/llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll @@ -1043,6 +1043,110 @@ ret i1 %6 } +define i32 @mask_v3i1(<3 x i32> %a, <3 x i32> %b) { +; SSE2-LABEL: mask_v3i1: +; SSE2: # %bb.0: +; SSE2-NEXT: pcmpeqd %xmm1, %xmm0 +; SSE2-NEXT: pcmpeqd %xmm1, %xmm1 +; SSE2-NEXT: pxor %xmm0, %xmm1 +; SSE2-NEXT: movd %xmm1, %eax +; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[1,1,1,1] +; SSE2-NEXT: movd %xmm0, %ecx +; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3] +; SSE2-NEXT: movd %xmm0, %edx +; SSE2-NEXT: orl %ecx, %edx +; SSE2-NEXT: orl %eax, %edx +; SSE2-NEXT: testb $1, %dl +; SSE2-NEXT: je .LBB27_2 +; SSE2-NEXT: # %bb.1: +; SSE2-NEXT: xorl %eax, %eax +; SSE2-NEXT: retq +; SSE2-NEXT: .LBB27_2: +; SSE2-NEXT: movl $1, %eax +; SSE2-NEXT: retq +; +; SSE41-LABEL: mask_v3i1: +; SSE41: # %bb.0: +; SSE41-NEXT: pcmpeqd %xmm1, %xmm0 +; SSE41-NEXT: pcmpeqd %xmm1, %xmm1 +; SSE41-NEXT: pxor %xmm0, %xmm1 +; SSE41-NEXT: pextrd $1, %xmm1, %eax +; SSE41-NEXT: movd %xmm1, %ecx +; SSE41-NEXT: pextrd $2, %xmm1, %edx +; SSE41-NEXT: orl %eax, %edx +; SSE41-NEXT: orl %ecx, %edx +; SSE41-NEXT: testb $1, %dl +; SSE41-NEXT: je .LBB27_2 +; SSE41-NEXT: # %bb.1: +; SSE41-NEXT: xorl %eax, %eax +; SSE41-NEXT: retq +; SSE41-NEXT: .LBB27_2: +; SSE41-NEXT: movl $1, %eax +; SSE41-NEXT: retq +; +; AVX1-LABEL: mask_v3i1: +; AVX1: # %bb.0: +; AVX1-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm0 +; AVX1-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1 +; AVX1-NEXT: vpxor %xmm1, %xmm0, %xmm0 +; AVX1-NEXT: vpextrd $1, %xmm0, %eax +; AVX1-NEXT: vmovd %xmm0, %ecx +; AVX1-NEXT: vpextrd $2, %xmm0, %edx +; AVX1-NEXT: orl %eax, %edx +; AVX1-NEXT: orl %ecx, %edx +; AVX1-NEXT: testb $1, %dl +; AVX1-NEXT: je .LBB27_2 +; AVX1-NEXT: # %bb.1: +; AVX1-NEXT: xorl %eax, %eax +; AVX1-NEXT: retq +; AVX1-NEXT: .LBB27_2: +; AVX1-NEXT: movl $1, %eax +; AVX1-NEXT: retq +; +; AVX2-LABEL: mask_v3i1: +; AVX2: # %bb.0: +; AVX2-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm0 +; AVX2-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1 +; AVX2-NEXT: vpxor %xmm1, %xmm0, %xmm0 +; AVX2-NEXT: vpextrd $1, %xmm0, %eax +; AVX2-NEXT: vmovd %xmm0, %ecx +; AVX2-NEXT: vpextrd $2, %xmm0, %edx +; AVX2-NEXT: orl %eax, %edx +; AVX2-NEXT: orl %ecx, %edx +; AVX2-NEXT: testb $1, %dl +; AVX2-NEXT: je .LBB27_2 +; AVX2-NEXT: # %bb.1: +; AVX2-NEXT: xorl %eax, %eax +; AVX2-NEXT: retq +; AVX2-NEXT: .LBB27_2: +; AVX2-NEXT: movl $1, %eax +; AVX2-NEXT: retq +; +; AVX512-LABEL: mask_v3i1: +; AVX512: # %bb.0: +; AVX512: vpcmpneqd %{{.}}mm1, %{{.}}mm0, %k0 +; AVX512-NEXT: kshiftrw $2, %k0, %k1 +; AVX512-NEXT: korw %k1, %k0, %k1 +; AVX512-NEXT: kshiftrw $1, %k0, %k0 +; AVX512-NEXT: korw %k0, %k1, %k0 +; AVX512-NEXT: kmovd %k0, %eax +; AVX512-NEXT: testb $1, %al +; AVX512-NEXT: je .LBB27_2 +; AVX512-NEXT: # %bb.1: +; AVX512-NEXT: xorl %eax, %eax +; AVX512: retq +; AVX512-NEXT: .LBB27_2: +; AVX512-NEXT: movl $1, %eax +; AVX512: retq + %1 = icmp ne <3 x i32> %a, %b + %2 = call i1 @llvm.vector.reduce.or.v3i1(<3 x i1> %1) + br i1 %2, label %3, label %4 +3: + ret i32 0 +4: + ret i32 1 +} + declare i64 @llvm.vector.reduce.or.v2i64(<2 x i64>) declare i64 @llvm.vector.reduce.or.v4i64(<4 x i64>) declare i64 @llvm.vector.reduce.or.v8i64(<8 x i64>) @@ -1068,3 +1172,5 @@ declare i8 @llvm.vector.reduce.or.v32i8(<32 x i8>) declare i8 @llvm.vector.reduce.or.v64i8(<64 x i8>) declare i8 @llvm.vector.reduce.or.v128i8(<128 x i8>) + +declare i1 @llvm.vector.reduce.or.v3i1(<3 x i1>)