diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -311,6 +311,11 @@ static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS, KnownBits RHS); + static KnownBits sadd_sat(const KnownBits &LHS, const KnownBits &RHS); + static KnownBits uadd_sat(const KnownBits &LHS, const KnownBits &RHS); + static KnownBits ssub_sat(const KnownBits &LHS, const KnownBits &RHS); + static KnownBits usub_sat(const KnownBits &LHS, const KnownBits &RHS); + /// Compute known bits resulting from multiplying LHS and RHS. static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS, bool NoUndefSelfMultiply = false); diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -430,6 +430,167 @@ return KnownAbs; } +static KnownBits computeForSatAddSub(bool Add, bool Signed, + const KnownBits &LHS, + const KnownBits &RHS) { + assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs"); + // We don't see NSW even for sadd/ssub as we want to check if the result has + // signed overflow. + KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS); + unsigned BitWidth = Res.getBitWidth(); + auto SignBitKnown = [&](const KnownBits &K) { + return K.Zero[BitWidth - 1] || K.One[BitWidth - 1]; + }; + std::optional Overflow; + + if (Signed) { + // If we can actually detect overflow do so. Otherwise leave Overflow as + // nullopt (we assume it may have happened). + if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) { + if (Add) { + // sadd.sat + Overflow = (LHS.isNonNegative() == RHS.isNonNegative() && + Res.isNonNegative() != LHS.isNonNegative()); + } else { + // ssub.sat + Overflow = (LHS.isNonNegative() != RHS.isNonNegative() && + Res.isNonNegative() != LHS.isNonNegative()); + } + } + } else if (Add) { + // uadd.sat + Overflow = KnownBits::ult(Res, RHS); + if (!Overflow) + Overflow = KnownBits::ult(Res, LHS); + if (!Overflow) { + bool Of; + (void)LHS.getMinValue().uadd_ov(RHS.getMinValue(), Of); + if (Of) + Overflow = true; + (void)LHS.getMaxValue().uadd_ov(RHS.getMaxValue(), Of); + if (!Of) + Overflow = false; + } + } else { + // usub.sat + Overflow = KnownBits::ugt(Res, LHS); + if (!Overflow) { + bool Of; + (void)LHS.getMaxValue().usub_ov(RHS.getMinValue(), Of); + if (Of) + Overflow = Of; + (void)LHS.getMinValue().usub_ov(RHS.getMaxValue(), Of); + if (!Of) + Overflow = Of; + } + } + + if (Signed) { + if (Add) { + if (LHS.isNonNegative() && RHS.isNonNegative()) { + // Pos + Pos -> Pos + Res.One.clearSignBit(); + Res.Zero.setSignBit(); + } + if (LHS.isNegative() && RHS.isNegative()) { + // Neg + Neg -> Neg + Res.One.setSignBit(); + Res.Zero.clearSignBit(); + } + } else { + if (LHS.isNegative() && RHS.isNonNegative()) { + // Neg - Pos -> Neg + Res.One.setSignBit(); + Res.Zero.clearSignBit(); + } else if (LHS.isNonNegative() && RHS.isNegative()) { + // Pos - Neg -> Pos + Res.One.clearSignBit(); + Res.Zero.setSignBit(); + } + } + } else { + // Add: Leading ones of either operand are preserved. + // Sub: Leading zeros of LHS and leading ones of RHS are preserved + // as leading zeros in the result. + unsigned LeadingKnown; + if (Add) + LeadingKnown = + std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes()); + else + LeadingKnown = + std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes()); + + // We select between the operation result and all-ones/zero + // respectively, so we can preserve known ones/zeros. + APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown); + if (Add) { + Res.One |= Mask; + Res.Zero &= ~Mask; + } else { + Res.Zero |= Mask; + Res.One &= ~Mask; + } + } + + if (Overflow) { + // We know whether or not we overflowed. + if (!(*Overflow)) { + // No overflow. + return Res; + } + + // We overflowed + APInt C; + if (Signed) { + // sadd.sat / ssub.sat + assert(SignBitKnown(LHS) && + "We somehow know overflow without knowing input sign"); + C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth) + : APInt::getSignedMaxValue(BitWidth); + } else if (Add) { + // uadd.sat + C = APInt::getMaxValue(BitWidth); + } else { + // uadd.sat + C = APInt::getMinValue(BitWidth); + } + + Res.One = C; + Res.Zero = ~C; + return Res; + } + + // We don't know if we overflowed. + if (Signed) { + // sadd.sat/ssub.sat + // We can keep our information about the sign bits. + Res.Zero.clearLowBits(BitWidth - 1); + Res.One.clearLowBits(BitWidth - 1); + } else if (Add) { + // uadd.sat + // We need to clear all the known zeros as we can only use the leading ones. + Res.Zero.clearAllBits(); + } else { + // usub.sat + // We need to clear all the known ones as we can only use the leading zero. + Res.One.clearAllBits(); + } + return Res; +} + +KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) { + return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS); +} +KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) { + return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS); +} +KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) { + return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS); +} +KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) { + return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS); +} + KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, bool NoUndefSelfMultiply) { unsigned BitWidth = LHS.getBitWidth(); diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -122,6 +122,10 @@ KnownBits KnownShl(KnownAnd); KnownBits KnownLShr(KnownAnd); KnownBits KnownAShr(KnownAnd); + KnownBits KnownUAddSat(KnownAnd); + KnownBits KnownUSubSat(KnownAnd); + KnownBits KnownSAddSat(KnownAnd); + KnownBits KnownSSubSat(KnownAnd); ForeachNumInKnownBits(Known1, [&](const APInt &N1) { ForeachNumInKnownBits(Known2, [&](const APInt &N2) { @@ -167,6 +171,22 @@ KnownMulHU.One &= Res; KnownMulHU.Zero &= ~Res; + Res = N1.uadd_sat(N2); + KnownUAddSat.One &= Res; + KnownUAddSat.Zero &= ~Res; + + Res = N1.usub_sat(N2); + KnownUSubSat.One &= Res; + KnownUSubSat.Zero &= ~Res; + + Res = N1.sadd_sat(N2); + KnownSAddSat.One &= Res; + KnownSAddSat.Zero &= ~Res; + + Res = N1.ssub_sat(N2); + KnownSSubSat.One &= Res; + KnownSSubSat.Zero &= ~Res; + if (!N2.isZero()) { Res = N1.udiv(N2); KnownUDiv.One &= Res; @@ -288,6 +308,22 @@ KnownBits ComputedAShr = KnownBits::ashr(Known1, Known2); EXPECT_TRUE(ComputedAShr.Zero.isSubsetOf(KnownAShr.Zero)); EXPECT_TRUE(ComputedAShr.One.isSubsetOf(KnownAShr.One)); + + KnownBits ComputedUAddSat = KnownBits::uadd_sat(Known1, Known2); + EXPECT_TRUE(ComputedUAddSat.Zero.isSubsetOf(KnownUAddSat.Zero)); + EXPECT_TRUE(ComputedUAddSat.One.isSubsetOf(KnownUAddSat.One)); + + KnownBits ComputedUSubSat = KnownBits::usub_sat(Known1, Known2); + EXPECT_TRUE(ComputedUSubSat.Zero.isSubsetOf(KnownUSubSat.Zero)); + EXPECT_TRUE(ComputedUSubSat.One.isSubsetOf(KnownUSubSat.One)); + + KnownBits ComputedSAddSat = KnownBits::sadd_sat(Known1, Known2); + EXPECT_TRUE(ComputedSAddSat.Zero.isSubsetOf(KnownSAddSat.Zero)); + EXPECT_TRUE(ComputedSAddSat.One.isSubsetOf(KnownSAddSat.One)); + + KnownBits ComputedSSubSat = KnownBits::ssub_sat(Known1, Known2); + EXPECT_TRUE(ComputedSSubSat.Zero.isSubsetOf(KnownSSubSat.Zero)); + EXPECT_TRUE(ComputedSSubSat.One.isSubsetOf(KnownSSubSat.One)); }); });