diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5863,8 +5863,7 @@ /// If one operand of an icmp is effectively a bool (value range of {0,1}), /// then try to reduce patterns based on that limit. -static Instruction *foldICmpUsingBoolRange(ICmpInst &I, - InstCombiner::BuilderTy &Builder) { +Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) { Value *X, *Y; ICmpInst::Predicate Pred; @@ -5880,6 +5879,59 @@ Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE) return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y); + const APInt *C; + if (match(I.getOperand(0), + m_OneUse(m_c_Add(m_ZExt(m_Value(X)), m_SExt(m_Value(Y))))) && + match(I.getOperand(1), m_APInt(C)) && + X->getType()->isIntOrIntVectorTy(1) && + Y->getType()->isIntOrIntVectorTy(1)) { + unsigned BitWidth = C->getBitWidth(); + Pred = I.getPredicate(); + APInt Zero = APInt::getZero(BitWidth); + APInt MinusOne = APInt::getAllOnes(BitWidth); + APInt One(BitWidth, 1); + if ((C->sgt(Zero) && Pred == ICmpInst::ICMP_SGT) || + (C->slt(Zero) && Pred == ICmpInst::ICMP_SLT)) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if ((C->sgt(One) && Pred == ICmpInst::ICMP_SLT) || + (C->slt(MinusOne) && Pred == ICmpInst::ICMP_SGT)) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + + APInt NewC = *C; + // canonicalize predicate to eq/ne + if ((*C == Zero && Pred == ICmpInst::ICMP_SLT) || + (*C != Zero && *C != MinusOne && Pred == ICmpInst::ICMP_UGT)) { + // x s< 0 in [-1, 1] --> x == -1 + // x u> 1(or any const !=0 !=-1) in [-1, 1] --> x == -1 + NewC = MinusOne; + Pred = ICmpInst::ICMP_EQ; + } else if ((*C == MinusOne && Pred == ICmpInst::ICMP_SGT) || + (*C != Zero && *C != One && Pred == ICmpInst::ICMP_ULT)) { + // x s> -1 in [-1, 1] --> x != -1 + // x u< -1 in [-1, 1] --> x != -1 + Pred = ICmpInst::ICMP_NE; + } else if (*C == Zero && Pred == ICmpInst::ICMP_SGT) { + // x s> 0 in [-1, 1] --> x == 1 + NewC = One; + Pred = ICmpInst::ICMP_EQ; + } else if (*C == One && Pred == ICmpInst::ICMP_SLT) { + // x s< 1 in [-1, 1] --> x != 1 + Pred = ICmpInst::ICMP_NE; + } + + if (NewC == MinusOne) { + if (Pred == ICmpInst::ICMP_EQ) + return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y); + if (Pred == ICmpInst::ICMP_NE) + return BinaryOperator::CreateOr(X, Builder.CreateNot(Y)); + } else if (NewC == One) { + if (Pred == ICmpInst::ICMP_EQ) + return BinaryOperator::CreateAnd(X, Builder.CreateNot(Y)); + if (Pred == ICmpInst::ICMP_NE) + return BinaryOperator::CreateOr(Builder.CreateNot(X), Y); + } + } + return nullptr; } @@ -6335,7 +6387,7 @@ if (Instruction *Res = foldICmpWithDominatingICmp(I)) return Res; - if (Instruction *Res = foldICmpUsingBoolRange(I, Builder)) + if (Instruction *Res = foldICmpUsingBoolRange(I)) return Res; if (Instruction *Res = foldICmpUsingKnownBits(I)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -565,6 +565,7 @@ Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp); Instruction *foldICmpWithConstant(ICmpInst &Cmp); + Instruction *foldICmpUsingBoolRange(ICmpInst &I); Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, diff --git a/llvm/test/Transforms/InstCombine/icmp-range.ll b/llvm/test/Transforms/InstCombine/icmp-range.ll --- a/llvm/test/Transforms/InstCombine/icmp-range.ll +++ b/llvm/test/Transforms/InstCombine/icmp-range.ll @@ -629,13 +629,11 @@ ret i1 %r } +; (zext i1 a) + (sext i1 b)) s< -1 --> false + define i1 @zext_sext_add_icmp_slt_minus1(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_slt_minus1( -; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i8 -; CHECK-NEXT: [[SEXT_B:%.*]] = sext i1 [[B:%.*]] to i8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i8 [[ZEXT_A]], [[SEXT_B]] -; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[ADD]], -1 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %zext.a = zext i1 %a to i8 %sext.b = sext i1 %b to i8 @@ -644,13 +642,11 @@ ret i1 %r } +; (zext i1 a) + (sext i1 b)) s> 1 --> false + define i1 @zext_sext_add_icmp_sgt_1(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_sgt_1( -; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i8 -; CHECK-NEXT: [[SEXT_B:%.*]] = sext i1 [[B:%.*]] to i8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i8 [[ZEXT_A]], [[SEXT_B]] -; CHECK-NEXT: [[R:%.*]] = icmp sgt i8 [[ADD]], 1 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %zext.a = zext i1 %a to i8 %sext.b = sext i1 %b to i8 @@ -659,13 +655,11 @@ ret i1 %r } +; (zext i1 a) + (sext i1 b)) s> -2 --> true + define i1 @zext_sext_add_icmp_sgt_minus2(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_sgt_minus2( -; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i8 -; CHECK-NEXT: [[SEXT_B:%.*]] = sext i1 [[B:%.*]] to i8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i8 [[ZEXT_A]], [[SEXT_B]] -; CHECK-NEXT: [[R:%.*]] = icmp sgt i8 [[ADD]], -2 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 true ; %zext.a = zext i1 %a to i8 %sext.b = sext i1 %b to i8 @@ -674,13 +668,11 @@ ret i1 %r } +; (zext i1 a) + (sext i1 b)) s< 2 --> true + define i1 @zext_sext_add_icmp_slt_2(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_slt_2( -; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i8 -; CHECK-NEXT: [[SEXT_B:%.*]] = sext i1 [[B:%.*]] to i8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i8 [[ZEXT_A]], [[SEXT_B]] -; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[ADD]], 2 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 true ; %zext.a = zext i1 %a to i8 %sext.b = sext i1 %b to i8 @@ -689,13 +681,11 @@ ret i1 %r } +; test case with i128 + define i1 @zext_sext_add_icmp_i128(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_i128( -; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i128 -; CHECK-NEXT: [[SEXT_B:%.*]] = sext i1 [[B:%.*]] to i128 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i128 [[ZEXT_A]], [[SEXT_B]] -; CHECK-NEXT: [[R:%.*]] = icmp sgt i128 [[ADD]], 9223372036854775808 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %zext.a = zext i1 %a to i128 %sext.b = sext i1 %b to i128 @@ -704,12 +694,12 @@ ret i1 %r } +; (zext i1 a) + (sext i1 b)) == -1 --> ~a & b + define i1 @zext_sext_add_icmp_eq_minus1(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_eq_minus1( -; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i8 -; CHECK-NEXT: [[SEXT_B:%.*]] = sext i1 [[B:%.*]] to i8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i8 [[ZEXT_A]], [[SEXT_B]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[ADD]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[A:%.*]], true +; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[B:%.*]] ; CHECK-NEXT: ret i1 [[R]] ; %zext.a = zext i1 %a to i8 @@ -719,12 +709,13 @@ ret i1 %r } + +; (zext i1 a) + (sext i1 b)) != -1 --> a | ~b + define i1 @zext_sext_add_icmp_ne_minus1(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_ne_minus1( -; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i8 -; CHECK-NEXT: [[SEXT_B:%.*]] = sext i1 [[B:%.*]] to i8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i8 [[ZEXT_A]], [[SEXT_B]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[ADD]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[B:%.*]], true +; CHECK-NEXT: [[R:%.*]] = or i1 [[TMP1]], [[A:%.*]] ; CHECK-NEXT: ret i1 [[R]] ; %zext.a = zext i1 %a to i8 @@ -734,10 +725,12 @@ ret i1 %r } +; (zext i1 a) + (sext i1 b)) s> -1 --> a | ~b + define i1 @zext_sext_add_icmp_sgt_minus1(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_sgt_minus1( -; CHECK-NEXT: [[B_NOT:%.*]] = xor i1 [[B:%.*]], true -; CHECK-NEXT: [[R:%.*]] = or i1 [[B_NOT]], [[A:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[B:%.*]], true +; CHECK-NEXT: [[R:%.*]] = or i1 [[TMP1]], [[A:%.*]] ; CHECK-NEXT: ret i1 [[R]] ; %zext.a = zext i1 %a to i8 @@ -747,12 +740,12 @@ ret i1 %r } +; (zext i1 a) + (sext i1 b)) u< -1 --> a | ~b + define i1 @zext_sext_add_icmp_ult_minus1(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_ult_minus1( -; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i8 -; CHECK-NEXT: [[SEXT_B:%.*]] = sext i1 [[B:%.*]] to i8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i8 [[ZEXT_A]], [[SEXT_B]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[ADD]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[B:%.*]], true +; CHECK-NEXT: [[R:%.*]] = or i1 [[TMP1]], [[A:%.*]] ; CHECK-NEXT: ret i1 [[R]] ; %zext.a = zext i1 %a to i8 @@ -762,12 +755,12 @@ ret i1 %r } +; (zext i1 a) + (sext i1 b)) s> 0 --> a & ~b + define i1 @zext_sext_add_icmp_sgt_0(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_sgt_0( -; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i8 -; CHECK-NEXT: [[SEXT_B:%.*]] = sext i1 [[B:%.*]] to i8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i8 [[ZEXT_A]], [[SEXT_B]] -; CHECK-NEXT: [[R:%.*]] = icmp sgt i8 [[ADD]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[B:%.*]], true +; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[A:%.*]] ; CHECK-NEXT: ret i1 [[R]] ; %zext.a = zext i1 %a to i8 @@ -777,11 +770,13 @@ ret i1 %r } +; (zext i1 a) + (sext i1 b)) s< 0 --> ~a & b + define i1 @zext_sext_add_icmp_slt_0(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_slt_0( ; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[A:%.*]], true -; CHECK-NEXT: [[TMP2:%.*]] = and i1 [[TMP1]], [[B:%.*]] -; CHECK-NEXT: ret i1 [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[B:%.*]] +; CHECK-NEXT: ret i1 [[R]] ; %zext.a = zext i1 %a to i8 %sext.b = sext i1 %b to i8 @@ -790,12 +785,12 @@ ret i1 %r } +; (zext i1 a) + (sext i1 b)) == 1 --> a & ~b + define i1 @zext_sext_add_icmp_eq_1(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_eq_1( -; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i8 -; CHECK-NEXT: [[SEXT_B:%.*]] = sext i1 [[B:%.*]] to i8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i8 [[ZEXT_A]], [[SEXT_B]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[ADD]], 1 +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[B:%.*]], true +; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[A:%.*]] ; CHECK-NEXT: ret i1 [[R]] ; %zext.a = zext i1 %a to i8 @@ -805,12 +800,12 @@ ret i1 %r } +; (zext i1 a) + (sext i1 b)) != 1 --> ~a | b + define i1 @zext_sext_add_icmp_ne_1(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_ne_1( -; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i8 -; CHECK-NEXT: [[SEXT_B:%.*]] = sext i1 [[B:%.*]] to i8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i8 [[ZEXT_A]], [[SEXT_B]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[ADD]], 1 +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[A:%.*]], true +; CHECK-NEXT: [[R:%.*]] = or i1 [[TMP1]], [[B:%.*]] ; CHECK-NEXT: ret i1 [[R]] ; %zext.a = zext i1 %a to i8 @@ -820,12 +815,12 @@ ret i1 %r } +; (zext i1 a) + (sext i1 b)) s< 1 --> ~a | b + define i1 @zext_sext_add_icmp_slt_1(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_slt_1( -; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i8 -; CHECK-NEXT: [[SEXT_B:%.*]] = sext i1 [[B:%.*]] to i8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i8 [[ZEXT_A]], [[SEXT_B]] -; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[ADD]], 1 +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[A:%.*]], true +; CHECK-NEXT: [[R:%.*]] = or i1 [[TMP1]], [[B:%.*]] ; CHECK-NEXT: ret i1 [[R]] ; %zext.a = zext i1 %a to i8 @@ -835,11 +830,13 @@ ret i1 %r } +; (zext i1 a) + (sext i1 b)) u> 1 --> ~a & b + define i1 @zext_sext_add_icmp_ugt_1(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_ugt_1( ; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[A:%.*]], true -; CHECK-NEXT: [[TMP2:%.*]] = and i1 [[TMP1]], [[B:%.*]] -; CHECK-NEXT: ret i1 [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[B:%.*]] +; CHECK-NEXT: ret i1 [[R]] ; %zext.a = zext i1 %a to i8 %sext.b = sext i1 %b to i8 @@ -850,10 +847,8 @@ define <2 x i1> @vector_zext_sext_add_icmp_slt_1(<2 x i1> %a, <2 x i1> %b) { ; CHECK-LABEL: @vector_zext_sext_add_icmp_slt_1( -; CHECK-NEXT: [[ZEXT_A:%.*]] = zext <2 x i1> [[A:%.*]] to <2 x i8> -; CHECK-NEXT: [[SEXT_B:%.*]] = sext <2 x i1> [[B:%.*]] to <2 x i8> -; CHECK-NEXT: [[ADD:%.*]] = add nsw <2 x i8> [[ZEXT_A]], [[SEXT_B]] -; CHECK-NEXT: [[R:%.*]] = icmp slt <2 x i8> [[ADD]], +; CHECK-NEXT: [[TMP1:%.*]] = xor <2 x i1> [[A:%.*]], +; CHECK-NEXT: [[R:%.*]] = or <2 x i1> [[TMP1]], [[B:%.*]] ; CHECK-NEXT: ret <2 x i1> [[R]] ; %zext.a = zext <2 x i1> %a to <2 x i8> @@ -878,6 +873,8 @@ ret <2 x i1> %r } +; Negative test, more than one use for icmp LHS + define i1 @zext_sext_add_icmp_slt_1_no_oneuse(i1 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_slt_1_no_oneuse( ; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i8 @@ -895,6 +892,8 @@ ret i1 %r } +; Negative test, icmp RHS is not a constant + define i1 @zext_sext_add_icmp_slt_1_rhs_not_const(i1 %a, i1 %b, i8 %c) { ; CHECK-LABEL: @zext_sext_add_icmp_slt_1_rhs_not_const( ; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i1 [[A:%.*]] to i8 @@ -910,6 +909,8 @@ ret i1 %r } +; Negative test, ext source is not i1 + define i1 @zext_sext_add_icmp_slt_1_type_not_i1(i2 %a, i1 %b) { ; CHECK-LABEL: @zext_sext_add_icmp_slt_1_type_not_i1( ; CHECK-NEXT: [[ZEXT_A:%.*]] = zext i2 [[A:%.*]] to i8