diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -332,6 +332,18 @@ static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS, KnownBits RHS); + /// Compute knownbits resulting from llvm.sadd.sat(LHS, RHS) + static KnownBits sadd_sat(const KnownBits &LHS, const KnownBits &RHS); + + /// Compute knownbits resulting from llvm.uadd.sat(LHS, RHS) + static KnownBits uadd_sat(const KnownBits &LHS, const KnownBits &RHS); + + /// Compute knownbits resulting from llvm.ssub.sat(LHS, RHS) + static KnownBits ssub_sat(const KnownBits &LHS, const KnownBits &RHS); + + /// Compute knownbits resulting from llvm.usub.sat(LHS, RHS) + static KnownBits usub_sat(const KnownBits &LHS, const KnownBits &RHS); + /// Compute known bits resulting from multiplying LHS and RHS. static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS, bool NoUndefSelfMultiply = false); diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -465,6 +465,171 @@ return KnownAbs; } +static KnownBits computeForSatAddSub(bool Add, bool Signed, + const KnownBits &LHS, + const KnownBits &RHS) { + assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs"); + // We don't see NSW even for sadd/ssub as we want to check if the result has + // signed overflow. + KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS); + unsigned BitWidth = Res.getBitWidth(); + auto SignBitKnown = [&](const KnownBits &K) { + return K.Zero[BitWidth - 1] || K.One[BitWidth - 1]; + }; + std::optional Overflow; + + if (Signed) { + // If we can actually detect overflow do so. Otherwise leave Overflow as + // nullopt (we assume it may have happened). + if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) { + if (Add) { + // sadd.sat + Overflow = (LHS.isNonNegative() == RHS.isNonNegative() && + Res.isNonNegative() != LHS.isNonNegative()); + } else { + // ssub.sat + Overflow = (LHS.isNonNegative() != RHS.isNonNegative() && + Res.isNonNegative() != LHS.isNonNegative()); + } + } + } else if (Add) { + // uadd.sat + Overflow = KnownBits::ult(Res, RHS); + if (!Overflow) + Overflow = KnownBits::ult(Res, LHS); + if (!Overflow) { + bool Of; + (void)LHS.getMinValue().uadd_ov(RHS.getMinValue(), Of); + if (Of) + Overflow = true; + (void)LHS.getMaxValue().uadd_ov(RHS.getMaxValue(), Of); + if (!Of) + Overflow = false; + } + } else { + // usub.sat + Overflow = KnownBits::ugt(Res, LHS); + if (!Overflow) { + bool Of; + (void)LHS.getMaxValue().usub_ov(RHS.getMinValue(), Of); + if (Of) + Overflow = Of; + (void)LHS.getMinValue().usub_ov(RHS.getMaxValue(), Of); + if (!Of) + Overflow = Of; + } + } + + if (Signed) { + if (Add) { + if (LHS.isNonNegative() && RHS.isNonNegative()) { + // Pos + Pos -> Pos + Res.One.clearSignBit(); + Res.Zero.setSignBit(); + } + if (LHS.isNegative() && RHS.isNegative()) { + // Neg + Neg -> Neg + Res.One.setSignBit(); + Res.Zero.clearSignBit(); + } + } else { + if (LHS.isNegative() && RHS.isNonNegative()) { + // Neg - Pos -> Neg + Res.One.setSignBit(); + Res.Zero.clearSignBit(); + } else if (LHS.isNonNegative() && RHS.isNegative()) { + // Pos - Neg -> Pos + Res.One.clearSignBit(); + Res.Zero.setSignBit(); + } + } + } else { + // Add: Leading ones of either operand are preserved. + // Sub: Leading zeros of LHS and leading ones of RHS are preserved + // as leading zeros in the result. + unsigned LeadingKnown; + if (Add) + LeadingKnown = + std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes()); + else + LeadingKnown = + std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes()); + + // We select between the operation result and all-ones/zero + // respectively, so we can preserve known ones/zeros. + APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown); + if (Add) { + Res.One |= Mask; + Res.Zero &= ~Mask; + } else { + Res.Zero |= Mask; + Res.One &= ~Mask; + } + } + + if (Overflow) { + // We know whether or not we overflowed. + if (!(*Overflow)) { + // No overflow. + assert(!Res.hasConflict() && "Bad Output"); + return Res; + } + + // We overflowed + APInt C; + if (Signed) { + // sadd.sat / ssub.sat + assert(SignBitKnown(LHS) && + "We somehow know overflow without knowing input sign"); + C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth) + : APInt::getSignedMaxValue(BitWidth); + } else if (Add) { + // uadd.sat + C = APInt::getMaxValue(BitWidth); + } else { + // uadd.sat + C = APInt::getMinValue(BitWidth); + } + + Res.One = C; + Res.Zero = ~C; + assert(!Res.hasConflict() && "Bad Output"); + return Res; + } + + // We don't know if we overflowed. + if (Signed) { + // sadd.sat/ssub.sat + // We can keep our information about the sign bits. + Res.Zero.clearLowBits(BitWidth - 1); + Res.One.clearLowBits(BitWidth - 1); + } else if (Add) { + // uadd.sat + // We need to clear all the known zeros as we can only use the leading ones. + Res.Zero.clearAllBits(); + } else { + // usub.sat + // We need to clear all the known ones as we can only use the leading zero. + Res.One.clearAllBits(); + } + + assert(!Res.hasConflict() && "Bad Output"); + return Res; +} + +KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) { + return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS); +} +KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) { + return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS); +} +KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) { + return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS); +} +KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) { + return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS); +} + KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, bool NoUndefSelfMultiply) { unsigned BitWidth = LHS.getBitWidth(); diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -300,7 +300,38 @@ return N1.srem(N2); }, checkCorrectnessOnlyBinary); - + testBinaryOpExhaustive( + [](const KnownBits &Known1, const KnownBits &Known2) { + return KnownBits::sadd_sat(Known1, Known2); + }, + [](const APInt &N1, const APInt &N2) -> std::optional { + return N1.sadd_sat(N2); + }, + checkCorrectnessOnlyBinary); + testBinaryOpExhaustive( + [](const KnownBits &Known1, const KnownBits &Known2) { + return KnownBits::uadd_sat(Known1, Known2); + }, + [](const APInt &N1, const APInt &N2) -> std::optional { + return N1.uadd_sat(N2); + }, + checkCorrectnessOnlyBinary); + testBinaryOpExhaustive( + [](const KnownBits &Known1, const KnownBits &Known2) { + return KnownBits::ssub_sat(Known1, Known2); + }, + [](const APInt &N1, const APInt &N2) -> std::optional { + return N1.ssub_sat(N2); + }, + checkCorrectnessOnlyBinary); + testBinaryOpExhaustive( + [](const KnownBits &Known1, const KnownBits &Known2) { + return KnownBits::usub_sat(Known1, Known2); + }, + [](const APInt &N1, const APInt &N2) -> std::optional { + return N1.usub_sat(N2); + }, + checkCorrectnessOnlyBinary); testBinaryOpExhaustive( [](const KnownBits &Known1, const KnownBits &Known2) { return KnownBits::shl(Known1, Known2);