Index: include/llvm/ADT/APInt.h =================================================================== --- include/llvm/ADT/APInt.h +++ include/llvm/ADT/APInt.h @@ -223,6 +223,9 @@ /// out-of-line slow case for countPopulation unsigned countPopulationSlowCase() const; + // out-of-line slow case for setBits. + void setBitsSlowCase(unsigned loBit, unsigned hiBit); + public: /// \name Constructors /// @{ @@ -1239,7 +1242,32 @@ void setBit(unsigned bitPosition); /// Set the bits from loBit (inclusive) to hiBit (exclusive) to 1. - void setBits(unsigned loBit, unsigned hiBit); + void setBits(unsigned loBit, unsigned hiBit) { + assert(hiBit <= BitWidth && "hiBit out of range"); + assert(loBit <= hiBit && "loBit out of range"); + if (loBit == hiBit) + return; + if (loBit < APINT_BITS_PER_WORD && hiBit <= APINT_BITS_PER_WORD) { + uint64_t mask = UINT64_MAX >> (APINT_BITS_PER_WORD - (hiBit - loBit)); + mask <<= loBit; + if (isSingleWord()) + VAL |= mask; + else + pVal[0] |= mask; + } else { + setBitsSlowCase(loBit, hiBit); + } + } + + /// Set the bottom loBits bits. + void setLowBits(unsigned loBits) { + return setBits(0, loBits); + } + + /// Set the top hiBits bits. + void setHighBits(unsigned hiBits) { + return setBits(BitWidth - hiBits, BitWidth); + } /// \brief Set every bit to 0. void clearAllBits() { Index: lib/Support/APInt.cpp =================================================================== --- lib/Support/APInt.cpp +++ lib/Support/APInt.cpp @@ -566,37 +566,31 @@ pVal[whichWord(bitPosition)] |= maskBit(bitPosition); } -void APInt::setBits(unsigned loBit, unsigned hiBit) { - assert(hiBit <= BitWidth && "hiBit out of range"); - assert(loBit <= hiBit && loBit <= BitWidth && "loBit out of range"); - - if (loBit == hiBit) - return; - - if (isSingleWord()) - *this |= APInt::getBitsSet(BitWidth, loBit, hiBit); - else { - unsigned hiBit1 = hiBit - 1; - unsigned loWord = whichWord(loBit); - unsigned hiWord = whichWord(hiBit1); - if (loWord == hiWord) { - // Set bits are all within the same word, create a [loBit,hiBit) mask. - uint64_t mask = UINT64_MAX; - mask >>= (APINT_BITS_PER_WORD - (hiBit - loBit)); - mask <<= whichBit(loBit); - pVal[loWord] |= mask; - } else { - // Set bits span multiple words, create a lo mask with set bits starting - // at loBit, a hi mask with set bits below hiBit and set all bits of the - // words in between. - uint64_t loMask = UINT64_MAX << whichBit(loBit); - uint64_t hiMask = UINT64_MAX >> (64 - whichBit(hiBit1) - 1); - pVal[loWord] |= loMask; +void APInt::setBitsSlowCase(unsigned loBit, unsigned hiBit) { + unsigned loWord = whichWord(loBit); + unsigned hiWord = whichWord(hiBit); + + // Create an initial mask for the low word with zeroes below loBit. + uint64_t loMask = UINT64_MAX << whichBit(loBit); + + // If hiBit is not aligned, we need a high mask. + unsigned hiShiftAmt = whichBit(hiBit); + if (hiShiftAmt != 0) { + // Create a high mask with zeros above hiBit. + uint64_t hiMask = UINT64_MAX >> (APINT_BITS_PER_WORD - hiShiftAmt); + // If loWord and hiWord are equal, then we combine the masks. Otherwise, + // set the bits in hiWord. + if (hiWord == loWord) + loMask &= hiMask; + else pVal[hiWord] |= hiMask; - for (unsigned word = loWord + 1; word < hiWord; ++word) - pVal[word] = UINT64_MAX; - } } + // Apply the mask to the low word. + pVal[loWord] |= loMask; + + // Fill any words between loWord and hiWord with all ones. + for (unsigned word = loWord + 1; word < hiWord; ++word) + pVal[word] = UINT64_MAX; } /// Set the given bit to 0 whose position is given as "bitPosition". Index: unittests/ADT/APIntTest.cpp =================================================================== --- unittests/ADT/APIntTest.cpp +++ unittests/ADT/APIntTest.cpp @@ -1541,3 +1541,125 @@ EXPECT_EQ(0u, i64hi32.countTrailingOnes()); EXPECT_EQ(32u, i64hi32.countPopulation()); } + +TEST(APIntTest, setLowBits) { + APInt i64lo32(64, 0); + i64lo32.setLowBits(32); + EXPECT_EQ(0u, i64lo32.countLeadingOnes()); + EXPECT_EQ(32u, i64lo32.countLeadingZeros()); + EXPECT_EQ(32u, i64lo32.getActiveBits()); + EXPECT_EQ(0u, i64lo32.countTrailingZeros()); + EXPECT_EQ(32u, i64lo32.countTrailingOnes()); + EXPECT_EQ(32u, i64lo32.countPopulation()); + + APInt i128lo64(128, 0); + i128lo64.setLowBits(64); + EXPECT_EQ(0u, i128lo64.countLeadingOnes()); + EXPECT_EQ(64u, i128lo64.countLeadingZeros()); + EXPECT_EQ(64u, i128lo64.getActiveBits()); + EXPECT_EQ(0u, i128lo64.countTrailingZeros()); + EXPECT_EQ(64u, i128lo64.countTrailingOnes()); + EXPECT_EQ(64u, i128lo64.countPopulation()); + + APInt i128lo24(128, 0); + i128lo24.setLowBits(24); + EXPECT_EQ(0u, i128lo24.countLeadingOnes()); + EXPECT_EQ(104u, i128lo24.countLeadingZeros()); + EXPECT_EQ(24u, i128lo24.getActiveBits()); + EXPECT_EQ(0u, i128lo24.countTrailingZeros()); + EXPECT_EQ(24u, i128lo24.countTrailingOnes()); + EXPECT_EQ(24u, i128lo24.countPopulation()); + + APInt i128lo104(128, 0); + i128lo104.setLowBits(104); + EXPECT_EQ(0u, i128lo104.countLeadingOnes()); + EXPECT_EQ(24u, i128lo104.countLeadingZeros()); + EXPECT_EQ(104u, i128lo104.getActiveBits()); + EXPECT_EQ(0u, i128lo104.countTrailingZeros()); + EXPECT_EQ(104u, i128lo104.countTrailingOnes()); + EXPECT_EQ(104u, i128lo104.countPopulation()); + + APInt i128lo0(128, 0); + i128lo0.setLowBits(0); + EXPECT_EQ(0u, i128lo0.countLeadingOnes()); + EXPECT_EQ(128u, i128lo0.countLeadingZeros()); + EXPECT_EQ(0u, i128lo0.getActiveBits()); + EXPECT_EQ(128u, i128lo0.countTrailingZeros()); + EXPECT_EQ(0u, i128lo0.countTrailingOnes()); + EXPECT_EQ(0u, i128lo0.countPopulation()); + + APInt i80lo79(80, 0); + i80lo79.setLowBits(79); + EXPECT_EQ(0u, i80lo79.countLeadingOnes()); + EXPECT_EQ(1u, i80lo79.countLeadingZeros()); + EXPECT_EQ(79u, i80lo79.getActiveBits()); + EXPECT_EQ(0u, i80lo79.countTrailingZeros()); + EXPECT_EQ(79u, i80lo79.countTrailingOnes()); + EXPECT_EQ(79u, i80lo79.countPopulation()); +} + +TEST(APIntTest, setHighBits) { + APInt i64hi32(64, 0); + i64hi32.setHighBits(32); + EXPECT_EQ(32u, i64hi32.countLeadingOnes()); + EXPECT_EQ(0u, i64hi32.countLeadingZeros()); + EXPECT_EQ(64u, i64hi32.getActiveBits()); + EXPECT_EQ(32u, i64hi32.countTrailingZeros()); + EXPECT_EQ(0u, i64hi32.countTrailingOnes()); + EXPECT_EQ(32u, i64hi32.countPopulation()); + + APInt i128hi64(128, 0); + i128hi64.setHighBits(64); + EXPECT_EQ(64u, i128hi64.countLeadingOnes()); + EXPECT_EQ(0u, i128hi64.countLeadingZeros()); + EXPECT_EQ(128u, i128hi64.getActiveBits()); + EXPECT_EQ(64u, i128hi64.countTrailingZeros()); + EXPECT_EQ(0u, i128hi64.countTrailingOnes()); + EXPECT_EQ(64u, i128hi64.countPopulation()); + + APInt i128hi24(128, 0); + i128hi24.setHighBits(24); + EXPECT_EQ(24u, i128hi24.countLeadingOnes()); + EXPECT_EQ(0u, i128hi24.countLeadingZeros()); + EXPECT_EQ(128u, i128hi24.getActiveBits()); + EXPECT_EQ(104u, i128hi24.countTrailingZeros()); + EXPECT_EQ(0u, i128hi24.countTrailingOnes()); + EXPECT_EQ(24u, i128hi24.countPopulation()); + + APInt i128hi104(128, 0); + i128hi104.setHighBits(104); + EXPECT_EQ(104u, i128hi104.countLeadingOnes()); + EXPECT_EQ(0u, i128hi104.countLeadingZeros()); + EXPECT_EQ(128u, i128hi104.getActiveBits()); + EXPECT_EQ(24u, i128hi104.countTrailingZeros()); + EXPECT_EQ(0u, i128hi104.countTrailingOnes()); + EXPECT_EQ(104u, i128hi104.countPopulation()); + + APInt i128hi0(128, 0); + i128hi0.setHighBits(0); + EXPECT_EQ(0u, i128hi0.countLeadingOnes()); + EXPECT_EQ(128u, i128hi0.countLeadingZeros()); + EXPECT_EQ(0u, i128hi0.getActiveBits()); + EXPECT_EQ(128u, i128hi0.countTrailingZeros()); + EXPECT_EQ(0u, i128hi0.countTrailingOnes()); + EXPECT_EQ(0u, i128hi0.countPopulation()); + + APInt i80hi1(80, 0); + i80hi1.setHighBits(1); + EXPECT_EQ(1u, i80hi1.countLeadingOnes()); + EXPECT_EQ(0u, i80hi1.countLeadingZeros()); + EXPECT_EQ(80u, i80hi1.getActiveBits()); + EXPECT_EQ(79u, i80hi1.countTrailingZeros()); + EXPECT_EQ(0u, i80hi1.countTrailingOnes()); + EXPECT_EQ(1u, i80hi1.countPopulation()); + + APInt i32hi16(32, 0); + i32hi16.setHighBits(16); + EXPECT_EQ(16u, i32hi16.countLeadingOnes()); + EXPECT_EQ(0u, i32hi16.countLeadingZeros()); + EXPECT_EQ(32u, i32hi16.getActiveBits()); + EXPECT_EQ(16u, i32hi16.countTrailingZeros()); + EXPECT_EQ(0u, i32hi16.countTrailingOnes()); + EXPECT_EQ(16u, i32hi16.countPopulation()); + +}