Index: llvm/trunk/include/llvm/Support/BranchProbability.h =================================================================== --- llvm/trunk/include/llvm/Support/BranchProbability.h +++ llvm/trunk/include/llvm/Support/BranchProbability.h @@ -183,17 +183,32 @@ if (Begin == End) return; - auto UnknownProbCount = - std::count(Begin, End, BranchProbability::getUnknown()); - assert((UnknownProbCount == 0 || - UnknownProbCount == std::distance(Begin, End)) && - "Cannot normalize probabilities with known and unknown ones."); - (void)UnknownProbCount; - - uint64_t Sum = std::accumulate( - Begin, End, uint64_t(0), - [](uint64_t S, const BranchProbability &BP) { return S + BP.N; }); + unsigned UnknownProbCount = 0; + uint64_t Sum = std::accumulate(Begin, End, uint64_t(0), + [&](uint64_t S, const BranchProbability &BP) { + if (!BP.isUnknown()) + return S + BP.N; + UnknownProbCount++; + return S; + }); + + if (UnknownProbCount > 0) { + BranchProbability ProbForUnknown = BranchProbability::getZero(); + // If the sum of all known probabilities is less than one, evenly distribute + // the complement of sum to unknown probabilities. Otherwise, set unknown + // probabilities to zeros and continue to normalize known probabilities. + if (Sum < BranchProbability::getDenominator()) + ProbForUnknown = BranchProbability::getRaw( + (BranchProbability::getDenominator() - Sum) / UnknownProbCount); + + std::replace_if(Begin, End, + [](const BranchProbability &BP) { return BP.isUnknown(); }, + ProbForUnknown); + if (Sum <= BranchProbability::getDenominator()) + return; + } + if (Sum == 0) { BranchProbability BP(1, std::distance(Begin, End)); std::fill(Begin, End, BP); Index: llvm/trunk/unittests/Support/BranchProbabilityTest.cpp =================================================================== --- llvm/trunk/unittests/Support/BranchProbabilityTest.cpp +++ llvm/trunk/unittests/Support/BranchProbabilityTest.cpp @@ -288,6 +288,7 @@ } TEST(BranchProbabilityTest, NormalizeProbabilities) { + const auto UnknownProb = BranchProbability::getUnknown(); { SmallVector Probs{{0, 1}, {0, 1}}; BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end()); @@ -322,6 +323,36 @@ EXPECT_EQ(BranchProbability::getDenominator() / 3 + 1, Probs[2].getNumerator()); } + { + SmallVector Probs{{0, 1}, UnknownProb}; + BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end()); + EXPECT_EQ(0, Probs[0].getNumerator()); + EXPECT_EQ(BranchProbability::getDenominator(), Probs[1].getNumerator()); + } + { + SmallVector Probs{{1, 1}, UnknownProb}; + BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end()); + EXPECT_EQ(BranchProbability::getDenominator(), Probs[0].getNumerator()); + EXPECT_EQ(0, Probs[1].getNumerator()); + } + { + SmallVector Probs{{1, 2}, UnknownProb}; + BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end()); + EXPECT_EQ(BranchProbability::getDenominator() / 2, Probs[0].getNumerator()); + EXPECT_EQ(BranchProbability::getDenominator() / 2, Probs[1].getNumerator()); + } + { + SmallVector Probs{ + {1, 2}, {1, 2}, {1, 2}, UnknownProb}; + BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end()); + EXPECT_EQ(BranchProbability::getDenominator() / 3 + 1, + Probs[0].getNumerator()); + EXPECT_EQ(BranchProbability::getDenominator() / 3 + 1, + Probs[1].getNumerator()); + EXPECT_EQ(BranchProbability::getDenominator() / 3 + 1, + Probs[2].getNumerator()); + EXPECT_EQ(0, Probs[3].getNumerator()); + } } }