Index: lib/Transforms/InstCombine/InstCombineCompares.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCompares.cpp +++ lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -37,77 +37,30 @@ STATISTIC(NumSel, "Number of select opts"); -static ConstantInt *extractElement(Constant *V, Constant *Idx) { - return cast(ConstantExpr::getExtractElement(V, Idx)); -} - -static bool hasAddOverflow(ConstantInt *Result, - ConstantInt *In1, ConstantInt *In2, - bool IsSigned) { - if (!IsSigned) - return Result->getValue().ult(In1->getValue()); - - if (In2->isNegative()) - return Result->getValue().sgt(In1->getValue()); - return Result->getValue().slt(In1->getValue()); -} - /// Compute Result = In1+In2, returning true if the result overflowed for this /// type. -static bool addWithOverflow(Constant *&Result, Constant *In1, - Constant *In2, bool IsSigned = false) { - Result = ConstantExpr::getAdd(In1, In2); - - if (VectorType *VTy = dyn_cast(In1->getType())) { - for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { - Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (hasAddOverflow(extractElement(Result, Idx), - extractElement(In1, Idx), - extractElement(In2, Idx), - IsSigned)) - return true; - } - return false; - } - - return hasAddOverflow(cast(Result), - cast(In1), cast(In2), - IsSigned); -} - -static bool hasSubOverflow(ConstantInt *Result, - ConstantInt *In1, ConstantInt *In2, - bool IsSigned) { - if (!IsSigned) - return Result->getValue().ugt(In1->getValue()); - - if (In2->isNegative()) - return Result->getValue().slt(In1->getValue()); +static bool addWithOverflow(APInt &Result, const APInt &In1, + const APInt &In2, bool IsSigned = false) { + bool Overflow; + if (IsSigned) + Result = In1.sadd_ov(In2, Overflow); + else + Result = In1.uadd_ov(In2, Overflow); - return Result->getValue().sgt(In1->getValue()); + return Overflow; } /// Compute Result = In1-In2, returning true if the result overflowed for this /// type. -static bool subWithOverflow(Constant *&Result, Constant *In1, - Constant *In2, bool IsSigned = false) { - Result = ConstantExpr::getSub(In1, In2); - - if (VectorType *VTy = dyn_cast(In1->getType())) { - for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { - Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (hasSubOverflow(extractElement(Result, Idx), - extractElement(In1, Idx), - extractElement(In2, Idx), - IsSigned)) - return true; - } - return false; - } +static bool subWithOverflow(APInt &Result, const APInt &In1, + const APInt &In2, bool IsSigned = false) { + bool Overflow; + if (IsSigned) + Result = In1.ssub_ov(In2, Overflow); + else + Result = In1.usub_ov(In2, Overflow); - return hasSubOverflow(cast(Result), - cast(In1), cast(In2), - IsSigned); + return Overflow; } /// Given an icmp instruction, return true if any use of this comparison is a @@ -2186,28 +2139,22 @@ (DivIsSigned && C2->isAllOnesValue())) return nullptr; - // TODO: We could do all of the computations below using APInt. - Constant *CmpRHS = cast(Cmp.getOperand(1)); - Constant *DivRHS = cast(Div->getOperand(1)); - - // Compute Prod = CmpRHS * DivRHS. We are essentially solving an equation of - // form X / C2 = C. We solve for X by multiplying C2 (DivRHS) and C (CmpRHS). + // Compute Prod = C * C2. We are essentially solving an equation of + // form X / C2 = C. We solve for X by multiplying C2 and C. // By solving for X, we can turn this into a range check instead of computing // a divide. - Constant *Prod = ConstantExpr::getMul(CmpRHS, DivRHS); + APInt Prod = *C * *C2; // Determine if the product overflows by seeing if the product is not equal to // the divide. Make sure we do the same kind of divide as in the LHS // instruction that we're folding. - bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS) - : ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS; + bool ProdOV = (DivIsSigned ? Prod.sdiv(*C2) : Prod.udiv(*C2)) != *C; ICmpInst::Predicate Pred = Cmp.getPredicate(); // If the division is known to be exact, then there is no remainder from the // divide, so the covered range size is unit, otherwise it is the divisor. - Constant *RangeSize = - Div->isExact() ? ConstantInt::get(Div->getType(), 1) : DivRHS; + APInt RangeSize = Div->isExact() ? APInt(C2->getBitWidth(), 1) : *C2; // Figure out the interval that is being checked. For example, a comparison // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). @@ -2217,7 +2164,7 @@ // overflow variable is set to 0 if it's corresponding bound variable is valid // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. int LoOverflow = 0, HiOverflow = 0; - Constant *LoBound = nullptr, *HiBound = nullptr; + APInt LoBound, HiBound; if (!DivIsSigned) { // udiv // e.g. X/5 op 3 --> [15, 20) @@ -2231,7 +2178,7 @@ } else if (C2->isStrictlyPositive()) { // Divisor is > 0. if (C->isNullValue()) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) - LoBound = ConstantExpr::getNeg(SubOne(RangeSize)); + LoBound = -(RangeSize - 1); HiBound = RangeSize; } else if (C->isStrictlyPositive()) { // (X / pos) op pos LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) @@ -2240,27 +2187,27 @@ HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); } else { // (X / pos) op neg // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) - HiBound = AddOne(Prod); + HiBound = Prod + 1; LoOverflow = HiOverflow = ProdOV ? -1 : 0; if (!LoOverflow) { - Constant *DivNeg = ConstantExpr::getNeg(RangeSize); + APInt DivNeg = -RangeSize; LoOverflow = addWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; } } } else if (C2->isNegative()) { // Divisor is < 0. if (Div->isExact()) - RangeSize = ConstantExpr::getNeg(RangeSize); + RangeSize.negate(); if (C->isNullValue()) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) - LoBound = AddOne(RangeSize); - HiBound = ConstantExpr::getNeg(RangeSize); - if (HiBound == DivRHS) { // -INTMIN = INTMIN + LoBound = RangeSize + 1; + HiBound = -RangeSize; + if (HiBound == *C2) { // -INTMIN = INTMIN HiOverflow = 1; // [INTMIN+1, overflow) - HiBound = nullptr; // e.g. X/INTMIN = 0 --> X > INTMIN + HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN } } else if (C->isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) - HiBound = AddOne(Prod); + HiBound = Prod + 1; HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; @@ -2283,25 +2230,27 @@ return replaceInstUsesWith(Cmp, Builder.getFalse()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, LoBound); + ICmpInst::ICMP_UGE, X, + ConstantInt::get(Div->getType(), LoBound)); if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, HiBound); + ICmpInst::ICMP_ULT, X, + ConstantInt::get(Div->getType(), HiBound)); return replaceInstUsesWith( - Cmp, insertRangeTest(X, LoBound->getUniqueInteger(), - HiBound->getUniqueInteger(), DivIsSigned, true)); + Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true)); case ICmpInst::ICMP_NE: if (LoOverflow && HiOverflow) return replaceInstUsesWith(Cmp, Builder.getTrue()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, LoBound); + ICmpInst::ICMP_ULT, X, + ConstantInt::get(Div->getType(), LoBound)); if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, HiBound); + ICmpInst::ICMP_UGE, X, + ConstantInt::get(Div->getType(), HiBound)); return replaceInstUsesWith(Cmp, - insertRangeTest(X, LoBound->getUniqueInteger(), - HiBound->getUniqueInteger(), + insertRangeTest(X, LoBound, HiBound, DivIsSigned, false)); case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_SLT: @@ -2309,7 +2258,7 @@ return replaceInstUsesWith(Cmp, Builder.getTrue()); if (LoOverflow == -1) // Low bound is less than input range. return replaceInstUsesWith(Cmp, Builder.getFalse()); - return new ICmpInst(Pred, X, LoBound); + return new ICmpInst(Pred, X, ConstantInt::get(Div->getType(), LoBound)); case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_SGT: if (HiOverflow == +1) // High bound greater than input range. @@ -2317,8 +2266,10 @@ if (HiOverflow == -1) // High bound less than input range. return replaceInstUsesWith(Cmp, Builder.getTrue()); if (Pred == ICmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); - return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); + return new ICmpInst(ICmpInst::ICMP_UGE, X, + ConstantInt::get(Div->getType(), HiBound)); + return new ICmpInst(ICmpInst::ICMP_SGE, X, + ConstantInt::get(Div->getType(), HiBound)); } return nullptr;