Index: include/llvm/Support/BranchProbability.h =================================================================== --- include/llvm/Support/BranchProbability.h +++ include/llvm/Support/BranchProbability.h @@ -112,6 +112,13 @@ return *this; } + BranchProbability &operator*=(uint32_t RHS) { + assert(N != UnknownN && + "Unknown probability cannot participate in arithmetics."); + N = (uint64_t(N) * RHS > D) ? D : N * RHS; + return *this; + } + BranchProbability &operator/=(uint32_t RHS) { assert(N != UnknownN && "Unknown probability cannot participate in arithmetics."); @@ -135,6 +142,11 @@ return Prob *= RHS; } + BranchProbability operator*(uint32_t RHS) const { + BranchProbability Prob(*this); + return Prob *= RHS; + } + BranchProbability operator/(uint32_t RHS) const { BranchProbability Prob(*this); return Prob /= RHS; Index: lib/Analysis/BranchProbabilityInfo.cpp =================================================================== --- lib/Analysis/BranchProbabilityInfo.cpp +++ lib/Analysis/BranchProbabilityInfo.cpp @@ -301,6 +301,8 @@ WeightSum += Weights[i]; } } + assert(WeightSum <= UINT32_MAX && + "Expected weights to scale down to 32 bits"); if (WeightSum == 0 || ReachableIdxs.size() == 0) { for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) @@ -328,21 +330,18 @@ // the difference between reachable blocks. if (ToDistribute > BranchProbability::getZero()) { BranchProbability PerEdge = ToDistribute / ReachableIdxs.size(); - for (auto i : ReachableIdxs) { + for (auto i : ReachableIdxs) BP[i] += PerEdge; - ToDistribute -= PerEdge; - } // Tail goes to the first reachable edge. - BP[ReachableIdxs[0]] += ToDistribute; + BranchProbability Tail = ToDistribute - (PerEdge * ReachableIdxs.size()); + assert(Tail.getNumerator() < ReachableIdxs.size() && "Tail is too big!"); + BP[ReachableIdxs[0]] += Tail; } } for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) setEdgeProbability(BB, i, BP[i]); - assert(WeightSum <= UINT32_MAX && - "Expected weights to scale down to 32 bits"); - return true; } Index: unittests/Support/BranchProbabilityTest.cpp =================================================================== --- unittests/Support/BranchProbabilityTest.cpp +++ unittests/Support/BranchProbabilityTest.cpp @@ -115,6 +115,54 @@ EXPECT_FALSE(BigZero >= BigOne); } +TEST(BranchProbabilityTest, ArithmeticOperators) { + BP Z(0, 1); + BP O(1, 1); + BP H(1, 2); + BP Q(1, 4); + BP Q3(3, 4); + + EXPECT_EQ(Z + O, O); + EXPECT_EQ(H + Z, H); + EXPECT_EQ(H + H, O); + EXPECT_EQ(Q + H, Q3); + EXPECT_EQ(Q + Q3, O); + EXPECT_EQ(H + Q3, O); + EXPECT_EQ(Q3 + Q3, O); + + EXPECT_EQ(Z - O, Z); + EXPECT_EQ(O - Z, O); + EXPECT_EQ(O - H, H); + EXPECT_EQ(O - Q, Q3); + EXPECT_EQ(Q3 - H, Q); + EXPECT_EQ(Q - H, Z); + EXPECT_EQ(Q - Q3, Z); + + EXPECT_EQ(Z * O, Z); + EXPECT_EQ(H * H, Q); + EXPECT_EQ(Q * O, Q); + EXPECT_EQ(O * O, O); + EXPECT_EQ(Z * Z, Z); + + EXPECT_EQ(Z * 3, Z); + EXPECT_EQ(Q * 3, Q3); + EXPECT_EQ(H * 3, O); + EXPECT_EQ(Q3 * 2, O); + EXPECT_EQ(O * UINT32_MAX, O); + + EXPECT_EQ(Z / 4, Z); + EXPECT_EQ(O / 4, Q); + EXPECT_EQ(Q3 / 3, Q); + EXPECT_EQ(H / 2, Q); + EXPECT_EQ(O / 2, H); + EXPECT_EQ(H / UINT32_MAX, Z); + + BP MIN(1, 1u << 31); + + EXPECT_EQ(O / UINT32_MAX, Z); + EXPECT_EQ(MIN * UINT32_MAX, O); +} + TEST(BranchProbabilityTest, getCompl) { EXPECT_EQ(BP(5, 7), BP(2, 7).getCompl()); EXPECT_EQ(BP(2, 7), BP(5, 7).getCompl());