diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -911,6 +911,21 @@ CxtI.getName() + ".simplified"); } +/// Fold (icmp eq ctpop(X) 1) | (icmp eq X 0) into (icmp ult ctpop(X) 2). +static Value *foldOrOfCtpop(ICmpInst *Cmp0, ICmpInst *Cmp1, + InstCombiner::BuilderTy &Builder) { + CmpInst::Predicate Pred0, Pred1; + Value *X; + if (match(Cmp0, m_ICmp(Pred0, m_Intrinsic(m_Value(X)), + m_SpecificInt(1))) && + match(Cmp1, m_ICmp(Pred1, m_Specific(X), m_ZeroInt())) && + Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_EQ) { + Value *CtPop = Cmp0->getOperand(0); + return Builder.CreateICmpULT(CtPop, ConstantInt::get(CtPop->getType(), 2)); + } + return nullptr; +} + /// Reduce a pair of compares that check if a value has exactly 1 bit set. static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd, InstCombiner::BuilderTy &Builder) { @@ -2595,6 +2610,11 @@ if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, Or, Builder, Q)) return V; + if (Value *V = foldOrOfCtpop(LHS, RHS, Builder)) + return V; + if (Value *V = foldOrOfCtpop(RHS, LHS, Builder)) + return V; + // E.g. (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/true)) return V; diff --git a/llvm/test/Transforms/InstCombine/icmp-or.ll b/llvm/test/Transforms/InstCombine/icmp-or.ll --- a/llvm/test/Transforms/InstCombine/icmp-or.ll +++ b/llvm/test/Transforms/InstCombine/icmp-or.ll @@ -363,3 +363,37 @@ %r = icmp sgt i8 %or, -1 ret i1 %r } + +declare i32 @llvm.ctpop.i32(i32) + +; (icmp eq ctpop(X) 1) | (icmp eq X 0) --> icmp ult ctpop(X) 2 + +define i1 @ctpop_or_eq_to_ctpop(i32 %0) { +; CHECK-LABEL: @ctpop_or_eq_to_ctpop( +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.ctpop.i32(i32 [[TMP0:%.*]]), !range [[RNG0:![0-9]+]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp ult i32 [[TMP2]], 2 +; CHECK-NEXT: ret i1 [[TMP3]] +; + %2 = call i32 @llvm.ctpop.i32(i32 %0) + %3 = icmp eq i32 %2, 1 + %4 = icmp eq i32 %0, 0 + %5 = or i1 %4, %3 + ret i1 %5 +} + +; negative test - wrong constants + +define i1 @not_ctpop_or_eq_to_ctpop(i32 %0) { +; CHECK-LABEL: @not_ctpop_or_eq_to_ctpop( +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.ctpop.i32(i32 [[TMP0:%.*]]), !range [[RNG0]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i32 [[TMP2]], 2 +; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i32 [[TMP0]], 0 +; CHECK-NEXT: [[TMP5:%.*]] = or i1 [[TMP4]], [[TMP3]] +; CHECK-NEXT: ret i1 [[TMP5]] +; + %2 = call i32 @llvm.ctpop.i32(i32 %0) + %3 = icmp eq i32 %2, 2 + %4 = icmp eq i32 %0, 0 + %5 = or i1 %4, %3 + ret i1 %5 +} diff --git a/llvm/test/Transforms/InstCombine/ispow2.ll b/llvm/test/Transforms/InstCombine/ispow2.ll --- a/llvm/test/Transforms/InstCombine/ispow2.ll +++ b/llvm/test/Transforms/InstCombine/ispow2.ll @@ -540,10 +540,8 @@ define i1 @isnot_pow2_ctpop_wrong_pred1(i32 %x) { ; CHECK-LABEL: @isnot_pow2_ctpop_wrong_pred1( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[T0]], 1 -; CHECK-NEXT: [[ISZERO:%.*]] = icmp eq i32 [[X]], 0 -; CHECK-NEXT: [[R:%.*]] = or i1 [[ISZERO]], [[CMP]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[T0]], 2 +; CHECK-NEXT: ret i1 [[TMP1]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp eq i32 %t0, 1 @@ -555,10 +553,8 @@ define i1 @isnot_pow2_ctpop_wrong_pred1_logical(i32 %x) { ; CHECK-LABEL: @isnot_pow2_ctpop_wrong_pred1_logical( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[T0]], 1 -; CHECK-NEXT: [[ISZERO:%.*]] = icmp eq i32 [[X]], 0 -; CHECK-NEXT: [[R:%.*]] = or i1 [[ISZERO]], [[CMP]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[T0]], 2 +; CHECK-NEXT: ret i1 [[TMP1]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp eq i32 %t0, 1