diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -304,7 +304,8 @@ KnownBits RHS); /// Compute known bits resulting from multiplying LHS and RHS. - static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS); + static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS, + bool SelfMultiply = false); /// Compute known bits from sign-extended multiply-hi. static KnownBits mulhs(const KnownBits &LHS, const KnownBits &RHS); diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -412,10 +412,13 @@ return KnownAbs; } -KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS) { +KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, + bool SelfMultiply) { unsigned BitWidth = LHS.getBitWidth(); assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() && !RHS.hasConflict() && "Operand mismatch"); + assert((!SelfMultiply || (LHS.One == RHS.One && LHS.Zero == RHS.Zero)) && + "Self multiplication knownbits mismatch"); // Compute a conservative estimate for high known-0 bits. unsigned LeadZ = @@ -489,6 +492,14 @@ Res.Zero.setHighBits(LeadZ); Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); Res.One = BottomKnown.getLoBits(ResultBitsKnown); + + // If we're self-multiplying then bit[1] is guaranteed to be zero. + if (SelfMultiply && BitWidth > 1) { + assert(Res.One[1] == 0 && + "Self-multiplication failed Quadratic Reciprocity!"); + Res.Zero.setBit(1); + } + return Res; } diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -267,6 +267,23 @@ EXPECT_TRUE(ComputedAShr.One.isSubsetOf(KnownAShr.One)); }); }); + + // Also test 'unary' binary cases where the same argument is repeated. + ForeachKnownBits(Bits, [&](const KnownBits &Known) { + KnownBits KnownMul(Bits); + KnownMul.Zero.setAllBits(); + KnownMul.One.setAllBits(); + + ForeachNumInKnownBits(Known, [&](const APInt &N) { + APInt Res = N * N; + KnownMul.One &= Res; + KnownMul.Zero &= ~Res; + }); + + KnownBits ComputedMul = KnownBits::mul(Known, Known, /*SelfMultiply*/ true); + EXPECT_TRUE(ComputedMul.Zero.isSubsetOf(KnownMul.Zero)); + EXPECT_TRUE(ComputedMul.One.isSubsetOf(KnownMul.One)); + }); } TEST(KnownBitsTest, UnaryExhaustive) {