Index: include/llvm/ADT/APInt.h =================================================================== --- include/llvm/ADT/APInt.h +++ include/llvm/ADT/APInt.h @@ -865,7 +865,14 @@ /// \brief Logical right-shift function. /// /// Logical right-shift this APInt by shiftAmt. - APInt lshr(unsigned shiftAmt) const; + APInt lshr(unsigned shiftAmt) const { + APInt R(*this); + R.lshrInPlace(shiftAmt); + return R; + } + + /// Logical right-shift this APInt by shiftAmt in place. + void lshrInPlace(unsigned shiftAmt); /// \brief Left-shift function. /// @@ -1937,7 +1944,7 @@ return A.ugt(B) ? A : B; } -/// \brief Compute GCD of two APInt values. +/// \brief Compute GCD of two unsigned APInt values. /// /// This function returns the greatest common divisor of the two APInt values /// using Euclid's algorithm. Index: lib/Support/APInt.cpp =================================================================== --- lib/Support/APInt.cpp +++ lib/Support/APInt.cpp @@ -770,16 +770,27 @@ return Count; } -/// Perform a logical right-shift from Src to Dst, which must be equal or -/// non-overlapping, of Words words, by Shift, which must be less than 64. +/// Perform a logical right-shift from Src to Dst of Words words, by Shift, +/// which must be less than 64. If the source and destination ranges overlap, +/// we require that Src >= Dst (put another way, we require that the overall +/// operation is a right shift). static void lshrNear(uint64_t *Dst, uint64_t *Src, unsigned Words, unsigned Shift) { - uint64_t Carry = 0; - for (int I = Words - 1; I >= 0; --I) { - uint64_t Tmp = Src[I]; - Dst[I] = (Tmp >> Shift) | Carry; - Carry = Tmp << (64 - Shift); + if (!Words) + return; + + if (Shift == 0) { + std::memmove(Dst, Src, Words * 8); + return; } + + uint64_t Low = Src[0]; + for (unsigned I = 1; I != Words; ++I) { + uint64_t High = Src[I]; + Dst[I - 1] = (Low >> Shift) | (High << (64 - Shift)); + Low = High; + } + Dst[Words - 1] = Low >> Shift; } APInt APInt::byteSwap() const { @@ -840,11 +851,45 @@ } APInt llvm::APIntOps::GreatestCommonDivisor(APInt A, APInt B) { - while (!!B) { - APInt R = A.urem(B); - A = std::move(B); - B = std::move(R); + // Fast-path a common case. + if (A == B) return A; + + // Corner cases: if either operand is zero, the other is the gcd. + if (!A) return B; + if (!B) return A; + + // Count common powers of 2 and remove all other powers of 2. + unsigned Pow2; + { + unsigned Pow2_A = A.countTrailingZeros(); + unsigned Pow2_B = B.countTrailingZeros(); + if (Pow2_A > Pow2_B) { + A.lshrInPlace(Pow2_A - Pow2_B); + Pow2 = Pow2_B; + } else if (Pow2_B > Pow2_A) { + B.lshrInPlace(Pow2_B - Pow2_A); + Pow2 = Pow2_A; + } else { + Pow2 = Pow2_A; + } + } + + // Both operands are odd multiples of 2^Pow_2: + // + // gcd(a, b) = gcd(|a - b| / 2^i, min(a, b)) + // + // This is a modified version of Stein's algorithm, taking advantage of + // efficient countTrailingZeros(). + while (A != B) { + if (A.ugt(B)) { + A -= B; + A.lshrInPlace(A.countTrailingZeros() - Pow2); + } else { + B -= A; + B.lshrInPlace(B.countTrailingZeros() - Pow2); + } } + return A; } @@ -1158,66 +1203,30 @@ /// Logical right-shift this APInt by shiftAmt. /// @brief Logical right-shift function. -APInt APInt::lshr(unsigned shiftAmt) const { +void APInt::lshrInPlace(unsigned shiftAmt) { if (isSingleWord()) { if (shiftAmt >= BitWidth) - return APInt(BitWidth, 0); + VAL = 0; else - return APInt(BitWidth, this->VAL >> shiftAmt); - } - - // If all the bits were shifted out, the result is 0. This avoids issues - // with shifting by the size of the integer type, which produces undefined - // results. We define these "undefined results" to always be 0. - if (shiftAmt >= BitWidth) - return APInt(BitWidth, 0); - - // If none of the bits are shifted out, the result is *this. This avoids - // issues with shifting by the size of the integer type, which produces - // undefined results in the code below. This is also an optimization. - if (shiftAmt == 0) - return *this; - - // Create some space for the result. - uint64_t * val = new uint64_t[getNumWords()]; - - // If we are shifting less than a word, compute the shift with a simple carry - if (shiftAmt < APINT_BITS_PER_WORD) { - lshrNear(val, pVal, getNumWords(), shiftAmt); - APInt Result(val, BitWidth); - Result.clearUnusedBits(); - return Result; - } - - // Compute some values needed by the remaining shift algorithms - unsigned wordShift = shiftAmt % APINT_BITS_PER_WORD; - unsigned offset = shiftAmt / APINT_BITS_PER_WORD; - - // If we are shifting whole words, just move whole words - if (wordShift == 0) { - for (unsigned i = 0; i < getNumWords() - offset; ++i) - val[i] = pVal[i+offset]; - for (unsigned i = getNumWords()-offset; i < getNumWords(); i++) - val[i] = 0; - APInt Result(val, BitWidth); - Result.clearUnusedBits(); - return Result; + VAL >>= shiftAmt; + return; } - // Shift the low order words - unsigned breakWord = getNumWords() - offset -1; - for (unsigned i = 0; i < breakWord; ++i) - val[i] = (pVal[i+offset] >> wordShift) | - (pVal[i+offset+1] << (APINT_BITS_PER_WORD - wordShift)); - // Shift the break word. - val[breakWord] = pVal[breakWord+offset] >> wordShift; + // Don't bother performing a no-op shift. + if (!shiftAmt) + return; - // Remaining words are 0 - for (unsigned i = breakWord+1; i < getNumWords(); ++i) - val[i] = 0; - APInt Result(val, BitWidth); - Result.clearUnusedBits(); - return Result; + // Find number of complete words being shifted out and zeroed. + const unsigned Words = getNumWords(); + const unsigned ShiftFullWords = std::min(shiftAmt >> 6, Words); + + // Fill in first Words - ShiftFullWords by shifting. + lshrNear(pVal, pVal + ShiftFullWords, Words - ShiftFullWords, + shiftAmt - (ShiftFullWords << 6)); + + // The remaining high words are all zero. + for (unsigned I = Words - ShiftFullWords; I != Words; ++I) + pVal[I] = 0; } /// Left-shift this APInt by shiftAmt. Index: unittests/ADT/APIntTest.cpp =================================================================== --- unittests/ADT/APIntTest.cpp +++ unittests/ADT/APIntTest.cpp @@ -1977,3 +1977,47 @@ i128.setHighBits(2); EXPECT_EQ(0xc, i128.getHiBits(4)); } + +TEST(APIntTest, GCD) { + using APIntOps::GreatestCommonDivisor; + + for (unsigned Bits : {1, 2, 32, 63, 64, 65}) { + // Test some corner cases near zero. + APInt Zero(Bits, 0), One(Bits, 1); + EXPECT_EQ(GreatestCommonDivisor(Zero, Zero), Zero); + EXPECT_EQ(GreatestCommonDivisor(Zero, One), One); + EXPECT_EQ(GreatestCommonDivisor(One, Zero), One); + EXPECT_EQ(GreatestCommonDivisor(One, One), One); + + if (Bits > 1) { + APInt Two(Bits, 2); + EXPECT_EQ(GreatestCommonDivisor(Zero, Two), Two); + EXPECT_EQ(GreatestCommonDivisor(One, Two), One); + EXPECT_EQ(GreatestCommonDivisor(Two, Two), Two); + + // Test some corner cases near the highest representable value. + APInt Max(Bits, 0); + Max.setAllBits(); + EXPECT_EQ(GreatestCommonDivisor(Zero, Max), Max); + EXPECT_EQ(GreatestCommonDivisor(One, Max), One); + EXPECT_EQ(GreatestCommonDivisor(Two, Max), One); + EXPECT_EQ(GreatestCommonDivisor(Max, Max), Max); + + APInt MaxOver2 = Max.udiv(Two); + EXPECT_EQ(GreatestCommonDivisor(MaxOver2, Max), One); + // Max - 1 == Max / 2 * 2, because Max is odd. + EXPECT_EQ(GreatestCommonDivisor(MaxOver2, Max - 1), MaxOver2); + } + } + + // Compute the 20th Mersenne prime. + APInt HugePrime(4423, 0); + HugePrime.setAllBits(); + HugePrime = HugePrime.zext(4450); + + // 9931 and 123456 are coprime. + APInt A = HugePrime * APInt(4450, 9931); + APInt B = HugePrime * APInt(4450, 123456); + APInt C = GreatestCommonDivisor(A, B); + EXPECT_EQ(C, HugePrime); +}