diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -296,6 +296,12 @@ /// Compute known bits resulting from multiplying LHS and RHS. static KnownBits computeForMul(const KnownBits &LHS, const KnownBits &RHS); + /// Compute known bits from sign-extended multiply-hi. + static KnownBits mulhs(const KnownBits &LHS, const KnownBits &RHS); + + /// Compute known bits from zero-extended multiply-hi. + static KnownBits mulhu(const KnownBits &LHS, const KnownBits &RHS); + /// Compute known bits for udiv(LHS, RHS). static KnownBits udiv(const KnownBits &LHS, const KnownBits &RHS); diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -489,6 +489,24 @@ return Res; } +KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) { + unsigned BitWidth = LHS.getBitWidth(); + assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() && + !RHS.hasConflict() && "Operand mismatch"); + KnownBits WideLHS = LHS.sext(2 * BitWidth); + KnownBits WideRHS = RHS.sext(2 * BitWidth); + return computeForMul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth); +} + +KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) { + unsigned BitWidth = LHS.getBitWidth(); + assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() && + !RHS.hasConflict() && "Operand mismatch"); + KnownBits WideLHS = LHS.zext(2 * BitWidth); + KnownBits WideRHS = RHS.zext(2 * BitWidth); + return computeForMul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth); +} + KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS) { unsigned BitWidth = LHS.getBitWidth(); assert(!LHS.hasConflict() && !RHS.hasConflict()); diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -113,6 +113,8 @@ KnownBits KnownSMax(KnownAnd); KnownBits KnownSMin(KnownAnd); KnownBits KnownMul(KnownAnd); + KnownBits KnownMulHS(KnownAnd); + KnownBits KnownMulHU(KnownAnd); KnownBits KnownUDiv(KnownAnd); KnownBits KnownURem(KnownAnd); KnownBits KnownSRem(KnownAnd); @@ -156,6 +158,14 @@ KnownMul.One &= Res; KnownMul.Zero &= ~Res; + Res = (N1.sext(2 * Bits) * N2.sext(2 * Bits)).extractBits(Bits, Bits); + KnownMulHS.One &= Res; + KnownMulHS.Zero &= ~Res; + + Res = (N1.zext(2 * Bits) * N2.zext(2 * Bits)).extractBits(Bits, Bits); + KnownMulHU.One &= Res; + KnownMulHU.Zero &= ~Res; + if (!N2.isNullValue()) { Res = N1.udiv(N2); KnownUDiv.One &= Res; @@ -218,12 +228,20 @@ EXPECT_EQ(KnownSMin.Zero, ComputedSMin.Zero); EXPECT_EQ(KnownSMin.One, ComputedSMin.One); - // ComputedMul is conservatively correct, but not guaranteed to be + // The following are conservatively correct, but not guaranteed to be // precise. KnownBits ComputedMul = KnownBits::computeForMul(Known1, Known2); EXPECT_TRUE(ComputedMul.Zero.isSubsetOf(KnownMul.Zero)); EXPECT_TRUE(ComputedMul.One.isSubsetOf(KnownMul.One)); + KnownBits ComputedMulHS = KnownBits::mulhs(Known1, Known2); + EXPECT_TRUE(ComputedMulHS.Zero.isSubsetOf(KnownMulHS.Zero)); + EXPECT_TRUE(ComputedMulHS.One.isSubsetOf(KnownMulHS.One)); + + KnownBits ComputedMulHU = KnownBits::mulhu(Known1, Known2); + EXPECT_TRUE(ComputedMulHU.Zero.isSubsetOf(KnownMulHU.Zero)); + EXPECT_TRUE(ComputedMulHU.One.isSubsetOf(KnownMulHU.One)); + KnownBits ComputedUDiv = KnownBits::udiv(Known1, Known2); EXPECT_TRUE(ComputedUDiv.Zero.isSubsetOf(KnownUDiv.Zero)); EXPECT_TRUE(ComputedUDiv.One.isSubsetOf(KnownUDiv.One));