Index: lib/Analysis/ValueTracking.cpp =================================================================== --- lib/Analysis/ValueTracking.cpp +++ lib/Analysis/ValueTracking.cpp @@ -342,21 +342,78 @@ } } - // If low bits are zero in either operand, output low known-0 bits. - // Also compute a conservative estimate for high known-0 bits. - // More trickiness is possible, but this is sufficient for the - // interesting case of alignment computation. - unsigned TrailZ = Known.countMinTrailingZeros() + - Known2.countMinTrailingZeros(); + assert(!Known.hasConflict() && !Known2.hasConflict()); + // Compute a conservative estimate for high known-0 bits. unsigned LeadZ = std::max(Known.countMinLeadingZeros() + Known2.countMinLeadingZeros(), BitWidth) - BitWidth; - - TrailZ = std::min(TrailZ, BitWidth); LeadZ = std::min(LeadZ, BitWidth); + + // The result of the bottom bits of an integer multiply can be + // inferred by looking at the bottom bits of both operands and + // multiplying them together. + // We can infer at least the minimum number of known trailing bits + // of both operands. Depending on number of trailing zeros, we can + // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming + // a and b are divisible by m and n respectively. + // We then calculate how many of those bits are inferrable and set + // the output. For example, the i8 mul: + // a = XXXX1100 (12) + // b = XXXX1110 (14) + // We know the bottom 3 bits are zero since the first can be divided by + // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4). + // Applying the multiplication to the trimmed arguments gets: + // XX11 (3) + // X111 (7) + // ------- + // XX11 + // XX11 + // XX11 + // XX11 + // ------- + // XXXXX01 + // Which allows us to infer the 2 LSBs. Since we're multiplying the result + // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits. + // The proof for this can be described as: + // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) && + // (C7 == (1 << (umin(countTrailingZeros(C1), C5) + + // umin(countTrailingZeros(C2), C6) + + // umin(C5 - umin(countTrailingZeros(C1), C5), + // C6 - umin(countTrailingZeros(C2), C6)))) - 1) + // %aa = shl i8 %a, C5 + // %bb = shl i8 %b, C6 + // %aaa = or i8 %aa, C1 + // %bbb = or i8 %bb, C2 + // %mul = mul i8 %aaa, %bbb + // %mask = and i8 %mul, C7 + // => + // %mask = i8 ((C1*C2)&C7) + // Where C5, C6 describe the known bits of %a, %b + // C1, C2 describe the known bottom bits of %a, %b. + // C7 describes the mask of the known bits of the result. + APInt Bottom0 = Known.One; + APInt Bottom1 = Known2.One; + + // How many times we'd be able to divide each argument by 2 (shr by 1). + // This gives us the number of trailing zeros on the multiplication result. + unsigned TrailBitsKnown0 = (Known.Zero | Known.One).countTrailingOnes(); + unsigned TrailBitsKnown1 = (Known2.Zero | Known2.One).countTrailingOnes(); + unsigned TrailZero0 = Known.countMinTrailingZeros(); + unsigned TrailZero1 = Known2.countMinTrailingZeros(); + unsigned TrailZ = TrailZero0 + TrailZero1; + + // Figure out the fewest known-bits operand. + unsigned SmallestOperand = std::min(TrailBitsKnown0 - TrailZero0, + TrailBitsKnown1 - TrailZero1); + unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth); + + APInt BottomKnown = Bottom0.getLoBits(TrailBitsKnown0) * + Bottom1.getLoBits(TrailBitsKnown1); + Known.resetAll(); - Known.Zero.setLowBits(TrailZ); Known.Zero.setHighBits(LeadZ); + Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); + Known.One |= BottomKnown.getLoBits(ResultBitsKnown); // Only make use of no-wrap flags if we failed to compute the sign bit // directly. This matters if the multiplication always overflows, in Index: unittests/Analysis/ValueTrackingTest.cpp =================================================================== --- unittests/Analysis/ValueTrackingTest.cpp +++ unittests/Analysis/ValueTrackingTest.cpp @@ -15,6 +15,7 @@ #include "llvm/IR/Module.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/KnownBits.h" #include "gtest/gtest.h" using namespace llvm; @@ -258,3 +259,57 @@ cast(F->getEntryBlock().getTerminator())->getOperand(0); EXPECT_EQ(ComputeNumSignBits(RVal, M->getDataLayout()), 1u); } + +TEST(ValueTracking, ComputeKnownBits) { + StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { " + " %ash = mul i32 %a, 8 " + " %aad = add i32 %ash, 7 " + " %aan = and i32 %aad, 4095 " + " %bsh = shl i32 %b, 4 " + " %bad = or i32 %bsh, 6 " + " %ban = and i32 %bad, 4095 " + " %mul = mul i32 %aan, %ban " + " ret i32 %mul " + "} "; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + assert(M && "Bad assembly?"); + + auto *F = M->getFunction("f"); + assert(F && "Bad assembly?"); + + auto *RVal = + cast(F->getEntryBlock().getTerminator())->getOperand(0); + auto Known = computeKnownBits(RVal, M->getDataLayout()); + ASSERT_FALSE(Known.hasConflict()); + EXPECT_EQ(Known.One.getZExtValue(), 10u); + EXPECT_EQ(Known.Zero.getZExtValue(), 4278190085u); +} + +TEST(ValueTracking, ComputeKnownMulBits) { + StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { " + " %aa = shl i32 %a, 5 " + " %bb = shl i32 %b, 5 " + " %aaa = or i32 %aa, 24 " + " %bbb = or i32 %bb, 28 " + " %mul = mul i32 %aaa, %bbb " + " ret i32 %mul " + "} "; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + assert(M && "Bad assembly?"); + + auto *F = M->getFunction("f"); + assert(F && "Bad assembly?"); + + auto *RVal = + cast(F->getEntryBlock().getTerminator())->getOperand(0); + auto Known = computeKnownBits(RVal, M->getDataLayout()); + ASSERT_FALSE(Known.hasConflict()); + EXPECT_EQ(Known.One.getZExtValue(), 32u); + EXPECT_EQ(Known.Zero.getZExtValue(), 95u); +}