diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -348,9 +348,10 @@ /// Determine if all bits are set. bool isAllOnes() const { if (isSingleWord()) { - if (BitWidth == 0) - return false; - return U.VAL == WORDTYPE_MAX >> (APINT_BITS_PER_WORD - BitWidth); + // Calculate the shift amount, handling the zero-bit wide case without UB. + unsigned ShiftAmt = + (APINT_BITS_PER_WORD - BitWidth) % APINT_BITS_PER_WORD; + return U.VAL == WORDTYPE_MAX >> ShiftAmt; } return countTrailingOnesSlowCase() == BitWidth; } @@ -580,7 +581,7 @@ return *this; } - AssignSlowCase(RHS); + assignSlowCase(RHS); return *this; } @@ -632,7 +633,7 @@ if (isSingleWord()) U.VAL &= RHS.U.VAL; else - AndAssignSlowCase(RHS); + andAssignSlowCase(RHS); return *this; } @@ -662,7 +663,7 @@ if (isSingleWord()) U.VAL |= RHS.U.VAL; else - OrAssignSlowCase(RHS); + orAssignSlowCase(RHS); return *this; } @@ -691,7 +692,7 @@ if (isSingleWord()) U.VAL ^= RHS.U.VAL; else - XorAssignSlowCase(RHS); + xorAssignSlowCase(RHS); return *this; } @@ -877,6 +878,17 @@ /// Rotate right by rotateAmt. APInt rotr(const APInt &rotateAmt) const; + /// Concatenate the bits from "NewLSB" onto the bottom of *this. This is + /// equivalent to: + /// (this->zext(NewWidth) << NewLSB.getBitWidth()) | NewLSB.zext(NewWidth) + APInt concat(const APInt &NewLSB) const { + /// If the result will be small, then both the merged values are small. + unsigned NewWidth = getBitWidth() + NewLSB.getBitWidth(); + if (NewWidth <= APINT_BITS_PER_WORD) + return APInt(NewWidth, (U.VAL << NewLSB.getBitWidth()) | NewLSB.U.VAL); + return concatSlowCase(NewLSB); + } + /// Unsigned division operation. /// /// Perform an unsigned divide operation on this APInt by RHS. Both this and @@ -971,7 +983,7 @@ assert(BitWidth == RHS.BitWidth && "Comparison requires equal bit widths"); if (isSingleWord()) return U.VAL == RHS.U.VAL; - return EqualSlowCase(RHS); + return equalSlowCase(RHS); } /// Equality operator. @@ -1491,7 +1503,7 @@ /// of 1 bits from the most significant to the least unsigned countLeadingOnes() const { if (isSingleWord()) { - if (BitWidth == 0) + if (LLVM_UNLIKELY(BitWidth == 0)) return 0; return llvm::countLeadingOnes(U.VAL << (APINT_BITS_PER_WORD - BitWidth)); } @@ -1799,7 +1811,6 @@ unsigned BitWidth; ///< The number of bits in this APInt. friend struct DenseMapInfo; - friend class APSInt; /// This constructor is used only internally for speed of construction of @@ -1841,7 +1852,7 @@ // Mask out the high bits. uint64_t mask = WORDTYPE_MAX >> (APINT_BITS_PER_WORD - WordBits); - if (BitWidth == 0) + if (LLVM_UNLIKELY(BitWidth == 0)) mask = 0; if (isSingleWord()) @@ -1905,10 +1916,10 @@ void ashrSlowCase(unsigned ShiftAmt); /// out-of-line slow case for operator= - void AssignSlowCase(const APInt &RHS); + void assignSlowCase(const APInt &RHS); /// out-of-line slow case for operator== - bool EqualSlowCase(const APInt &RHS) const LLVM_READONLY; + bool equalSlowCase(const APInt &RHS) const LLVM_READONLY; /// out-of-line slow case for countLeadingZeros unsigned countLeadingZerosSlowCase() const LLVM_READONLY; @@ -1937,14 +1948,17 @@ /// out-of-line slow case for flipAllBits. void flipAllBitsSlowCase(); + /// out-of-line slow case for concat. + APInt concatSlowCase(const APInt &NewLSB) const; + /// out-of-line slow case for operator&=. - void AndAssignSlowCase(const APInt &RHS); + void andAssignSlowCase(const APInt &RHS); /// out-of-line slow case for operator|=. - void OrAssignSlowCase(const APInt &RHS); + void orAssignSlowCase(const APInt &RHS); /// out-of-line slow case for operator^=. - void XorAssignSlowCase(const APInt &RHS); + void xorAssignSlowCase(const APInt &RHS); /// Unsigned comparison. Returns -1, 0, or 1 if this APInt is less than, equal /// to, or greater than RHS. diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -137,7 +137,7 @@ U.pVal = getMemory(getNumWords()); } -void APInt::AssignSlowCase(const APInt& RHS) { +void APInt::assignSlowCase(const APInt &RHS) { // Don't do anything for X = X if (this == &RHS) return; @@ -235,19 +235,19 @@ return Result; } -void APInt::AndAssignSlowCase(const APInt &RHS) { +void APInt::andAssignSlowCase(const APInt &RHS) { WordType *dst = U.pVal, *rhs = RHS.U.pVal; for (size_t i = 0, e = getNumWords(); i != e; ++i) dst[i] &= rhs[i]; } -void APInt::OrAssignSlowCase(const APInt &RHS) { +void APInt::orAssignSlowCase(const APInt &RHS) { WordType *dst = U.pVal, *rhs = RHS.U.pVal; for (size_t i = 0, e = getNumWords(); i != e; ++i) dst[i] |= rhs[i]; } -void APInt::XorAssignSlowCase(const APInt &RHS) { +void APInt::xorAssignSlowCase(const APInt &RHS) { WordType *dst = U.pVal, *rhs = RHS.U.pVal; for (size_t i = 0, e = getNumWords(); i != e; ++i) dst[i] ^= rhs[i]; @@ -268,7 +268,7 @@ return clearUnusedBits(); } -bool APInt::EqualSlowCase(const APInt& RHS) const { +bool APInt::equalSlowCase(const APInt &RHS) const { return std::equal(U.pVal, U.pVal + getNumWords(), RHS.U.pVal); } @@ -339,6 +339,17 @@ clearUnusedBits(); } +/// Concatenate the bits from "NewLSB" onto the bottom of *this. This is +/// equivalent to: +/// (this->zext(NewWidth) << NewLSB.getBitWidth()) | NewLSB.zext(NewWidth) +/// In the slow case, we know the result is large. +APInt APInt::concatSlowCase(const APInt &NewLSB) const { + unsigned NewWidth = getBitWidth() + NewLSB.getBitWidth(); + APInt Result = NewLSB.zext(NewWidth); + Result.insertBits(*this, NewLSB.getBitWidth()); + return Result; +} + /// Toggle a given bit to its opposite value whose position is given /// as "bitPosition". /// Toggles a given bit to its opposite value. @@ -1064,7 +1075,7 @@ // Calculate the rotate amount modulo the bit width. static unsigned rotateModulo(unsigned BitWidth, const APInt &rotateAmt) { - if (BitWidth == 0) + if (LLVM_UNLIKELY(BitWidth == 0)) return 0; unsigned rotBitWidth = rotateAmt.getBitWidth(); APInt rot = rotateAmt; @@ -1082,7 +1093,7 @@ } APInt APInt::rotl(unsigned rotateAmt) const { - if (BitWidth == 0) + if (LLVM_UNLIKELY(BitWidth == 0)) return *this; rotateAmt %= BitWidth; if (rotateAmt == 0) diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -2587,6 +2587,18 @@ EXPECT_EQ(0xFFFFFFFF, val.truncOrSelf(64)); } +TEST(APIntTest, concatMSB) { + APInt Int1(4, 0x1ULL); + APInt Int3(4, 0x3ULL); + + EXPECT_EQ(0x31, Int3.concat(Int1)); + EXPECT_EQ(APInt(12, 0x313), Int3.concat(Int1).concat(Int3)); + EXPECT_EQ(APInt(16, 0x3313), Int3.concat(Int3).concat(Int1).concat(Int3)); + + APInt I64(64, 0x3ULL); + EXPECT_EQ(I64, I64.concat(I64).lshr(64).trunc(64)); +} + TEST(APIntTest, multiply) { APInt i64(64, 1234);