diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -269,6 +269,8 @@ One.extractBits(NumBits, BitPosition)); } + static KnownBits computeForMul(const KnownBits &LHS, const KnownBits &RHS); + /// Update known bits based on ANDing with RHS. KnownBits &operator&=(const KnownBits &RHS); diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -415,7 +415,6 @@ const APInt &DemandedElts, KnownBits &Known, KnownBits &Known2, unsigned Depth, const Query &Q) { - unsigned BitWidth = Known.getBitWidth(); computeKnownBits(Op1, DemandedElts, Known, Depth + 1, Q); computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q); @@ -433,7 +432,7 @@ bool isKnownNegativeOp0 = Known2.isNegative(); // The product of two numbers with the same sign is non-negative. isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) || - (isKnownNonNegativeOp1 && isKnownNonNegativeOp0); + (isKnownNonNegativeOp1 && isKnownNonNegativeOp0); // The product of a negative number and a non-negative number is either // negative or zero. if (!isKnownNonNegative) @@ -444,78 +443,7 @@ } } - assert(!Known.hasConflict() && !Known2.hasConflict()); - // Compute a conservative estimate for high known-0 bits. - unsigned LeadZ = std::max(Known.countMinLeadingZeros() + - Known2.countMinLeadingZeros(), - BitWidth) - BitWidth; - LeadZ = std::min(LeadZ, BitWidth); - - // The result of the bottom bits of an integer multiply can be - // inferred by looking at the bottom bits of both operands and - // multiplying them together. - // We can infer at least the minimum number of known trailing bits - // of both operands. Depending on number of trailing zeros, we can - // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming - // a and b are divisible by m and n respectively. - // We then calculate how many of those bits are inferrable and set - // the output. For example, the i8 mul: - // a = XXXX1100 (12) - // b = XXXX1110 (14) - // We know the bottom 3 bits are zero since the first can be divided by - // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4). - // Applying the multiplication to the trimmed arguments gets: - // XX11 (3) - // X111 (7) - // ------- - // XX11 - // XX11 - // XX11 - // XX11 - // ------- - // XXXXX01 - // Which allows us to infer the 2 LSBs. Since we're multiplying the result - // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits. - // The proof for this can be described as: - // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) && - // (C7 == (1 << (umin(countTrailingZeros(C1), C5) + - // umin(countTrailingZeros(C2), C6) + - // umin(C5 - umin(countTrailingZeros(C1), C5), - // C6 - umin(countTrailingZeros(C2), C6)))) - 1) - // %aa = shl i8 %a, C5 - // %bb = shl i8 %b, C6 - // %aaa = or i8 %aa, C1 - // %bbb = or i8 %bb, C2 - // %mul = mul i8 %aaa, %bbb - // %mask = and i8 %mul, C7 - // => - // %mask = i8 ((C1*C2)&C7) - // Where C5, C6 describe the known bits of %a, %b - // C1, C2 describe the known bottom bits of %a, %b. - // C7 describes the mask of the known bits of the result. - APInt Bottom0 = Known.One; - APInt Bottom1 = Known2.One; - - // How many times we'd be able to divide each argument by 2 (shr by 1). - // This gives us the number of trailing zeros on the multiplication result. - unsigned TrailBitsKnown0 = (Known.Zero | Known.One).countTrailingOnes(); - unsigned TrailBitsKnown1 = (Known2.Zero | Known2.One).countTrailingOnes(); - unsigned TrailZero0 = Known.countMinTrailingZeros(); - unsigned TrailZero1 = Known2.countMinTrailingZeros(); - unsigned TrailZ = TrailZero0 + TrailZero1; - - // Figure out the fewest known-bits operand. - unsigned SmallestOperand = std::min(TrailBitsKnown0 - TrailZero0, - TrailBitsKnown1 - TrailZero1); - unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth); - - APInt BottomKnown = Bottom0.getLoBits(TrailBitsKnown0) * - Bottom1.getLoBits(TrailBitsKnown1); - - Known.resetAll(); - Known.Zero.setHighBits(LeadZ); - Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); - Known.One |= BottomKnown.getLoBits(ResultBitsKnown); + Known = KnownBits::computeForMul(Known, Known2); // Only make use of no-wrap flags if we failed to compute the sign bit // directly. This matters if the multiplication always overflows, in diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -163,6 +163,85 @@ return KnownAbs; } +KnownBits KnownBits::computeForMul(const KnownBits &LHS, const KnownBits &RHS) { + unsigned BitWidth = LHS.getBitWidth(); + + assert(!LHS.hasConflict() && !RHS.hasConflict()); + // Compute a conservative estimate for high known-0 bits. + unsigned LeadZ = + std::max(LHS.countMinLeadingZeros() + RHS.countMinLeadingZeros(), + BitWidth) - + BitWidth; + LeadZ = std::min(LeadZ, BitWidth); + + // The result of the bottom bits of an integer multiply can be + // inferred by looking at the bottom bits of both operands and + // multiplying them together. + // We can infer at least the minimum number of known trailing bits + // of both operands. Depending on number of trailing zeros, we can + // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming + // a and b are divisible by m and n respectively. + // We then calculate how many of those bits are inferrable and set + // the output. For example, the i8 mul: + // a = XXXX1100 (12) + // b = XXXX1110 (14) + // We know the bottom 3 bits are zero since the first can be divided by + // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4). + // Applying the multiplication to the trimmed arguments gets: + // XX11 (3) + // X111 (7) + // ------- + // XX11 + // XX11 + // XX11 + // XX11 + // ------- + // XXXXX01 + // Which allows us to infer the 2 LSBs. Since we're multiplying the result + // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits. + // The proof for this can be described as: + // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) && + // (C7 == (1 << (umin(countTrailingZeros(C1), C5) + + // umin(countTrailingZeros(C2), C6) + + // umin(C5 - umin(countTrailingZeros(C1), C5), + // C6 - umin(countTrailingZeros(C2), C6)))) - 1) + // %aa = shl i8 %a, C5 + // %bb = shl i8 %b, C6 + // %aaa = or i8 %aa, C1 + // %bbb = or i8 %bb, C2 + // %mul = mul i8 %aaa, %bbb + // %mask = and i8 %mul, C7 + // => + // %mask = i8 ((C1*C2)&C7) + // Where C5, C6 describe the known bits of %a, %b + // C1, C2 describe the known bottom bits of %a, %b. + // C7 describes the mask of the known bits of the result. + APInt Bottom0 = LHS.One; + APInt Bottom1 = RHS.One; + + // How many times we'd be able to divide each argument by 2 (shr by 1). + // This gives us the number of trailing zeros on the multiplication result. + unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countTrailingOnes(); + unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countTrailingOnes(); + unsigned TrailZero0 = LHS.countMinTrailingZeros(); + unsigned TrailZero1 = RHS.countMinTrailingZeros(); + unsigned TrailZ = TrailZero0 + TrailZero1; + + // Figure out the fewest known-bits operand. + unsigned SmallestOperand = + std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1); + unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth); + + APInt BottomKnown = + Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1); + + KnownBits Res(BitWidth); + Res.Zero.setHighBits(LeadZ); + Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); + Res.One = BottomKnown.getLoBits(ResultBitsKnown); + return Res; +} + KnownBits &KnownBits::operator&=(const KnownBits &RHS) { // Result bit is 0 if either operand bit is 0. Zero |= RHS.Zero; diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -112,6 +112,7 @@ KnownBits KnownUMin(KnownAnd); KnownBits KnownSMax(KnownAnd); KnownBits KnownSMin(KnownAnd); + KnownBits KnownMul(KnownAnd); ForeachNumInKnownBits(Known1, [&](const APInt &N1) { ForeachNumInKnownBits(Known2, [&](const APInt &N2) { @@ -144,6 +145,10 @@ Res = APIntOps::smin(N1, N2); KnownSMin.One &= Res; KnownSMin.Zero &= ~Res; + + Res = N1 * N2; + KnownMul.One &= Res; + KnownMul.Zero &= ~Res; }); }); @@ -174,6 +179,11 @@ KnownBits ComputedSMin = KnownBits::smin(Known1, Known2); EXPECT_EQ(KnownSMin.Zero, ComputedSMin.Zero); EXPECT_EQ(KnownSMin.One, ComputedSMin.One); + + // ComputedMul should have at most all the information from KnownMul. + KnownBits ComputedMul = KnownBits::computeForMul(Known1, Known2); + EXPECT_TRUE(ComputedMul.Zero.isSubsetOf(KnownMul.Zero)); + EXPECT_TRUE(ComputedMul.One.isSubsetOf(KnownMul.One)); }); }); }