diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -321,6 +321,10 @@ /// Compute known bits from zero-extended multiply-hi. static KnownBits mulhu(const KnownBits &LHS, const KnownBits &RHS); + /// Compute known bits for sdiv(LHS, RHS). + static KnownBits sdiv(const KnownBits &LHS, const KnownBits &RHS, + bool Exact = false); + /// Compute known bits for udiv(LHS, RHS). static KnownBits udiv(const KnownBits &LHS, const KnownBits &RHS); diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -528,6 +528,74 @@ return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth); } +KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS, + bool Exact) { + // Equivilent of `udiv`. We must have caught this before it was folded. + if (LHS.isNonNegative() && RHS.isNonNegative()) + return udiv(LHS, RHS); + + unsigned BitWidth = LHS.getBitWidth(); + assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs"); + KnownBits Known(BitWidth); + + APInt Num, Denum; + // Positive -> true + // Negative -> false + // Unknown -> nullopt + std::optional ResultSign; + if (LHS.isNegative() && RHS.isNegative()) { + Denum = RHS.getSignedMaxValue(); + Num = LHS.getSignedMinValue(); + ResultSign = true; + // Result non-negative. + } else if (LHS.isNegative() && RHS.isStrictlyPositive()) { + // Result is non-negative if Exact OR -LHS u>= RHS. + if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) { + Denum = RHS.getSignedMinValue(); + Num = LHS.getSignedMinValue(); + ResultSign = false; + } + } else if (LHS.isStrictlyPositive() && RHS.isNegative()) { + // Result is non-negative if Xxact OR LHS u>= -RHS. + if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) { + Denum = RHS.getSignedMaxValue(); + Num = LHS.getSignedMaxValue(); + ResultSign = false; + } + } + + if (ResultSign) { + APInt Res = Num.sdiv(Denum); + if (*ResultSign) { + unsigned LeadZ = Res.countLeadingZeros(); + Known.Zero.setHighBits(LeadZ); + Known.makeNonNegative(); + } else { + unsigned LeadO = Res.countLeadingOnes(); + Known.One.setHighBits(LeadO); + Known.makeNegative(); + } + } + + if (Exact) { + // Odd / Odd -> Odd + if (LHS.One[0] && RHS.One[0]) { + Known.Zero.clearBit(0); + Known.One.setBit(0); + } + // Even / Odd -> Even + else if (LHS.Zero[0] && RHS.One[0]) { + Known.One.clearBit(0); + Known.Zero.setBit(0); + } + // Odd / Even -> impossible + // Even / Even -> unknown + } + + assert(!Known.hasConflict() && "Bad Output"); + return Known; +} + KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS) { unsigned BitWidth = LHS.getBitWidth(); assert(!LHS.hasConflict() && !RHS.hasConflict()); diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -114,6 +114,8 @@ KnownBits KnownMulHS(KnownAnd); KnownBits KnownMulHU(KnownAnd); KnownBits KnownUDiv(KnownAnd); + KnownBits KnownSDiv(KnownAnd); + KnownBits KnownSDivExact(KnownAnd); KnownBits KnownURem(KnownAnd); KnownBits KnownSRem(KnownAnd); KnownBits KnownShl(KnownAnd); @@ -173,9 +175,21 @@ KnownURem.One &= Res; KnownURem.Zero &= ~Res; - Res = N1.srem(N2); - KnownSRem.One &= Res; - KnownSRem.Zero &= ~Res; + // INT_MIN / -1 is UB. + if (!N1.isMinSignedValue() || !N2.isAllOnes()) { + Res = N1.srem(N2); + KnownSRem.One &= Res; + KnownSRem.Zero &= ~Res; + + Res = N1.sdiv(N2); + KnownSDiv.One &= Res; + KnownSDiv.Zero &= ~Res; + + if (N1.srem(N2).isZero()) { + KnownSDivExact.One &= Res; + KnownSDivExact.Zero &= ~Res; + } + } } if (N2.ult(1ULL << N1.getBitWidth())) { @@ -237,6 +251,14 @@ EXPECT_TRUE(ComputedUDiv.Zero.isSubsetOf(KnownUDiv.Zero)); EXPECT_TRUE(ComputedUDiv.One.isSubsetOf(KnownUDiv.One)); + KnownBits ComputedSDiv = KnownBits::sdiv(Known1, Known2, false); + EXPECT_TRUE(ComputedSDiv.Zero.isSubsetOf(KnownSDiv.Zero)); + EXPECT_TRUE(ComputedSDiv.One.isSubsetOf(KnownSDiv.One)); + + KnownBits ComputedSDivExact = KnownBits::sdiv(Known1, Known2, true); + EXPECT_TRUE(ComputedSDivExact.Zero.isSubsetOf(KnownSDivExact.Zero)); + EXPECT_TRUE(ComputedSDivExact.One.isSubsetOf(KnownSDivExact.One)); + KnownBits ComputedURem = KnownBits::urem(Known1, Known2); EXPECT_TRUE(ComputedURem.Zero.isSubsetOf(KnownURem.Zero)); EXPECT_TRUE(ComputedURem.One.isSubsetOf(KnownURem.One));