diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -326,7 +326,8 @@ bool Exact = false); /// Compute known bits for udiv(LHS, RHS). - static KnownBits udiv(const KnownBits &LHS, const KnownBits &RHS); + static KnownBits udiv(const KnownBits &LHS, const KnownBits &RHS, + bool Exact = false); /// Compute known bits for urem(LHS, RHS). static KnownBits urem(const KnownBits &LHS, const KnownBits &RHS); diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -532,7 +532,7 @@ bool Exact) { // Equivilent of `udiv`. We must have caught this before it was folded. if (LHS.isNonNegative() && RHS.isNonNegative()) - return udiv(LHS, RHS); + return udiv(LHS, RHS, Exact); unsigned BitWidth = LHS.getBitWidth(); assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs"); @@ -596,21 +596,33 @@ return Known; } -KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS) { +KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS, + bool Exact) { unsigned BitWidth = LHS.getBitWidth(); assert(!LHS.hasConflict() && !RHS.hasConflict()); KnownBits Known(BitWidth); - // For the purposes of computing leading zeros we can conservatively - // treat a udiv as a logical right shift by the power of 2 known to - // be less than the denominator. - unsigned LeadZ = LHS.countMinLeadingZeros(); - unsigned RHSMaxLeadingZeros = RHS.countMaxLeadingZeros(); + // We can figure out the minimum number of upper zero bits by doing + // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator + // gets larger, the number of upper zero bits increases. + APInt MinDenum = RHS.getMinValue(); + APInt MaxNum = LHS.getMaxValue(); + APInt MaxRes = MinDenum.isZero() ? MaxNum : MaxNum.udiv(MinDenum); - if (RHSMaxLeadingZeros != BitWidth) - LeadZ = std::min(BitWidth, LeadZ + BitWidth - RHSMaxLeadingZeros - 1); + unsigned LeadZ = MaxRes.countLeadingZeros(); Known.Zero.setHighBits(LeadZ); + if (Exact) { + // Odd / Odd -> Odd + if (LHS.One[0] && RHS.One[0]) + Known.One.setBit(0); + // Even / Odd -> Even + else if (LHS.Zero[0] && RHS.One[0]) + Known.Zero.setBit(0); + // Odd / Even -> impossible + // Even / Even -> unknown + } + return Known; } diff --git a/llvm/test/Analysis/ValueTracking/knownbits-div.ll b/llvm/test/Analysis/ValueTracking/knownbits-div.ll --- a/llvm/test/Analysis/ValueTracking/knownbits-div.ll +++ b/llvm/test/Analysis/ValueTracking/knownbits-div.ll @@ -185,12 +185,7 @@ define i1 @udiv_high_bits(i8 %x, i8 %y) { ; CHECK-LABEL: @udiv_high_bits( -; CHECK-NEXT: [[NUM:%.*]] = and i8 [[X:%.*]], -127 -; CHECK-NEXT: [[DENUM:%.*]] = or i8 [[Y:%.*]], 31 -; CHECK-NEXT: [[DIV:%.*]] = udiv i8 [[NUM]], [[DENUM]] -; CHECK-NEXT: [[AND:%.*]] = and i8 [[DIV]], 8 -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[AND]], 8 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %num = and i8 %x, 129 %denum = or i8 %y, 31 diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -249,6 +249,16 @@ return N1.udiv(N2); }, checkCorrectnessOnlyBinary); + testBinaryOpExhaustive( + [](const KnownBits &Known1, const KnownBits &Known2) { + return KnownBits::udiv(Known1, Known2, /*Exact*/ true); + }, + [](const APInt &N1, const APInt &N2) -> std::optional { + if (N2.isZero() || !N1.urem(N2).isZero()) + return std::nullopt; + return N1.udiv(N2); + }, + checkCorrectnessOnlyBinary); testBinaryOpExhaustive( [](const KnownBits &Known1, const KnownBits &Known2) { return KnownBits::sdiv(Known1, Known2);