diff --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h
--- a/llvm/include/llvm/IR/Instruction.h
+++ b/llvm/include/llvm/IR/Instruction.h
@@ -355,6 +355,11 @@
   /// Returns false if no metadata or invalid metadata was found.
   bool extractProfMetadata(uint64_t &TrueVal, uint64_t &FalseVal) const;
 
+  /// Retrieve the raw weight values of a conditional branch, select, or switch.
+  /// Returns true on success with profile weights filled in.
+  /// Returns false if no metadata or invalid metadata was found.
+  bool extractProfMetadata(SmallVector<uint32_t> &Weights) const;
+
   /// Retrieve total raw weight values of a branch.
   /// Returns true on success with profile total weights filled in.
   /// Returns false if no metadata was found.
diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
--- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp
+++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -30,6 +30,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/InitializePasses.h"
@@ -400,24 +401,18 @@
   SmallVector<uint32_t, 2> Weights;
   SmallVector<unsigned, 2> UnreachableIdxs;
   SmallVector<unsigned, 2> ReachableIdxs;
-  Weights.reserve(TI->getNumSuccessors());
-  for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) {
-    ConstantInt *Weight =
-        mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
-    if (!Weight)
-      return false;
-    assert(Weight->getValue().getActiveBits() <= 32 &&
-           "Too many bits for uint32_t");
-    Weights.push_back(Weight->getZExtValue());
-    WeightSum += Weights.back();
+
+  extractBranchWeights(*TI, Weights);
+  for (unsigned I = 0, E = Weights.size(); I != E; ++I) {
+      WeightSum += Weights[I];
     const LoopBlock SrcLoopBB = getLoopBlock(BB);
-    const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I - 1));
+    const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I));
     auto EstimatedWeight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB});
     if (EstimatedWeight &&
         *EstimatedWeight <= static_cast<uint32_t>(BlockExecWeight::UNREACHABLE))
-      UnreachableIdxs.push_back(I - 1);
+      UnreachableIdxs.push_back(I);
     else
-      ReachableIdxs.push_back(I - 1);
+      ReachableIdxs.push_back(I);
   }
   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
 
diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp
--- a/llvm/lib/IR/Metadata.cpp
+++ b/llvm/lib/IR/Metadata.cpp
@@ -40,6 +40,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/TrackingMDRef.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
@@ -1505,26 +1506,24 @@
   assert(
       (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select) &&
       "Looking for branch weights on something besides branch or select");
-
-  auto *ProfileData = getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData || ProfileData->getNumOperands() != 3)
-    return false;
-
-  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
-  if (!ProfDataName || !ProfDataName->getString().equals("branch_weights"))
-    return false;
-
-  auto *CITrue = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1));
-  auto *CIFalse = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2));
-  if (!CITrue || !CIFalse)
+  SmallVector<uint32_t>Weights;
+  if(!extractBranchWeights(*this, Weights) || Weights.size() != 3)
     return false;
 
-  TrueVal = CITrue->getValue().getZExtValue();
-  FalseVal = CIFalse->getValue().getZExtValue();
+  TrueVal = Weights[0];
+  FalseVal = Weights[1];
 
   return true;
 }
 
+bool Instruction::extractProfMetadata(SmallVector<uint32_t> &Weights) const {
+  assert(
+      (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select) &&
+      "Looking for branch weights on something besides branch or select");
+  return extractBranchWeights(*this, Weights);
+}
+
+
 bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const {
   assert(
       (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select ||
@@ -1533,32 +1532,7 @@
        getOpcode() == Instruction::Switch) &&
       "Looking for branch weights on something besides branch");
 
-  TotalVal = 0;
-  auto *ProfileData = getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData)
-    return false;
-
-  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
-  if (!ProfDataName)
-    return false;
-
-  if (ProfDataName->getString().equals("branch_weights")) {
-    TotalVal = 0;
-    for (unsigned i = 1; i < ProfileData->getNumOperands(); i++) {
-      auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i));
-      if (!V)
-        return false;
-      TotalVal += V->getValue().getZExtValue();
-    }
-    return true;
-  } else if (ProfDataName->getString().equals("VP") &&
-             ProfileData->getNumOperands() > 3) {
-    TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
-                   ->getValue()
-                   .getZExtValue();
-    return true;
-  }
-  return false;
+ return extractTotalBranchWeights(getMetadata(LLVMContext::MD_prof), TotalVal);
 }
 
 void GlobalObject::copyMetadata(const GlobalObject *Other, unsigned Offset) {
diff --git a/llvm/lib/Transforms/Utils/MisExpect.cpp b/llvm/lib/Transforms/Utils/MisExpect.cpp
--- a/llvm/lib/Transforms/Utils/MisExpect.cpp
+++ b/llvm/lib/Transforms/Utils/MisExpect.cpp
@@ -35,6 +35,7 @@
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -118,34 +119,6 @@
 namespace llvm {
 namespace misexpect {
 
-// Helper function to extract branch weights into a vector
-Optional<SmallVector<uint32_t, 4>> extractWeights(Instruction *I,
-                                                  LLVMContext &Ctx) {
-  assert(I && "MisExpect::extractWeights given invalid pointer");
-
-  auto *ProfileData = I->getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData)
-    return None;
-
-  unsigned NOps = ProfileData->getNumOperands();
-  if (NOps < 3)
-    return None;
-
-  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
-  if (!ProfDataName || !ProfDataName->getString().equals("branch_weights"))
-    return None;
-
-  SmallVector<uint32_t, 4> Weights(NOps - 1);
-  for (unsigned Idx = 1; Idx < NOps; Idx++) {
-    ConstantInt *Value =
-        mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
-    uint32_t V = Value->getZExtValue();
-    Weights[Idx - 1] = V;
-  }
-
-  return Weights;
-}
-
 // TODO: when clang allows c++17, use std::clamp instead
 uint32_t clamp(uint64_t value, uint32_t low, uint32_t hi) {
   if (value > hi)
@@ -218,19 +191,17 @@
 
 void checkBackendInstrumentation(Instruction &I,
                                  const ArrayRef<uint32_t> RealWeights) {
-  auto ExpectedWeightsOpt = extractWeights(&I, I.getContext());
-  if (!ExpectedWeightsOpt)
+  SmallVector<uint32_t> ExpectedWeights;
+  if (!extractBranchWeights(I, ExpectedWeights))
     return;
-  auto ExpectedWeights = ExpectedWeightsOpt.getValue();
   verifyMisExpect(I, RealWeights, ExpectedWeights);
 }
 
 void checkFrontendInstrumentation(Instruction &I,
                                   const ArrayRef<uint32_t> ExpectedWeights) {
-  auto RealWeightsOpt = extractWeights(&I, I.getContext());
-  if (!RealWeightsOpt)
+  SmallVector<uint32_t> RealWeights;
+  if (!extractBranchWeights(I, RealWeights))
     return;
-  auto RealWeights = RealWeightsOpt.getValue();
   verifyMisExpect(I, RealWeights, ExpectedWeights);
 }
 
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -57,6 +57,7 @@
 #include "llvm/IR/NoFolder.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -1056,15 +1057,6 @@
   return LHS->getValue().ult(RHS->getValue()) ? 1 : -1;
 }
 
-static inline bool HasBranchWeights(const Instruction *I) {
-  MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof);
-  if (ProfMD && ProfMD->getOperand(0))
-    if (MDString *MDS = dyn_cast<MDString>(ProfMD->getOperand(0)))
-      return MDS->getString().equals("branch_weights");
-
-  return false;
-}
-
 /// Get Weights of a given terminator, the default weight is at the front
 /// of the vector. If TI is a conditional eq, we need to swap the branch-weight
 /// metadata.
@@ -1183,8 +1175,8 @@
 
   // Update the branch weight metadata along the way
   SmallVector<uint64_t, 8> Weights;
-  bool PredHasWeights = HasBranchWeights(PTI);
-  bool SuccHasWeights = HasBranchWeights(TI);
+  bool PredHasWeights = hasBranchWeightMD(*PTI);
+  bool SuccHasWeights = hasBranchWeightMD(*TI);
 
   if (PredHasWeights) {
     GetBranchWeights(PTI, Weights);
@@ -4408,7 +4400,7 @@
   // Get weight for TrueBB and FalseBB.
   uint32_t TrueWeight = 0, FalseWeight = 0;
   SmallVector<uint64_t, 8> Weights;
-  bool HasWeights = HasBranchWeights(SI);
+  bool HasWeights = hasBranchWeightMD(*SI);
   if (HasWeights) {
     GetBranchWeights(SI, Weights);
     if (Weights.size() == 1 + SI->getNumCases()) {
@@ -5268,7 +5260,7 @@
   BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
 
   // Update weight for the newly-created conditional branch.
-  if (HasBranchWeights(SI)) {
+  if (hasBranchWeightMD(*SI)) {
     SmallVector<uint64_t, 8> Weights;
     GetBranchWeights(SI, Weights);
     if (Weights.size() == 1 + SI->getNumCases()) {