diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -301,10 +301,11 @@ /// Compute known bits resulting from adding LHS and RHS. static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS, - KnownBits RHS); + KnownBits RHS, bool SelfAdd = false); /// Compute known bits resulting from multiplying LHS and RHS. - static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS); + static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS, + bool SelfMultiply = false); /// Compute known bits from sign-extended multiply-hi. static KnownBits mulhs(const KnownBits &LHS, const KnownBits &RHS); diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -54,13 +54,22 @@ LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue()); } -KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, - const KnownBits &LHS, KnownBits RHS) { +KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, const KnownBits &LHS, + KnownBits RHS, bool SelfAdd) { KnownBits KnownOut; if (Add) { + if (SelfAdd) { + // Sum = LHS + LHS = LHS << 1 + assert(LHS.One == RHS.One && LHS.Zero == RHS.Zero && + "Self-addition knownbits mismatch"); + KnownOut.One = LHS.One << 1; + KnownOut.Zero = LHS.Zero << 1; + KnownOut.Zero.setLowBits(1); + return KnownOut; + } // Sum = LHS + RHS + 0 - KnownOut = ::computeForAddCarry( - LHS, RHS, /*CarryZero*/true, /*CarryOne*/false); + KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero*/ true, + /*CarryOne*/ false); } else { // Sum = LHS + ~RHS + 1 std::swap(RHS.Zero, RHS.One); @@ -412,10 +421,13 @@ return KnownAbs; } -KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS) { +KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, + bool SelfMultiply) { unsigned BitWidth = LHS.getBitWidth(); assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() && !RHS.hasConflict() && "Operand mismatch"); + assert((!SelfMultiply || (LHS.One == RHS.One && LHS.Zero == RHS.Zero)) && + "Self multiplication knownbits mismatch"); // Compute a conservative estimate for high known-0 bits. unsigned LeadZ = @@ -489,6 +501,14 @@ Res.Zero.setHighBits(LeadZ); Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); Res.One = BottomKnown.getLoBits(ResultBitsKnown); + + // If we're self-multiplying then bit[1] is guaranteed to be zero. + if (SelfMultiply && BitWidth > 1) { + assert(Res.One[1] == 0 && + "Self-multiplication failed Quadratic Reciprocity!"); + Res.Zero.setBit(1); + } + return Res; } diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -92,6 +92,40 @@ EXPECT_TRUE(KnownNSWComputed.One.isSubsetOf(KnownNSW.One)); }); }); + + // Also test 'SelfAdd' cases where the same argument is repeated. + if (IsAdd) { + ForeachKnownBits(Bits, [&](const KnownBits &Known) { + KnownBits KnownAdd(Bits), KnownNSW(Bits); + KnownAdd.Zero.setAllBits(); + KnownAdd.One.setAllBits(); + KnownNSW.Zero.setAllBits(); + KnownNSW.One.setAllBits(); + + ForeachNumInKnownBits(Known, [&](const APInt &N) { + bool Overflow; + APInt Res = N.sadd_ov(N, Overflow); + + KnownAdd.One &= Res; + KnownAdd.Zero &= ~Res; + + if (!Overflow) { + KnownNSW.One &= Res; + KnownNSW.Zero &= ~Res; + } + }); + + KnownBits KnownComputed = + KnownBits::computeForAddSub(IsAdd, /*NSW*/ false, Known, Known, true); + EXPECT_EQ(KnownAdd.Zero, KnownComputed.Zero); + EXPECT_EQ(KnownAdd.One, KnownComputed.One); + + KnownBits KnownNSWComputed = + KnownBits::computeForAddSub(IsAdd, /*NSW*/ true, Known, Known, true); + EXPECT_TRUE(KnownNSWComputed.Zero.isSubsetOf(KnownNSW.Zero)); + EXPECT_TRUE(KnownNSWComputed.One.isSubsetOf(KnownNSW.One)); + }); + } } TEST(KnownBitsTest, AddSubExhaustive) { @@ -267,6 +301,23 @@ EXPECT_TRUE(ComputedAShr.One.isSubsetOf(KnownAShr.One)); }); }); + + // Also test 'unary' binary cases where the same argument is repeated. + ForeachKnownBits(Bits, [&](const KnownBits &Known) { + KnownBits KnownMul(Bits); + KnownMul.Zero.setAllBits(); + KnownMul.One.setAllBits(); + + ForeachNumInKnownBits(Known, [&](const APInt &N) { + APInt Res = N * N; + KnownMul.One &= Res; + KnownMul.Zero &= ~Res; + }); + + KnownBits ComputedMul = KnownBits::mul(Known, Known, true); + EXPECT_TRUE(ComputedMul.Zero.isSubsetOf(KnownMul.Zero)); + EXPECT_TRUE(ComputedMul.One.isSubsetOf(KnownMul.One)); + }); } TEST(KnownBitsTest, UnaryExhaustive) {