diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -538,7 +538,7 @@ KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS, bool Exact) { - // Equivilent of `udiv`. We must have caught this before it was folded. + // Equivalent of `udiv`. We must have caught this before it was folded. if (LHS.isNonNegative() && RHS.isNonNegative()) return udiv(LHS, RHS, Exact); @@ -546,42 +546,39 @@ assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs"); KnownBits Known(BitWidth); - APInt Num, Denum; - // Positive -> true - // Negative -> false - // Unknown -> nullopt - std::optional ResultSign; + std::optional Res; if (LHS.isNegative() && RHS.isNegative()) { - Denum = RHS.getSignedMaxValue(); - Num = LHS.getSignedMinValue(); - ResultSign = true; // Result non-negative. - } else if (LHS.isNegative() && RHS.isStrictlyPositive()) { - // Result is non-negative if Exact OR -LHS u>= RHS. + APInt Denom = RHS.getSignedMaxValue(); + APInt Num = LHS.getSignedMinValue(); + // INT_MIN/-1 would be a poison result (impossible). Estimate the division + // as signed max (we will only set sign bit in the result). + Res = (Num.isMinSignedValue() && Denom.isAllOnes()) + ? APInt::getSignedMaxValue(BitWidth) + : Num.sdiv(Denom); + } else if (LHS.isNegative() && RHS.isNonNegative()) { + // Result is negative if Exact OR -LHS u>= RHS. if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) { - Denum = RHS.getSignedMinValue(); - Num = LHS.getSignedMinValue(); - ResultSign = false; + APInt Denom = RHS.getSignedMinValue(); + APInt Num = LHS.getSignedMinValue(); + Res = Denom.isZero() ? Num : Num.sdiv(Denom); } } else if (LHS.isStrictlyPositive() && RHS.isNegative()) { - // Result is non-negative if Exact OR LHS u>= -RHS. + // Result is negative if Exact OR LHS u>= -RHS. if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) { - Denum = RHS.getSignedMaxValue(); - Num = LHS.getSignedMaxValue(); - ResultSign = false; + APInt Denom = RHS.getSignedMaxValue(); + APInt Num = LHS.getSignedMaxValue(); + Res = Num.sdiv(Denom); } } - if (ResultSign) { - APInt Res = Num.sdiv(Denum); - if (*ResultSign) { - unsigned LeadZ = Res.countLeadingZeros(); + if (Res) { + if (Res->isNonNegative()) { + unsigned LeadZ = Res->countLeadingZeros(); Known.Zero.setHighBits(LeadZ); - Known.makeNonNegative(); } else { - unsigned LeadO = Res.countLeadingOnes(); + unsigned LeadO = Res->countLeadingOnes(); Known.One.setHighBits(LeadO); - Known.makeNegative(); } } @@ -613,24 +610,29 @@ // We can figure out the minimum number of upper zero bits by doing // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator // gets larger, the number of upper zero bits increases. - APInt MinDenum = RHS.getMinValue(); + APInt MinDenom = RHS.getMinValue(); APInt MaxNum = LHS.getMaxValue(); - APInt MaxRes = MinDenum.isZero() ? MaxNum : MaxNum.udiv(MinDenum); + APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(MinDenom); unsigned LeadZ = MaxRes.countLeadingZeros(); Known.Zero.setHighBits(LeadZ); if (Exact) { // Odd / Odd -> Odd - if (LHS.One[0] && RHS.One[0]) + if (LHS.One[0] && RHS.One[0]) { + Known.Zero.clearBit(0); Known.One.setBit(0); + } // Even / Odd -> Even - else if (LHS.Zero[0] && RHS.One[0]) + else if (LHS.Zero[0] && RHS.One[0]) { + Known.One.clearBit(0); Known.Zero.setBit(0); + } // Odd / Even -> impossible // Even / Even -> unknown } + assert(!Known.hasConflict() && "Bad Output"); return Known; }