Index: lib/Transforms/InstCombine/InstCombine.h =================================================================== --- lib/Transforms/InstCombine/InstCombine.h +++ lib/Transforms/InstCombine/InstCombine.h @@ -193,6 +193,8 @@ ConstantInt *DivRHS); Instruction *FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, ConstantInt *CI1, ConstantInt *CI2); + Instruction *FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, ConstantInt *CI2); Instruction *FoldICmpAddOpCst(Instruction &ICI, Value *X, ConstantInt *CI, ICmpInst::Predicate Pred); Instruction *FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Index: lib/Transforms/InstCombine/InstCombineCompares.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCompares.cpp +++ lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1118,6 +1118,65 @@ return getConstant(false); } +/// FoldICmpCstShlCst - Handle "(icmp eq/ne (shl const2, A), const1)" -> +/// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)). +Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, + ConstantInt *CI2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getConstant = [&I, this](bool IsTrue) { + if (I.getPredicate() == I.ICMP_NE) + IsTrue = !IsTrue; + return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + }; + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + APInt AP1 = CI1->getValue(); + APInt AP2 = CI2->getValue(); + + if (!AP1) { + if (!AP2) { + // Both Constants are 0. + return getConstant(true); + } + + int zeros = AP2.countTrailingZeros(); + + if (zeros != 0) { + return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), AP2.getBitWidth() - zeros)); + } else { + return getConstant(false); + } + } + + if (!AP2) { + // Shifting 0 by any value gives 0. + return getConstant(false); + } + + if (AP1 == AP2) { + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + } + + // Get the distance between the lowest bits that are set. + int Shift = AP1.countTrailingZeros() - AP2.countTrailingZeros(); + + if (Shift > 0) { + if (AP2.shl(Shift) == AP1) { + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + } + } + + // Shifting const2 will never be equal to const1. + return getConstant(false); +} + /// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)". /// Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, @@ -2693,13 +2752,17 @@ Builder->getInt(CI->getValue()-1)); } - // (icmp eq/ne (ashr/lshr const2, A), const1) if (I.isEquality()) { ConstantInt *CI2; if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) || match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) { + // (icmp eq/ne (ashr/lshr const2, A), const1) return FoldICmpCstShrCst(I, Op0, A, CI, CI2); } + else if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) { + // (icmp eq/ne (shl const2, A), const1) + return FoldICmpCstShlCst(I, Op0, A, CI, CI2); + } } // If this comparison is a normal comparison, it demands all Index: test/Transforms/InstCombine/icmp.ll =================================================================== --- test/Transforms/InstCombine/icmp.ll +++ test/Transforms/InstCombine/icmp.ll @@ -1418,3 +1418,96 @@ %ret = icmp ne i32 %and, 0 ret i1 %ret } + +; CHECK-LABEL: @shl_both_zero +; CHECK-NEXT: ret i1 true +define i1 @shl_both_zero(i32 %a) { + %shl = shl i32 0, %a + %cmp = icmp eq i32 %shl, 0 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_zero_ap2_non_zero_1 +; CHECK-NEXT: ret i1 false +define i1 @shl_ap1_zero_ap2_non_zero_1(i32 %a) { + %shl = shl i32 1, %a + %cmp = icmp eq i32 %shl, 0 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_zero_ap2_non_zero_2 +; CHECK-NEXT: %cmp = icmp ugt i32 %a, 29 +; CHECK-NEXT: ret i1 %cmp +define i1 @shl_ap1_zero_ap2_non_zero_2(i32 %a) { + %shl = shl i32 4, %a + %cmp = icmp eq i32 %shl, 0 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_zero_ap2_non_zero_3 +; CHECK-NEXT: ret i1 false +define i1 @shl_ap1_zero_ap2_non_zero_3(i32 %a) { + %shl = shl i32 -1, %a + %cmp = icmp eq i32 %shl, 0 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_zero_ap2_non_zero_4 +; CHECK-NEXT: %cmp = icmp ugt i32 %a, 30 +; CHECK-NEXT: ret i1 %cmp +define i1 @shl_ap1_zero_ap2_non_zero_4(i32 %a) { + %shl = shl i32 -2, %a + %cmp = icmp eq i32 %shl, 0 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_non_zero_ap2_zero +; CHECK-NEXT: ret i1 false +define i1 @shl_ap1_non_zero_ap2_zero(i32 %a) { + %shl = shl i32 0, %a + %cmp = icmp eq i32 %shl, 5 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_non_zero_ap2_non_zero_both_positive +; CHECK-NEXT: %cmp = icmp eq i32 %a, 0 +; CHECK-NEXT: ret i1 %cmp +define i1 @shl_ap1_non_zero_ap2_non_zero_both_positive(i32 %a) { + %shl = shl i32 50, %a + %cmp = icmp eq i32 %shl, 50 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_non_zero_ap2_non_zero_both_negative +; CHECK-NEXT: %cmp = icmp eq i32 %a, 0 +; CHECK-NEXT: ret i1 %cmp +define i1 @shl_ap1_non_zero_ap2_non_zero_both_negative(i32 %a) { + %shl = shl i32 -50, %a + %cmp = icmp eq i32 %shl, -50 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_non_zero_ap2_non_zero_ap1_1 +; CHECK-NEXT: ret i1 false +define i1 @shl_ap1_non_zero_ap2_non_zero_ap1_1(i32 %a) { + %shl = shl i32 50, %a + %cmp = icmp eq i32 %shl, 25 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_non_zero_ap2_non_zero_ap1_2 +; CHECK-NEXT: %cmp = icmp eq i32 %a, 1 +; CHECK-NEXT: ret i1 %cmp +define i1 @shl_ap1_non_zero_ap2_non_zero_ap1_2(i32 %a) { + %shl = shl i32 25, %a + %cmp = icmp eq i32 %shl, 50 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_non_zero_ap2_non_zero_ap1_3 +; CHECK-NEXT: ret i1 false +define i1 @shl_ap1_non_zero_ap2_non_zero_ap1_3(i32 %a) { + %shl = shl i32 26, %a + %cmp = icmp eq i32 %shl, 50 + ret i1 %cmp +}