diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -422,6 +422,26 @@ return KnownBits(Zero.reverseBits(), One.reverseBits()); } + /// Compute known bits for X & -X. The name comes from the X86 BMI instruction + /// BLSI. + KnownBits blsi(const KnownBits &NegX) const; + KnownBits blsi() const { + KnownBits NegX(getBitWidth()); + return blsi(NegX); + } + + /// Compute known bits for X ^ (X - 1). The name comes from the X86 BMI + /// instruction BLSMSK. + KnownBits blsmsk(const KnownBits &XMinusC, const APInt &C) const; + KnownBits blsmsk(const KnownBits &XMinusC) const { + APInt C = APInt::getAllOnes(getBitWidth()); + return blsmsk(XMinusC, C); + } + KnownBits blsmsk() const { + KnownBits XMinusC(getBitWidth()); + return blsmsk(XMinusC); + } + bool operator==(const KnownBits &Other) const { return Zero == Other.Zero && One == Other.One; } diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -623,6 +623,35 @@ return *this; } +KnownBits KnownBits::blsi(const KnownBits &NegX) const { + KnownBits Known = *this & NegX; + unsigned BitWidth = getBitWidth(); + unsigned Max = + std::min(countMaxTrailingZeros(), NegX.countMaxTrailingZeros()); + Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth)); + unsigned Min = + std::max(countMinTrailingZeros(), NegX.countMinTrailingZeros()); + if (Max == Min && Max < BitWidth) + Known.One.setBit(Max); + return Known; +} + +// TODO: It's possible for use to have more information on `XMinusC` than `X` +// (if there is an assume or branch on `X-C` for example). The logic could be +// more complete by trying to get min set bit from (XMinusOne + C). Note, `C` +// is generally `-1` but in some cases can be shrunk by SimplifyDemandedBits. +KnownBits KnownBits::blsmsk(const KnownBits &XMinusC, const APInt &C) const { + KnownBits Known = *this ^ XMinusC; + unsigned BitWidth = getBitWidth(); + unsigned Max = countMaxTrailingZeros(); + Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth)); + unsigned Min = countMinTrailingZeros(); + Known.One.setLowBits(std::min(Min + 1, BitWidth)); + return Known; + (void)XMinusC; + (void)C; +} + void KnownBits::print(raw_ostream &OS) const { OS << "{Zero=" << Zero << ", One=" << One << "}"; } diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -284,6 +284,12 @@ KnownAbs.Zero.setAllBits(); KnownAbs.One.setAllBits(); KnownBits KnownAbsPoison(KnownAbs); + KnownBits KnownBlsi(Bits); + KnownBlsi.Zero.setAllBits(); + KnownBlsi.One.setAllBits(); + KnownBits KnownBlsmsk(Bits); + KnownBlsmsk.Zero.setAllBits(); + KnownBlsmsk.One.setAllBits(); ForeachNumInKnownBits(Known, [&](const APInt &N) { APInt Res = N.abs(); @@ -294,6 +300,14 @@ KnownAbsPoison.One &= Res; KnownAbsPoison.Zero &= ~Res; } + + Res = N & -N; + KnownBlsi.One &= Res; + KnownBlsi.Zero &= ~Res; + + Res = N ^ (N - 1); + KnownBlsmsk.One &= Res; + KnownBlsmsk.Zero &= ~Res; }); // abs() is conservatively correct, but not guaranteed to be precise. @@ -304,6 +318,12 @@ KnownBits ComputedAbsPoison = Known.abs(true); EXPECT_TRUE(ComputedAbsPoison.Zero.isSubsetOf(KnownAbsPoison.Zero)); EXPECT_TRUE(ComputedAbsPoison.One.isSubsetOf(KnownAbsPoison.One)); + + KnownBits ComputedBlsi = Known.blsi(); + EXPECT_EQ(KnownBlsi, ComputedBlsi); + + KnownBits ComputedBlsmsk = Known.blsmsk(); + EXPECT_EQ(KnownBlsmsk, ComputedBlsmsk); }); }