diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -173,6 +173,10 @@ One.extractBits(NumBits, BitPosition)); } + /// Return KnownBits based on this, but updated given that the underlying + /// value is known to be greater than or equal to Val. + KnownBits makeGE(APInt Val) const; + /// Returns the minimum number of trailing zero bits. unsigned countMinTrailingZeros() const { return Zero.countTrailingOnes(); @@ -241,6 +245,18 @@ static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS, KnownBits RHS); + /// Compute known bits for umax(LHS, RHS). + static KnownBits umax(const KnownBits &LHS, const KnownBits &RHS); + + /// Compute known bits for umin(LHS, RHS). + static KnownBits umin(const KnownBits &LHS, const KnownBits &RHS); + + /// Compute known bits for smax(LHS, RHS). + static KnownBits smax(const KnownBits &LHS, const KnownBits &RHS); + + /// Compute known bits for smin(LHS, RHS). + static KnownBits smin(const KnownBits &LHS, const KnownBits &RHS); + /// Insert the bits from a smaller known bits starting at bitPosition. void insertBits(const KnownBits &SubBits, unsigned BitPosition) { Zero.insertBits(SubBits.Zero, BitPosition); diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -83,6 +83,82 @@ return KnownOut; } +KnownBits KnownBits::makeGE(APInt Val) const { + // Find the highest bit position in which our underlying value might be + // strictly greater than Val. + unsigned N = (~Zero & ~Val).countLeadingZeros(); + + // For every higher bit position, if Val has a 1 in that bit then our + // underlying value must also have a 1. + Val.clearLowBits(getBitWidth() - N); + return KnownBits(Zero, One | Val); +} + +KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) { + // If the result of the umax is LHS then it must be greater than or equal to + // the minimum possible value of RHS. Likewise for RHS. Any known bits that + // are common to these two values are also known in the result. + KnownBits L = LHS.makeGE(RHS.getMinValue()); + KnownBits R = RHS.makeGE(LHS.getMinValue()); + KnownBits Res(L.Zero & R.Zero, L.One & R.One); + + unsigned BitWidth = LHS.getBitWidth(); + + // If LHS is known to be greater than RHS in some bit position, and greater + // than or equal to RHS in every higher bit position, then we know that the + // result of the umax is LHS, so copy its known bits. + APInt LHSgtRHS = LHS.One & RHS.Zero; + APInt LHSgeRHS = LHS.One | RHS.Zero; + LHSgeRHS.setLowBits(BitWidth - LHSgtRHS.countLeadingZeros()); + if (LHSgeRHS.isAllOnesValue()) { + Res.Zero |= LHS.Zero; + Res.One |= LHS.One; + } + + // Likewise for RHS. + APInt RHSgtLHS = RHS.One & LHS.Zero; + APInt RHSgeLHS = RHS.One | LHS.Zero; + RHSgeLHS.setLowBits(BitWidth - RHSgtLHS.countLeadingZeros()); + if (RHSgeLHS.isAllOnesValue()) { + Res.Zero |= RHS.Zero; + Res.One |= RHS.One; + } + + return Res; +} + +KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) { + // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0] + auto Flip = [](KnownBits Val) { return KnownBits(Val.One, Val.Zero); }; + return Flip(umax(Flip(LHS), Flip(RHS))); +} + +KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) { + // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF] + auto Flip = [](KnownBits Val) { + unsigned SignBitPosition = Val.getBitWidth() - 1; + APInt Zero = Val.Zero; + APInt One = Val.One; + Zero.setBitTo(SignBitPosition, Val.One[SignBitPosition]); + One.setBitTo(SignBitPosition, Val.Zero[SignBitPosition]); + return KnownBits(Zero, One); + }; + return Flip(umax(Flip(LHS), Flip(RHS))); +} + +KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) { + // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0] + auto Flip = [](KnownBits Val) { + unsigned SignBitPosition = Val.getBitWidth() - 1; + APInt Zero = Val.One; + APInt One = Val.Zero; + Zero.setBitTo(SignBitPosition, Val.Zero[SignBitPosition]); + One.setBitTo(SignBitPosition, Val.One[SignBitPosition]); + return KnownBits(Zero, One); + }; + return Flip(umax(Flip(LHS), Flip(RHS))); +} + KnownBits &KnownBits::operator&=(const KnownBits &RHS) { // Result bit is 0 if either operand bit is 0. Zero |= RHS.Zero; diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -103,13 +103,15 @@ unsigned Bits = 4; ForeachKnownBits(Bits, [&](const KnownBits &Known1) { ForeachKnownBits(Bits, [&](const KnownBits &Known2) { - KnownBits KnownAnd(Bits), KnownOr(Bits), KnownXor(Bits); + KnownBits KnownAnd(Bits); KnownAnd.Zero.setAllBits(); KnownAnd.One.setAllBits(); - KnownOr.Zero.setAllBits(); - KnownOr.One.setAllBits(); - KnownXor.Zero.setAllBits(); - KnownXor.One.setAllBits(); + KnownBits KnownOr(KnownAnd); + KnownBits KnownXor(KnownAnd); + KnownBits KnownUMax(KnownAnd); + KnownBits KnownUMin(KnownAnd); + KnownBits KnownSMax(KnownAnd); + KnownBits KnownSMin(KnownAnd); ForeachNumInKnownBits(Known1, [&](const APInt &N1) { ForeachNumInKnownBits(Known2, [&](const APInt &N2) { @@ -126,6 +128,22 @@ Res = N1 ^ N2; KnownXor.One &= Res; KnownXor.Zero &= ~Res; + + Res = APIntOps::umax(N1, N2); + KnownUMax.One &= Res; + KnownUMax.Zero &= ~Res; + + Res = APIntOps::umin(N1, N2); + KnownUMin.One &= Res; + KnownUMin.Zero &= ~Res; + + Res = APIntOps::smax(N1, N2); + KnownSMax.One &= Res; + KnownSMax.Zero &= ~Res; + + Res = APIntOps::smin(N1, N2); + KnownSMin.One &= Res; + KnownSMin.Zero &= ~Res; }); }); @@ -140,6 +158,22 @@ KnownBits ComputedXor = Known1 ^ Known2; EXPECT_EQ(KnownXor.Zero, ComputedXor.Zero); EXPECT_EQ(KnownXor.One, ComputedXor.One); + + KnownBits ComputedUMax = KnownBits::umax(Known1, Known2); + EXPECT_EQ(KnownUMax.Zero, ComputedUMax.Zero); + EXPECT_EQ(KnownUMax.One, ComputedUMax.One); + + KnownBits ComputedUMin = KnownBits::umin(Known1, Known2); + EXPECT_EQ(KnownUMin.Zero, ComputedUMin.Zero); + EXPECT_EQ(KnownUMin.One, ComputedUMin.One); + + KnownBits ComputedSMax = KnownBits::smax(Known1, Known2); + EXPECT_EQ(KnownSMax.Zero, ComputedSMax.Zero); + EXPECT_EQ(KnownSMax.One, ComputedSMax.One); + + KnownBits ComputedSMin = KnownBits::smin(Known1, Known2); + EXPECT_EQ(KnownSMin.Zero, ComputedSMin.Zero); + EXPECT_EQ(KnownSMin.One, ComputedSMin.One); }); }); }