Index: lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCalls.cpp +++ lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -736,7 +736,8 @@ return Builder.CreateInsertElement(Dst, Res, (uint64_t)0); } -static Value *simplifyX86movmsk(const IntrinsicInst &II) { +static Value *simplifyX86movmsk(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { Value *Arg = II.getArgOperand(0); Type *ResTy = II.getType(); Type *ArgTy = Arg->getType(); @@ -749,29 +750,46 @@ if (!ArgTy->isVectorTy()) return nullptr; - auto *C = dyn_cast(Arg); - if (!C) - return nullptr; + if (auto *C = dyn_cast(Arg)) { + // Extract signbits of the vector input and pack into integer result. + APInt Result(ResTy->getPrimitiveSizeInBits(), 0); + for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) { + auto *COp = C->getAggregateElement(I); + if (!COp) + return nullptr; + if (isa(COp)) + continue; - // Extract signbits of the vector input and pack into integer result. - APInt Result(ResTy->getPrimitiveSizeInBits(), 0); - for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) { - auto *COp = C->getAggregateElement(I); - if (!COp) - return nullptr; - if (isa(COp)) - continue; + auto *CInt = dyn_cast(COp); + auto *CFp = dyn_cast(COp); + if (!CInt && !CFp) + return nullptr; - auto *CInt = dyn_cast(COp); - auto *CFp = dyn_cast(COp); - if (!CInt && !CFp) - return nullptr; + if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative())) + Result.setBit(I); + } + return Constant::getIntegerValue(ResTy, Result); + } - if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative())) - Result.setBit(I); + // If the argument is bitcast, look through that, but make sure the source of + // that bitcast is still a vector with the same number of elements. + // TODO: We can also convert a bitcast with wider elements, but that requires + // duplicating the bool source sign bits to match the number of elements + // expected by the movmsk call. + Arg = peekThroughBitcast(Arg); + Value *BoolSrc; + if (Arg->getType()->isVectorTy() && + Arg->getType()->getVectorNumElements() == ArgTy->getVectorNumElements() && + match(Arg, m_SExt(m_Value(BoolSrc))) && + BoolSrc->getType()->getScalarSizeInBits() == 1) { + // call iM movmsk(sext X) --> zext (bitcast X to iN) to iM + unsigned NumElts = BoolSrc->getType()->getVectorNumElements(); + Type *ScalarTy = Type::getIntNTy(Arg->getContext(), NumElts); + Value *BC = Builder.CreateBitCast(BoolSrc, ScalarTy); + return Builder.CreateZExtOrTrunc(BC, ResTy); } - return Constant::getIntegerValue(ResTy, Result); + return nullptr; } static Value *simplifyX86insertps(const IntrinsicInst &II, @@ -2543,7 +2561,7 @@ case Intrinsic::x86_avx_movmsk_pd_256: case Intrinsic::x86_avx_movmsk_ps_256: case Intrinsic::x86_avx2_pmovmskb: - if (Value *V = simplifyX86movmsk(*II)) + if (Value *V = simplifyX86movmsk(*II, Builder)) return replaceInstUsesWith(*II, V); break; Index: test/Transforms/InstCombine/X86/x86-movmsk.ll =================================================================== --- test/Transforms/InstCombine/X86/x86-movmsk.ll +++ test/Transforms/InstCombine/X86/x86-movmsk.ll @@ -315,10 +315,9 @@ define i32 @sext_sse_movmsk_ps(<4 x i1> %x) { ; CHECK-LABEL: @sext_sse_movmsk_ps( -; CHECK-NEXT: [[SEXT:%.*]] = sext <4 x i1> [[X:%.*]] to <4 x i32> -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x i32> [[SEXT]] to <4 x float> -; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]]) -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i1> [[X:%.*]] to i4 +; CHECK-NEXT: [[TMP2:%.*]] = zext i4 [[TMP1]] to i32 +; CHECK-NEXT: ret i32 [[TMP2]] ; %sext = sext <4 x i1> %x to <4 x i32> %bc = bitcast <4 x i32> %sext to <4 x float> @@ -328,10 +327,9 @@ define i32 @sext_sse2_movmsk_pd(<2 x i1> %x) { ; CHECK-LABEL: @sext_sse2_movmsk_pd( -; CHECK-NEXT: [[SEXT:%.*]] = sext <2 x i1> [[X:%.*]] to <2 x i64> -; CHECK-NEXT: [[BC:%.*]] = bitcast <2 x i64> [[SEXT]] to <2 x double> -; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.x86.sse2.movmsk.pd(<2 x double> [[BC]]) -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i1> [[X:%.*]] to i2 +; CHECK-NEXT: [[TMP2:%.*]] = zext i2 [[TMP1]] to i32 +; CHECK-NEXT: ret i32 [[TMP2]] ; %sext = sext <2 x i1> %x to <2 x i64> %bc = bitcast <2 x i64> %sext to <2 x double> @@ -341,9 +339,9 @@ define i32 @sext_sse2_pmovmskb_128(<16 x i1> %x) { ; CHECK-LABEL: @sext_sse2_pmovmskb_128( -; CHECK-NEXT: [[SEXT:%.*]] = sext <16 x i1> [[X:%.*]] to <16 x i8> -; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.x86.sse2.pmovmskb.128(<16 x i8> [[SEXT]]) -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <16 x i1> [[X:%.*]] to i16 +; CHECK-NEXT: [[TMP2:%.*]] = zext i16 [[TMP1]] to i32 +; CHECK-NEXT: ret i32 [[TMP2]] ; %sext = sext <16 x i1> %x to <16 x i8> %r = call i32 @llvm.x86.sse2.pmovmskb.128(<16 x i8> %sext) @@ -352,10 +350,9 @@ define i32 @sext_avx_movmsk_ps_256(<8 x i1> %x) { ; CHECK-LABEL: @sext_avx_movmsk_ps_256( -; CHECK-NEXT: [[SEXT:%.*]] = sext <8 x i1> [[X:%.*]] to <8 x i32> -; CHECK-NEXT: [[BC:%.*]] = bitcast <8 x i32> [[SEXT]] to <8 x float> -; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.x86.avx.movmsk.ps.256(<8 x float> [[BC]]) -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8 +; CHECK-NEXT: [[TMP2:%.*]] = zext i8 [[TMP1]] to i32 +; CHECK-NEXT: ret i32 [[TMP2]] ; %sext = sext <8 x i1> %x to <8 x i32> %bc = bitcast <8 x i32> %sext to <8 x float> @@ -365,10 +362,9 @@ define i32 @sext_avx_movmsk_pd_256(<4 x i1> %x) { ; CHECK-LABEL: @sext_avx_movmsk_pd_256( -; CHECK-NEXT: [[SEXT:%.*]] = sext <4 x i1> [[X:%.*]] to <4 x i64> -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x i64> [[SEXT]] to <4 x double> -; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.x86.avx.movmsk.pd.256(<4 x double> [[BC]]) -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i1> [[X:%.*]] to i4 +; CHECK-NEXT: [[TMP2:%.*]] = zext i4 [[TMP1]] to i32 +; CHECK-NEXT: ret i32 [[TMP2]] ; %sext = sext <4 x i1> %x to <4 x i64> %bc = bitcast <4 x i64> %sext to <4 x double> @@ -378,15 +374,60 @@ define i32 @sext_avx2_pmovmskb(<32 x i1> %x) { ; CHECK-LABEL: @sext_avx2_pmovmskb( -; CHECK-NEXT: [[SEXT:%.*]] = sext <32 x i1> [[X:%.*]] to <32 x i8> -; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.x86.avx2.pmovmskb(<32 x i8> [[SEXT]]) -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <32 x i1> [[X:%.*]] to i32 +; CHECK-NEXT: ret i32 [[TMP1]] ; %sext = sext <32 x i1> %x to <32 x i8> %r = call i32 @llvm.x86.avx2.pmovmskb(<32 x i8> %sext) ret i32 %r } +; Negative test - bitcast from scalar. + +define i32 @sext_sse_movmsk_ps_scalar_source(i1 %x) { +; CHECK-LABEL: @sext_sse_movmsk_ps_scalar_source( +; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[X:%.*]] to i128 +; CHECK-NEXT: [[BC:%.*]] = bitcast i128 [[SEXT]] to <4 x float> +; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]]) +; CHECK-NEXT: ret i32 [[R]] +; + %sext = sext i1 %x to i128 + %bc = bitcast i128 %sext to <4 x float> + %r = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> %bc) + ret i32 %r +} + +; Negative test - bitcast from vector type with more elements. + +define i32 @sext_sse_movmsk_ps_too_many_elts(<8 x i1> %x) { +; CHECK-LABEL: @sext_sse_movmsk_ps_too_many_elts( +; CHECK-NEXT: [[SEXT:%.*]] = sext <8 x i1> [[X:%.*]] to <8 x i16> +; CHECK-NEXT: [[BC:%.*]] = bitcast <8 x i16> [[SEXT]] to <4 x float> +; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]]) +; CHECK-NEXT: ret i32 [[R]] +; + %sext = sext <8 x i1> %x to <8 x i16> + %bc = bitcast <8 x i16> %sext to <4 x float> + %r = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> %bc) + ret i32 %r +} + +; TODO: We could handle this by doing a bitcasted sign-bit test after the sext? +; But need to make sure the backend handles that correctly. + +define i32 @sext_sse_movmsk_ps_must_replicate_bits(<2 x i1> %x) { +; CHECK-LABEL: @sext_sse_movmsk_ps_must_replicate_bits( +; CHECK-NEXT: [[SEXT:%.*]] = sext <2 x i1> [[X:%.*]] to <2 x i64> +; CHECK-NEXT: [[BC:%.*]] = bitcast <2 x i64> [[SEXT]] to <4 x float> +; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]]) +; CHECK-NEXT: ret i32 [[R]] +; + %sext = sext <2 x i1> %x to <2 x i64> + %bc = bitcast <2 x i64> %sext to <4 x float> + %r = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> %bc) + ret i32 %r +} + declare i32 @llvm.x86.mmx.pmovmskb(x86_mmx) declare i32 @llvm.x86.sse.movmsk.ps(<4 x float>)