diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -342,6 +342,10 @@ /// Compute known bits from zero-extended multiply-hi. static KnownBits mulhu(const KnownBits &LHS, const KnownBits &RHS); + /// Compute known bits for sdiv(LHS, RHS). + static KnownBits sdiv(const KnownBits &LHS, const KnownBits &RHS, + bool Exact = false); + /// Compute known bits for udiv(LHS, RHS). static KnownBits udiv(const KnownBits &LHS, const KnownBits &RHS); diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -536,6 +536,74 @@ return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth); } +KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS, + bool Exact) { + // Equivilent of `udiv`. We must have caught this before it was folded. + if (LHS.isNonNegative() && RHS.isNonNegative()) + return udiv(LHS, RHS); + + unsigned BitWidth = LHS.getBitWidth(); + assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs"); + KnownBits Known(BitWidth); + + APInt Num, Denum; + // Positive -> true + // Negative -> false + // Unknown -> nullopt + std::optional ResultSign; + if (LHS.isNegative() && RHS.isNegative()) { + Denum = RHS.getSignedMaxValue(); + Num = LHS.getSignedMinValue(); + ResultSign = true; + // Result non-negative. + } else if (LHS.isNegative() && RHS.isStrictlyPositive()) { + // Result is non-negative if Exact OR -LHS u>= RHS. + if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) { + Denum = RHS.getSignedMinValue(); + Num = LHS.getSignedMinValue(); + ResultSign = false; + } + } else if (LHS.isStrictlyPositive() && RHS.isNegative()) { + // Result is non-negative if Exact OR LHS u>= -RHS. + if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) { + Denum = RHS.getSignedMaxValue(); + Num = LHS.getSignedMaxValue(); + ResultSign = false; + } + } + + if (ResultSign) { + APInt Res = Num.sdiv(Denum); + if (*ResultSign) { + unsigned LeadZ = Res.countLeadingZeros(); + Known.Zero.setHighBits(LeadZ); + Known.makeNonNegative(); + } else { + unsigned LeadO = Res.countLeadingOnes(); + Known.One.setHighBits(LeadO); + Known.makeNegative(); + } + } + + if (Exact) { + // Odd / Odd -> Odd + if (LHS.One[0] && RHS.One[0]) { + Known.Zero.clearBit(0); + Known.One.setBit(0); + } + // Even / Odd -> Even + else if (LHS.Zero[0] && RHS.One[0]) { + Known.One.clearBit(0); + Known.Zero.setBit(0); + } + // Odd / Even -> impossible + // Even / Even -> unknown + } + + assert(!Known.hasConflict() && "Bad Output"); + return Known; +} + KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS) { unsigned BitWidth = LHS.getBitWidth(); assert(!LHS.hasConflict() && !RHS.hasConflict()); diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -249,6 +249,27 @@ return N1.udiv(N2); }, checkCorrectnessOnlyBinary); + testBinaryOpExhaustive( + [](const KnownBits &Known1, const KnownBits &Known2) { + return KnownBits::sdiv(Known1, Known2); + }, + [](const APInt &N1, const APInt &N2) -> std::optional { + if (N2.isZero() || (N1.isMinSignedValue() && N2.isAllOnes())) + return std::nullopt; + return N1.sdiv(N2); + }, + checkCorrectnessOnlyBinary); + testBinaryOpExhaustive( + [](const KnownBits &Known1, const KnownBits &Known2) { + return KnownBits::sdiv(Known1, Known2, /*Exact*/ true); + }, + [](const APInt &N1, const APInt &N2) -> std::optional { + if (N2.isZero() || (N1.isMinSignedValue() && N2.isAllOnes()) || + !N1.srem(N2).isZero()) + return std::nullopt; + return N1.sdiv(N2); + }, + checkCorrectnessOnlyBinary); testBinaryOpExhaustive( [](const KnownBits &Known1, const KnownBits &Known2) { return KnownBits::urem(Known1, Known2);