diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -538,7 +538,7 @@ KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS, bool Exact) { - // Equivilent of `udiv`. We must have caught this before it was folded. + // Equivalent of `udiv`. We must have caught this before it was folded. if (LHS.isNonNegative() && RHS.isNonNegative()) return udiv(LHS, RHS, Exact); @@ -546,41 +546,49 @@ assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs"); KnownBits Known(BitWidth); - APInt Num, Denum; + APInt Num, Denom; // Positive -> true // Negative -> false // Unknown -> nullopt std::optional ResultSign; if (LHS.isNegative() && RHS.isNegative()) { - Denum = RHS.getSignedMaxValue(); + // Result non-negative. + Denom = RHS.getSignedMaxValue(); Num = LHS.getSignedMinValue(); ResultSign = true; - // Result non-negative. - } else if (LHS.isNegative() && RHS.isStrictlyPositive()) { - // Result is non-negative if Exact OR -LHS u>= RHS. + } else if (LHS.isNegative() && RHS.isNonNegative()) { + // Result is negative if Exact OR -LHS u>= RHS. if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) { - Denum = RHS.getSignedMinValue(); + Denom = RHS.getSignedMinValue(); Num = LHS.getSignedMinValue(); ResultSign = false; } } else if (LHS.isStrictlyPositive() && RHS.isNegative()) { - // Result is non-negative if Exact OR LHS u>= -RHS. + // Result is negative if Exact OR LHS u>= -RHS. if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) { - Denum = RHS.getSignedMaxValue(); + Denom = RHS.getSignedMaxValue(); Num = LHS.getSignedMaxValue(); ResultSign = false; } } if (ResultSign) { - APInt Res = Num.sdiv(Denum); + // Denom may be zero as we only check RHS non-negative / negative and RHS + // non-negative includes zero. + std::optional Res; + if (!Denom.isZero()) + Res = Num.sdiv(Denom); if (*ResultSign) { - unsigned LeadZ = Res.countLeadingZeros(); - Known.Zero.setHighBits(LeadZ); + if (Res) { + unsigned LeadZ = Res->countLeadingZeros(); + Known.Zero.setHighBits(LeadZ); + } Known.makeNonNegative(); } else { - unsigned LeadO = Res.countLeadingOnes(); - Known.One.setHighBits(LeadO); + if (Res) { + unsigned LeadO = Res->countLeadingOnes(); + Known.One.setHighBits(LeadO); + } Known.makeNegative(); } } @@ -613,24 +621,29 @@ // We can figure out the minimum number of upper zero bits by doing // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator // gets larger, the number of upper zero bits increases. - APInt MinDenum = RHS.getMinValue(); + APInt MinDenom = RHS.getMinValue(); APInt MaxNum = LHS.getMaxValue(); - APInt MaxRes = MinDenum.isZero() ? MaxNum : MaxNum.udiv(MinDenum); + APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(MinDenom); unsigned LeadZ = MaxRes.countLeadingZeros(); Known.Zero.setHighBits(LeadZ); if (Exact) { // Odd / Odd -> Odd - if (LHS.One[0] && RHS.One[0]) + if (LHS.One[0] && RHS.One[0]) { + Known.Zero.clearBit(0); Known.One.setBit(0); + } // Even / Odd -> Even - else if (LHS.Zero[0] && RHS.One[0]) + else if (LHS.Zero[0] && RHS.One[0]) { + Known.One.clearBit(0); Known.Zero.setBit(0); + } // Odd / Even -> impossible // Even / Even -> unknown } + assert(!Known.hasConflict() && "Bad Output"); return Known; }