Index: lib/Transforms/Scalar/CallSiteSplitting.cpp
===================================================================
--- lib/Transforms/Scalar/CallSiteSplitting.cpp
+++ lib/Transforms/Scalar/CallSiteSplitting.cpp
@@ -209,8 +209,89 @@
   return CallSiteBB->canSplitPredecessors();
 }
 
-/// Return true if the CS is split into its new predecessors.
+static Instruction *cloneReturnInstForMustTail(Instruction *I,
+                                               Instruction *Before, Value *V) {
+  Instruction *Copy = I->clone();
+  Copy->setName(I->getName());
+  Copy->insertBefore(Before);
+  if (V)
+    Copy->setOperand(0, V);
+  return Copy;
+}
+
+/// `musttail` calls must be followed by optional `bitcast`, and `ret`
+/// instructions. `splitCallSite` however produces blocks with copies of the
+/// call followed by `br` instructions, which violates the invariant.
 ///
+/// Restore the invariant in each of the predecessor blocks by moving the
+/// `bitcast` and `ret` from the tail block into the predecessor.
+static void fixMustTailCallSite(BasicBlock *TailBB) {
+  DEBUG(dbgs() << "fix musttail splits:\n");
+
+  bool IsVoid = TailBB->getParent()->getReturnType()->isVoidTy();
+
+  BasicBlock::iterator BI = TailBB->begin();
+  // PHINode is present for non-void return values
+  PHINode *PN = dyn_cast<PHINode>(&*BI);
+  if (PN)
+    ++BI;
+
+  // Optional BitCastInst
+  BitCastInst *BCI = dyn_cast<BitCastInst>(&*BI);
+  if (BCI)
+    ++BI;
+
+  // Mandatory ReturnInst
+  ReturnInst *RI = dyn_cast<ReturnInst>(&*BI);
+  assert(RI && "`musttail` call must be followed by `ret` instruction");
+
+  assert(std::next(BI) == TailBB->end() && "`ret` doesn't end the block");
+
+  SmallVector<BasicBlock *, 2> Splits(predecessors((TailBB)));
+  assert(Splits.size() == 2 && "Expected exactly 2 splits!");
+
+  for (unsigned i = 0; i < Splits.size(); i++) {
+    BasicBlock *Split = Splits[i];
+
+    TerminatorInst *TI = Split->getTerminator();
+
+    CallInst *CI = dyn_cast<CallInst>(&*std::prev(TI->getIterator()));
+    assert(CI && "pred blocks must have the call");
+    assert(IsVoid == CI->getFunctionType()->getReturnType()->isVoidTy() &&
+        "Return type mismatch between musttail call and the caller function");
+    DEBUG(dbgs() << "    feeding CI : " << *CI << " into\n");
+
+    Value *V = CI;
+
+    // Add bitcast
+    if (BCI) {
+      assert(!IsVoid && "Can\'t bitcast void");
+      V = cloneReturnInstForMustTail(BCI, TI, V);
+      DEBUG(dbgs() << "        " << *V << " into\n");
+    }
+
+    // Add return
+    Instruction *NewRI = cloneReturnInstForMustTail(RI, TI,
+                                                    IsVoid ? nullptr : V);
+    DEBUG(dbgs() << "        " << *NewRI << "\n");
+
+    // Remove terminating `br` instruction
+    TI->eraseFromParent();
+
+    if (PN)
+      PN->removeIncomingValue(Split, true);
+  }
+
+  // Remove unused instructions
+  if (RI)
+    RI->eraseFromParent();
+  if (BCI)
+    BCI->eraseFromParent();
+
+  // The block is no longer reachable
+  TailBB->eraseFromParent();
+}
+
 /// For each (predecessor, conditions from predecessors) pair, it will split the
 /// basic block containing the call site, hook it up to the predecessor and
 /// replace the call instruction with new call instructions, which contain
@@ -257,6 +338,7 @@
     const SmallVectorImpl<std::pair<BasicBlock *, ConditionsTy>> &Preds) {
   Instruction *Instr = CS.getInstruction();
   BasicBlock *TailBB = Instr->getParent();
+  bool IsMustTailCall = CS.isMustTailCall();
 
   PHINode *CallPN = nullptr;
   if (Instr->getNumUses())
@@ -330,6 +412,9 @@
       break;
   }
 
+  if (IsMustTailCall)
+    fixMustTailCallSite(TailBB);
+
   NumCallSiteSplit++;
 }
 
@@ -415,7 +500,17 @@
       Function *Callee = CS.getCalledFunction();
       if (!Callee || Callee->isDeclaration())
         continue;
+
+      // Successful musttail call-site splits result in erased CI and erased BB.
+      // Check if such path is possible before attempting the splitting.
+      bool IsMustTail = CS.isMustTailCall();
+
       Changed |= tryToSplitCallSite(CS, TTI);
+
+      // There're no interesting instructions after this. The call site
+      // itself might have been erased on splitting.
+      if (IsMustTail)
+        break;
     }
   }
   return Changed;
Index: test/Transforms/CallSiteSplitting/musttail.ll
===================================================================
--- /dev/null
+++ test/Transforms/CallSiteSplitting/musttail.ll
@@ -0,0 +1,75 @@
+; RUN: opt < %s -callsite-splitting -S | FileCheck %s
+
+;CHECK-LABEL: @caller
+;CHECK-LABEL: Top.split:
+;CHECK: %ca1 = musttail call i8* @callee(i8* null, i8* %b)
+;CHECK: %cb4 = bitcast i8* %ca1 to i8*
+;CHECK: ret i8* %cb4
+;CHECK-LABEL: TBB.split
+;CHECK: %ca2 = musttail call i8* @callee(i8* nonnull %a, i8* null)
+;CHECK: %cb3 = bitcast i8* %ca2 to i8*
+;CHECK: ret i8* %cb3
+define i8* @caller(i8* %a, i8* %b) {
+Top:
+  %c = icmp eq i8* %a, null
+  br i1 %c, label %Tail, label %TBB
+TBB:
+  %c2 = icmp eq i8* %b, null
+  br i1 %c2, label %Tail, label %End
+Tail:
+  %ca = musttail call i8* @callee(i8* %a, i8* %b)
+  %cb = bitcast i8* %ca to i8*
+  ret i8* %cb
+End:
+  ret i8* null
+}
+
+define i8* @callee(i8* %a, i8* %b) noinline {
+  ret i8* %a
+}
+
+;CHECK-LABEL: @no_cast_caller
+;CHECK-LABEL: Top.split:
+;CHECK: %ca1 = musttail call i8* @callee(i8* null, i8* %b)
+;CHECK: ret i8* %ca1
+;CHECK-LABEL: TBB.split
+;CHECK: %ca2 = musttail call i8* @callee(i8* nonnull %a, i8* null)
+;CHECK: ret i8* %ca2
+define i8* @no_cast_caller(i8* %a, i8* %b) {
+Top:
+  %c = icmp eq i8* %a, null
+  br i1 %c, label %Tail, label %TBB
+TBB:
+  %c2 = icmp eq i8* %b, null
+  br i1 %c2, label %Tail, label %End
+Tail:
+  %ca = musttail call i8* @callee(i8* %a, i8* %b)
+  ret i8* %ca
+End:
+  ret i8* null
+}
+
+;CHECK-LABEL: @void_caller
+;CHECK-LABEL: Top.split:
+;CHECK: musttail call void @void_callee(i8* null, i8* %b)
+;CHECK: ret void
+;CHECK-LABEL: TBB.split
+;CHECK: musttail call void @void_callee(i8* nonnull %a, i8* null)
+;CHECK: ret void
+define void @void_caller(i8* %a, i8* %b) {
+Top:
+  %c = icmp eq i8* %a, null
+  br i1 %c, label %Tail, label %TBB
+TBB:
+  %c2 = icmp eq i8* %b, null
+  br i1 %c2, label %Tail, label %End
+Tail:
+  musttail call void @void_callee(i8* %a, i8* %b)
+  ret void
+End:
+  ret void
+}
+
+define void @void_callee(i8* %a, i8* %b) noinline {
+  ret void
+}