Index: lib/Analysis/ValueTracking.cpp =================================================================== --- lib/Analysis/ValueTracking.cpp +++ lib/Analysis/ValueTracking.cpp @@ -350,6 +350,16 @@ } } + // The result of the bottom bits of an integer multiply can + // be inferred by looking at the bottom bits of both operands + // and multiplying them together. + assert(!Known.hasConflict() && !Known2.hasConflict()); + APInt Bottom0 = Known.One; + APInt Bottom1 = Known2.One; + // Find the last bit known on both operands. + unsigned TrailBitsKnown = (Known.Zero | Known.One).countTrailingOnes(); + TrailBitsKnown = std::min(TrailBitsKnown, (Known2.Zero | Known2.One).countTrailingOnes()); + // If low bits are zero in either operand, output low known-0 bits. // Also compute a conservative estimate for high known-0 bits. // More trickiness is possible, but this is sufficient for the @@ -366,6 +376,12 @@ Known.Zero.setLowBits(TrailZ); Known.Zero.setHighBits(LeadZ); + // Finally, these are the known bottom bits of the result. + APInt BottomKnown = Bottom0.getLoBits(TrailBitsKnown) * + Bottom1.getLoBits(TrailBitsKnown); + Known.Zero |= (~BottomKnown).getLoBits(TrailBitsKnown); + Known.One |= BottomKnown.getLoBits(TrailBitsKnown); + // Only make use of no-wrap flags if we failed to compute the sign bit // directly. This matters if the multiplication always overflows, in // which case we prefer to follow the result of the direct computation, Index: unittests/Analysis/ValueTrackingTest.cpp =================================================================== --- unittests/Analysis/ValueTrackingTest.cpp +++ unittests/Analysis/ValueTrackingTest.cpp @@ -15,6 +15,7 @@ #include "llvm/IR/Module.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/KnownBits.h" #include "gtest/gtest.h" using namespace llvm; @@ -258,3 +259,55 @@ cast(F->getEntryBlock().getTerminator())->getOperand(0); EXPECT_EQ(ComputeNumSignBits(RVal, M->getDataLayout()), 1u); } + +TEST(ValueTracking, ComputeKnownBits) { + StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { " + " %ash = mul i32 %a, 8 " + " %aad = add i32 %ash, 7 " + " %aan = and i32 %aad, 4095 " + " %bsh = shl i32 %b, 4 " + " %bad = or i32 %bsh, 6 " + " %ban = and i32 %bad, 4095 " + " %mul = mul i32 %aan, %ban " + " ret i32 %mul " + "} "; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + assert(M && "Bad assembly?"); + + auto *F = M->getFunction("f"); + assert(F && "Bad assembly?"); + + auto *RVal = + cast(F->getEntryBlock().getTerminator())->getOperand(0); + auto Known = computeKnownBits(RVal, M->getDataLayout()); + ASSERT_FALSE(Known.hasConflict()); + EXPECT_EQ(Known.One.getZExtValue(), 2u); + EXPECT_EQ(Known.Zero.getZExtValue(), 4278190085u); +} + +TEST(ValueTracking, ComputeKnownMulBits) { + StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { " + " %aa = or i32 %a, 15 " + " %bb = or i32 %b, 3 " + " %mul = mul i32 %aa, %bb " + " ret i32 %mul " + "} "; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + assert(M && "Bad assembly?"); + + auto *F = M->getFunction("f"); + assert(F && "Bad assembly?"); + + auto *RVal = + cast(F->getEntryBlock().getTerminator())->getOperand(0); + auto Known = computeKnownBits(RVal, M->getDataLayout()); + ASSERT_FALSE(Known.hasConflict()); + EXPECT_EQ(Known.One.getZExtValue(), 1u); + EXPECT_EQ(Known.Zero.getZExtValue(), 2u); +}