diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -911,6 +911,26 @@ CxtI.getName() + ".simplified"); } +/// Fold (icmp eq ctpop(X) 1) | (icmp eq X 0) into (icmp ult ctpop(X) 2) and +/// fold (icmp ne ctpop(X) 1) & (icmp ne X 0) into (icmp uge ctpop(X) 2). +static Value *foldIsPowerOf2OrZero(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd, + InstCombiner::BuilderTy &Builder) { + CmpInst::Predicate Pred0, Pred1; + Value *X; + if (!match(Cmp0, m_ICmp(Pred0, m_Intrinsic(m_Value(X)), + m_SpecificInt(1))) || + !match(Cmp1, m_ICmp(Pred1, m_Specific(X), m_ZeroInt()))) + return nullptr; + + Value *CtPop = Cmp0->getOperand(0); + if (IsAnd && Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_NE) + return Builder.CreateICmpUGE(CtPop, ConstantInt::get(CtPop->getType(), 2)); + if (!IsAnd && Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_EQ) + return Builder.CreateICmpULT(CtPop, ConstantInt::get(CtPop->getType(), 2)); + + return nullptr; +} + /// Reduce a pair of compares that check if a value has exactly 1 bit set. static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd, InstCombiner::BuilderTy &Builder) { @@ -1237,6 +1257,11 @@ if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, And, Builder, Q)) return V; + if (Value *V = foldIsPowerOf2OrZero(LHS, RHS, /*IsAnd=*/true, Builder)) + return V; + if (Value *V = foldIsPowerOf2OrZero(RHS, LHS, /*IsAnd=*/true, Builder)) + return V; + // E.g. (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/false)) return V; @@ -2595,6 +2620,11 @@ if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, Or, Builder, Q)) return V; + if (Value *V = foldIsPowerOf2OrZero(LHS, RHS, /*IsAnd=*/false, Builder)) + return V; + if (Value *V = foldIsPowerOf2OrZero(RHS, LHS, /*IsAnd=*/false, Builder)) + return V; + // E.g. (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/true)) return V; diff --git a/llvm/test/Transforms/InstCombine/icmp-or.ll b/llvm/test/Transforms/InstCombine/icmp-or.ll --- a/llvm/test/Transforms/InstCombine/icmp-or.ll +++ b/llvm/test/Transforms/InstCombine/icmp-or.ll @@ -363,3 +363,131 @@ %r = icmp sgt i8 %or, -1 ret i1 %r } + +declare i32 @llvm.ctpop.i32(i32) + +; (icmp eq ctpop(X) 1) | (icmp eq X 0) --> icmp ult ctpop(X) 2 + +define i1 @ctpop_or_eq_to_ctpop(i32 %0) { +; CHECK-LABEL: @ctpop_or_eq_to_ctpop( +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.ctpop.i32(i32 [[TMP0:%.*]]), !range [[RNG0:![0-9]+]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp ult i32 [[TMP2]], 2 +; CHECK-NEXT: ret i1 [[TMP3]] +; + %2 = call i32 @llvm.ctpop.i32(i32 %0) + %3 = icmp eq i32 %2, 1 + %4 = icmp eq i32 %0, 0 + %5 = or i1 %4, %3 + ret i1 %5 +} + +; negative test - wrong constants + +define i1 @not_ctpop_or_eq_to_ctpop(i32 %0) { +; CHECK-LABEL: @not_ctpop_or_eq_to_ctpop( +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.ctpop.i32(i32 [[TMP0:%.*]]), !range [[RNG0]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i32 [[TMP2]], 2 +; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i32 [[TMP0]], 0 +; CHECK-NEXT: [[TMP5:%.*]] = or i1 [[TMP4]], [[TMP3]] +; CHECK-NEXT: ret i1 [[TMP5]] +; + %2 = call i32 @llvm.ctpop.i32(i32 %0) + %3 = icmp eq i32 %2, 2 + %4 = icmp eq i32 %0, 0 + %5 = or i1 %4, %3 + ret i1 %5 +} + +declare <2 x i32> @llvm.ctpop.v2i32(<2 x i32>) + +define <2 x i1> @ctpop_or_eq_to_ctpop_vec(<2 x i32> %0) { +; CHECK-LABEL: @ctpop_or_eq_to_ctpop_vec( +; CHECK-NEXT: [[TMP2:%.*]] = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> [[TMP0:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp ult <2 x i32> [[TMP2]], +; CHECK-NEXT: ret <2 x i1> [[TMP3]] +; + %2 = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> %0) + %3 = icmp eq <2 x i32> %2, + %4 = icmp eq <2 x i32> %0, + %5 = or <2 x i1> %4, %3 + ret <2 x i1> %5 +} + +; negative test - wrong constants + +define <2 x i1> @not_ctpop_or_eq_to_ctpop_vec(<2 x i32> %0) { +; CHECK-LABEL: @not_ctpop_or_eq_to_ctpop_vec( +; CHECK-NEXT: [[TMP2:%.*]] = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> [[TMP0:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq <2 x i32> [[TMP2]], +; CHECK-NEXT: [[TMP4:%.*]] = icmp eq <2 x i32> [[TMP0]], zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = or <2 x i1> [[TMP4]], [[TMP3]] +; CHECK-NEXT: ret <2 x i1> [[TMP5]] +; + %2 = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> %0) + %3 = icmp eq <2 x i32> %2, + %4 = icmp eq <2 x i32> %0, + %5 = or <2 x i1> %4, %3 + ret <2 x i1> %5 +} + +; (icmp ne ctpop(X) 1) & (icmp ne X 0) --> icmp uge ctpop(X) 2 + +define i1 @ctpop_and_ne_to_ctpop(i32 %0) { +; CHECK-LABEL: @ctpop_and_ne_to_ctpop( +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.ctpop.i32(i32 [[TMP0:%.*]]), !range [[RNG0]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp ugt i32 [[TMP2]], 1 +; CHECK-NEXT: ret i1 [[TMP3]] +; + %2 = call i32 @llvm.ctpop.i32(i32 %0) + %3 = icmp ne i32 %2, 1 + %4 = icmp ne i32 %0, 0 + %5 = and i1 %4, %3 + ret i1 %5 +} + +; negative test - wrong constants + +define i1 @not_ctpop_and_ne_to_ctpop(i32 %0) { +; CHECK-LABEL: @not_ctpop_and_ne_to_ctpop( +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.ctpop.i32(i32 [[TMP0:%.*]]), !range [[RNG0]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i32 [[TMP2]], 2 +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i32 [[TMP0]], 0 +; CHECK-NEXT: [[TMP5:%.*]] = and i1 [[TMP4]], [[TMP3]] +; CHECK-NEXT: ret i1 [[TMP5]] +; + %2 = call i32 @llvm.ctpop.i32(i32 %0) + %3 = icmp ne i32 %2, 2 + %4 = icmp ne i32 %0, 0 + %5 = and i1 %4, %3 + ret i1 %5 +} + +define <2 x i1> @ctpop_and_ne_to_ctpop_vec(<2 x i32> %0) { +; CHECK-LABEL: @ctpop_and_ne_to_ctpop_vec( +; CHECK-NEXT: [[TMP2:%.*]] = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> [[TMP0:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp ugt <2 x i32> [[TMP2]], +; CHECK-NEXT: ret <2 x i1> [[TMP3]] +; + %2 = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> %0) + %3 = icmp ne <2 x i32> %2, + %4 = icmp ne <2 x i32> %0, + %5 = and <2 x i1> %4, %3 + ret <2 x i1> %5 +} + +; negative test - wrong constants + +define <2 x i1> @not_ctpop_and_ne_to_ctpop_vec(<2 x i32> %0) { +; CHECK-LABEL: @not_ctpop_and_ne_to_ctpop_vec( +; CHECK-NEXT: [[TMP2:%.*]] = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> [[TMP0:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp ne <2 x i32> [[TMP2]], +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <2 x i32> [[TMP0]], zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = and <2 x i1> [[TMP4]], [[TMP3]] +; CHECK-NEXT: ret <2 x i1> [[TMP5]] +; + %2 = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> %0) + %3 = icmp ne <2 x i32> %2, + %4 = icmp ne <2 x i32> %0, + %5 = and <2 x i1> %4, %3 + ret <2 x i1> %5 +} diff --git a/llvm/test/Transforms/InstCombine/ispow2.ll b/llvm/test/Transforms/InstCombine/ispow2.ll --- a/llvm/test/Transforms/InstCombine/ispow2.ll +++ b/llvm/test/Transforms/InstCombine/ispow2.ll @@ -535,15 +535,11 @@ ret i1 %r } -; Negative test - wrong predicate (but this could reduce). - -define i1 @isnot_pow2_ctpop_wrong_pred1(i32 %x) { -; CHECK-LABEL: @isnot_pow2_ctpop_wrong_pred1( +define i1 @is_pow2_or_zero(i32 %x) { +; CHECK-LABEL: @is_pow2_or_zero( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[T0]], 1 -; CHECK-NEXT: [[ISZERO:%.*]] = icmp eq i32 [[X]], 0 -; CHECK-NEXT: [[R:%.*]] = or i1 [[ISZERO]], [[CMP]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[T0]], 2 +; CHECK-NEXT: ret i1 [[TMP1]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp eq i32 %t0, 1 @@ -552,13 +548,11 @@ ret i1 %r } -define i1 @isnot_pow2_ctpop_wrong_pred1_logical(i32 %x) { -; CHECK-LABEL: @isnot_pow2_ctpop_wrong_pred1_logical( +define i1 @is_pow2_or_zero_logical(i32 %x) { +; CHECK-LABEL: @is_pow2_or_zero_logical( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[T0]], 1 -; CHECK-NEXT: [[ISZERO:%.*]] = icmp eq i32 [[X]], 0 -; CHECK-NEXT: [[R:%.*]] = or i1 [[ISZERO]], [[CMP]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[T0]], 2 +; CHECK-NEXT: ret i1 [[TMP1]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp eq i32 %t0, 1