diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -746,6 +746,45 @@ return Builder.CreateICmp(NewPred, Input, RangeEnd); } +// (or (icmp eq X, 0), (icmp eq X, Pow2OrZero)) +// -> (icmp eq (and X, Pow2OrZero), X) +// (and (icmp ne X, 0), (icmp ne X, Pow2OrZero)) +// -> (icmp ne (and X, Pow2OrZero), X) +static Value *foldAndOrOfICmpsWithPow2AndWithZero( + InstCombiner::BuilderTy &Builder, InstCombinerImpl::ICmpComponents LHS, + InstCombinerImpl::ICmpComponents RHS, bool IsAnd, const SimplifyQuery &Q) { + CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ; + // Make sure we have right compares for out op. + if (LHS.getPredicate() != Pred || RHS.getPredicate() != Pred) + return nullptr; + + // Make it so we can match LHS against the (icmp eq/ne X, 0) just for + // simplicity. + if (match(RHS.getOperand(1), m_Zero())) + std::swap(LHS, RHS); + + Value *Pow2, *Op; + // Match the desired pattern: + // LHS: (icmp eq/ne X, 0) + // RHS: (icmp eq/ne X, Pow2OrZero) + Op = LHS.getOperand(0); + if (!match(LHS.getOperand(1), m_Zero())) + return nullptr; + if (RHS.getOperand(0) == Op) + Pow2 = RHS.getOperand(1); + else if (RHS.getOperand(1) == Op) + Pow2 = RHS.getOperand(0); + else + return nullptr; + + if (!isKnownToBeAPowerOfTwo(Pow2, Q.DL, /*OrZero*/ true, /*Depth*/ 0, Q.AC, + Q.CxtI, Q.DT)) + return nullptr; + + Value *And = Builder.CreateAnd(Op, Pow2); + return Builder.CreateICmp(Pred, And, Op); +} + // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2( @@ -3114,9 +3153,15 @@ if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, &I, IsAnd, IsLogical)) return V; + if (!IsLogical) + if (Value *V = + foldAndOrOfICmpsWithPow2AndWithZero(Builder, LHS, RHS, IsAnd, Q)) + return V; + ICmpInst::Predicate PredL = LHS.getPredicate(), PredR = RHS.getPredicate(); Value *LHS0 = LHS.getOperand(0), *RHS0 = RHS.getOperand(0); Value *LHS1 = LHS.getOperand(1), *RHS1 = RHS.getOperand(1); + const APInt *LHSC = nullptr, *RHSC = nullptr; match(LHS1, m_APInt(LHSC)); match(RHS1, m_APInt(RHSC)); diff --git a/llvm/test/Transforms/InstCombine/and-or-icmps.ll b/llvm/test/Transforms/InstCombine/and-or-icmps.ll --- a/llvm/test/Transforms/InstCombine/and-or-icmps.ll +++ b/llvm/test/Transforms/InstCombine/and-or-icmps.ll @@ -3051,9 +3051,8 @@ ; CHECK-LABEL: @icmp_eq_or_z_or_pow2orz( ; CHECK-NEXT: [[NY:%.*]] = sub i8 0, [[Y:%.*]] ; CHECK-NEXT: [[POW2ORZ:%.*]] = and i8 [[NY]], [[Y]] -; CHECK-NEXT: [[C0:%.*]] = icmp eq i8 [[X:%.*]], 0 -; CHECK-NEXT: [[CP2:%.*]] = icmp eq i8 [[POW2ORZ]], [[X]] -; CHECK-NEXT: [[R:%.*]] = or i1 [[C0]], [[CP2]] +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[POW2ORZ]], [[X:%.*]] +; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[TMP1]], [[X]] ; CHECK-NEXT: ret i1 [[R]] ; %ny = sub i8 0, %y @@ -3089,9 +3088,8 @@ ; CHECK-LABEL: @icmp_ne_and_z_and_pow2orz( ; CHECK-NEXT: [[NY:%.*]] = sub <2 x i8> zeroinitializer, [[Y:%.*]] ; CHECK-NEXT: [[POW2ORZ:%.*]] = and <2 x i8> [[NY]], [[Y]] -; CHECK-NEXT: [[C0:%.*]] = icmp ne <2 x i8> [[X:%.*]], zeroinitializer -; CHECK-NEXT: [[CP2:%.*]] = icmp ne <2 x i8> [[POW2ORZ]], [[X]] -; CHECK-NEXT: [[R:%.*]] = and <2 x i1> [[C0]], [[CP2]] +; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[POW2ORZ]], [[X:%.*]] +; CHECK-NEXT: [[R:%.*]] = icmp ne <2 x i8> [[TMP1]], [[X]] ; CHECK-NEXT: ret <2 x i1> [[R]] ; %ny = sub <2 x i8> zeroinitializer, %y