diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -4806,6 +4806,39 @@ return nullptr; } +// Canonicalize checking for a power-of-2-or-zero value: +static Instruction *foldICmpPow2Test(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + const CmpInst::Predicate Pred = I.getPredicate(); + Value *A = nullptr; + // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) + // ((A-1) & A) != 0 --> ctpop(A) > 1 (two commuted variants) + if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()), + m_Deferred(A)))) || + !match(Op1, m_ZeroInt())) + A = nullptr; + + // (A & -A) == A --> ctpop(A) < 2 (four commuted variants) + // (-A & A) != A --> ctpop(A) > 1 (four commuted variants) + if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1))))) + A = Op1; + else if (match(Op1, + m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0))))) + A = Op0; + + if (A) { + Type *Ty = A->getType(); + CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A); + return Pred == ICmpInst::ICMP_EQ ? new ICmpInst(ICmpInst::ICMP_ULT, CtPop, + ConstantInt::get(Ty, 2)) + : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, + ConstantInt::get(Ty, 1)); + } + + return nullptr; +} + Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { if (!I.isEquality()) return nullptr; @@ -4991,30 +5024,6 @@ if (Instruction *ICmp = foldICmpIntrinsicWithIntrinsic(I, Builder)) return ICmp; - // Canonicalize checking for a power-of-2-or-zero value: - // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) - // ((A-1) & A) != 0 --> ctpop(A) > 1 (two commuted variants) - if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()), - m_Deferred(A)))) || - !match(Op1, m_ZeroInt())) - A = nullptr; - - // (A & -A) == A --> ctpop(A) < 2 (four commuted variants) - // (-A & A) != A --> ctpop(A) > 1 (four commuted variants) - if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1))))) - A = Op1; - else if (match(Op1, - m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0))))) - A = Op0; - - if (A) { - Type *Ty = A->getType(); - CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A); - return Pred == ICmpInst::ICMP_EQ - ? new ICmpInst(ICmpInst::ICMP_ULT, CtPop, ConstantInt::get(Ty, 2)) - : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, ConstantInt::get(Ty, 1)); - } - // Match icmp eq (trunc (lshr A, BW), (ashr (trunc A), BW-1)), which checks the // top BW/2 + 1 bits are all the same. Create "A >=s INT_MIN && A <=s INT_MAX", // which we generate as "icmp ult (add A, 2^(BW-1)), 2^BW" to skip a few steps @@ -6740,6 +6749,9 @@ if (Instruction *Res = foldICmpEquality(I)) return Res; + if (Instruction *Res = foldICmpPow2Test(I, Builder)) + return Res; + if (Instruction *Res = foldICmpOfUAddOv(I)) return Res;