Index: lib/Analysis/ValueTracking.cpp =================================================================== --- lib/Analysis/ValueTracking.cpp +++ lib/Analysis/ValueTracking.cpp @@ -350,21 +350,43 @@ } } - // If low bits are zero in either operand, output low known-0 bits. - // Also compute a conservative estimate for high known-0 bits. - // More trickiness is possible, but this is sufficient for the - // interesting case of alignment computation. - unsigned TrailZ = Known.countMinTrailingZeros() + - Known2.countMinTrailingZeros(); + // The result of the bottom bits of an integer multiply can + // be inferred by looking at the bottom bits of both operands + // and multiplying them together. + assert(!Known.hasConflict() && !Known2.hasConflict()); + APInt Bottom0 = Known.One; + APInt Bottom1 = Known2.One; + + // Compute a conservative estimate for high known-0 bits. unsigned LeadZ = std::max(Known.countMinLeadingZeros() + Known2.countMinLeadingZeros(), BitWidth) - BitWidth; - - TrailZ = std::min(TrailZ, BitWidth); LeadZ = std::min(LeadZ, BitWidth); + + // If there are trailing zeros on either operand, we can infer + // extra bits of the multiplication result. + // Find the last bit known on both operands. + unsigned TrailBitsKnown0 = (Known.Zero | Known.One).countTrailingOnes(); + unsigned TrailBitsKnown1 = (Known2.Zero | Known2.One).countTrailingOnes(); + // How many times we'd be able to divide each argument by 2 (shr by 1). + unsigned TrailZero0 = Known.countMinTrailingZeros(); + unsigned TrailZero1 = Known2.countMinTrailingZeros(); + // Number of trailing zeros on the multiplication result. + unsigned TrailZ = TrailZero0 + TrailZero1; + unsigned ResultBitsKnown = std::min(TrailBitsKnown0 - TrailZero0, + TrailBitsKnown1 - TrailZero1); + // We know at least the trailing zeros, plus any other known bits + // of the operands. + ResultBitsKnown = std::min(ResultBitsKnown + TrailZ, BitWidth); + + // Finally, these are the known bottom bits of the result. + APInt BottomKnown = Bottom0.getLoBits(TrailBitsKnown0) * + Bottom1.getLoBits(TrailBitsKnown1); + Known.resetAll(); - Known.Zero.setLowBits(TrailZ); Known.Zero.setHighBits(LeadZ); + Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); + Known.One |= BottomKnown.getLoBits(ResultBitsKnown); // Only make use of no-wrap flags if we failed to compute the sign bit // directly. This matters if the multiplication always overflows, in Index: unittests/Analysis/ValueTrackingTest.cpp =================================================================== --- unittests/Analysis/ValueTrackingTest.cpp +++ unittests/Analysis/ValueTrackingTest.cpp @@ -15,6 +15,7 @@ #include "llvm/IR/Module.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/KnownBits.h" #include "gtest/gtest.h" using namespace llvm; @@ -258,3 +259,57 @@ cast(F->getEntryBlock().getTerminator())->getOperand(0); EXPECT_EQ(ComputeNumSignBits(RVal, M->getDataLayout()), 1u); } + +TEST(ValueTracking, ComputeKnownBits) { + StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { " + " %ash = mul i32 %a, 8 " + " %aad = add i32 %ash, 7 " + " %aan = and i32 %aad, 4095 " + " %bsh = shl i32 %b, 4 " + " %bad = or i32 %bsh, 6 " + " %ban = and i32 %bad, 4095 " + " %mul = mul i32 %aan, %ban " + " ret i32 %mul " + "} "; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + assert(M && "Bad assembly?"); + + auto *F = M->getFunction("f"); + assert(F && "Bad assembly?"); + + auto *RVal = + cast(F->getEntryBlock().getTerminator())->getOperand(0); + auto Known = computeKnownBits(RVal, M->getDataLayout()); + ASSERT_FALSE(Known.hasConflict()); + EXPECT_EQ(Known.One.getZExtValue(), 10u); + EXPECT_EQ(Known.Zero.getZExtValue(), 4278190085u); +} + +TEST(ValueTracking, ComputeKnownMulBits) { + StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { " + " %aa = shl i32 %a, 5 " + " %bb = shl i32 %b, 5 " + " %aaa = or i32 %aa, 24 " + " %bbb = or i32 %bb, 28 " + " %mul = mul i32 %aaa, %bbb " + " ret i32 %mul " + "} "; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + assert(M && "Bad assembly?"); + + auto *F = M->getFunction("f"); + assert(F && "Bad assembly?"); + + auto *RVal = + cast(F->getEntryBlock().getTerminator())->getOperand(0); + auto Known = computeKnownBits(RVal, M->getDataLayout()); + ASSERT_FALSE(Known.hasConflict()); + EXPECT_EQ(Known.One.getZExtValue(), 32u); + EXPECT_EQ(Known.Zero.getZExtValue(), 95u); +}