diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -166,6 +166,16 @@ return *this; } + /// Return known bits for a sign extension or truncation of the value we're + /// tracking. + KnownBits sextOrTrunc(unsigned BitWidth) const { + if (BitWidth > getBitWidth()) + return sext(BitWidth); + if (BitWidth < getBitWidth()) + return trunc(BitWidth); + return *this; + } + /// Return a KnownBits with the extracted bits /// [bitPosition,bitPosition+numBits). KnownBits extractBits(unsigned NumBits, unsigned BitPosition) const { diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -203,4 +203,33 @@ }); } +TEST(KnownBitsTest, SExtOrTrunc) { + const unsigned NarrowerSize = 4; + const unsigned BaseSize = 6; + const unsigned WiderSize = 8; + APInt NegativeFitsNarrower(BaseSize, -4, /*isSigned*/ true); + APInt NegativeDoesntFitNarrower(BaseSize, -28, /*isSigned*/ true); + APInt PositiveFitsNarrower(BaseSize, 14); + APInt PositiveDoesntFitNarrower(BaseSize, 36); + auto InitKnownBits = [&](KnownBits &Res, const APInt &Input) { + Res = KnownBits(Input.getBitWidth()); + Res.One = Input; + Res.Zero = ~Input; + }; + + for (unsigned Size : {NarrowerSize, BaseSize, WiderSize}) { + for (const APInt &Input : + {NegativeFitsNarrower, NegativeDoesntFitNarrower, PositiveFitsNarrower, + PositiveDoesntFitNarrower}) { + KnownBits Test; + InitKnownBits(Test, Input); + KnownBits Baseline; + InitKnownBits(Baseline, Input.sextOrTrunc(Size)); + Test = Test.sextOrTrunc(Size); + EXPECT_EQ(Test.One, Baseline.One); + EXPECT_EQ(Test.Zero, Baseline.Zero); + } + } +} + } // end anonymous namespace