Index: lib/Transforms/Scalar/CallSiteSplitting.cpp =================================================================== --- lib/Transforms/Scalar/CallSiteSplitting.cpp +++ lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -209,8 +209,6 @@ return CallSiteBB->canSplitPredecessors(); } -/// Return true if the CS is split into its new predecessors. -/// /// For each (predecessor, conditions from predecessors) pair, it will split the /// basic block containing the call site, hook it up to the predecessor and /// replace the call instruction with new call instructions, which contain @@ -255,12 +253,24 @@ static void splitCallSite( CallSite CS, const SmallVectorImpl> &Preds) { + bool IsMustTail = CS.isMustTailCall(); Instruction *Instr = CS.getInstruction(); BasicBlock *TailBB = Instr->getParent(); PHINode *CallPN = nullptr; - if (Instr->getNumUses()) - CallPN = PHINode::Create(Instr->getType(), Preds.size(), "phi.call"); + + Instruction *StopAt; + if (IsMustTail) { + // We need to copy everything including the terminator of the Tail block, + // since musttail must be followed by optional bitcast and ret. + // Insert temporary unreachable instruction to be used as a marker for + // `DuplicateInstructionsInSplitBetween`. + StopAt = new UnreachableInst(TailBB->getContext(), TailBB); + } else { + StopAt = &*std::next(Instr->getIterator()); + if (Instr->getNumUses()) + CallPN = PHINode::Create(Instr->getType(), Preds.size(), "phi.call"); + } DEBUG(dbgs() << "split call-site : " << *Instr << " into \n"); @@ -271,11 +281,18 @@ for (unsigned i = 0; i < Preds.size(); i++) { BasicBlock *PredBB = Preds[i].first; BasicBlock *SplitBlock = DuplicateInstructionsInSplitBetween( - TailBB, PredBB, &*std::next(Instr->getIterator()), ValueToValueMaps[i]); + TailBB, PredBB, StopAt, ValueToValueMaps[i]); assert(SplitBlock && "Unexpected new basic block split."); - Instruction *NewCI = - &*std::prev(SplitBlock->getTerminator()->getIterator()); + Instruction *NewCI; + + if (IsMustTail) { + // The pre-last instruction is `unreachable`, do a full lookup to + // find the CI's copy. + NewCI = dyn_cast(&*ValueToValueMaps[i][Instr]); + } else { + NewCI = &*std::prev(SplitBlock->getTerminator()->getIterator()); + } CallSite NewCS(NewCI); addConditions(NewCS, Preds[i].second); @@ -309,7 +326,7 @@ // instruction, so we do not end up deleting them. By using reverse-order, we // do not introduce unnecessary PHI nodes for def-use chains from the call // instruction to the beginning of the block. - auto I = Instr->getReverseIterator(); + auto I = (&*std::prev(StopAt->getIterator()))->getReverseIterator(); while (I != TailBB->rend()) { Instruction *CurrentI = &*I++; if (!CurrentI->use_empty()) { @@ -330,6 +347,26 @@ break; } + // At this point the splits have following instructions at the end: + // + // %c = call musttail ... + // ret ... %c + // br TBB ; or Tail + // + // Remove `br`, effectively unlinking the splits from the Tail. + // Once this is done - remove the Tail itself. + if (IsMustTail) { + DEBUG(dbgs() << "clean up musttail splits\n"); + + SmallVector Splits(predecessors((TailBB))); + assert(Splits.size() == 2 && "Expected exactly 2 splits!"); + + for (unsigned i = 0; i < Splits.size(); i++) + Splits[i]->getTerminator()->eraseFromParent(); + + TailBB->eraseFromParent(); + } + NumCallSiteSplit++; } @@ -415,7 +452,17 @@ Function *Callee = CS.getCalledFunction(); if (!Callee || Callee->isDeclaration()) continue; + + // Successful musttail call-site splits result in erased CI and erased BB. + // Check if such path is possible before attempting the splitting. + bool IsMustTail = CS.isMustTailCall(); + Changed |= tryToSplitCallSite(CS, TTI); + + // There're no interesting instructions after this. The call site + // itself might have been erased on splitting. + if (IsMustTail) + break; } } return Changed; Index: test/Transforms/CallSiteSplitting/musttail.ll =================================================================== --- /dev/null +++ test/Transforms/CallSiteSplitting/musttail.ll @@ -0,0 +1,29 @@ +; RUN: opt < %s -callsite-splitting -S | FileCheck %s + +;CHECK-LABEL: @caller +;CHECK-LABEL: Top.split: +;CHECK: %ca1 = musttail call i8* @callee(i8* null, i8* %b) +;CHECK: %cb2 = bitcast i8* %ca1 to i8* +;CHECK: ret i8* %cb2 +;CHECK-LABEL: TBB.split +;CHECK: %ca3 = musttail call i8* @callee(i8* nonnull %a, i8* null) +;CHECK: %cb4 = bitcast i8* %ca3 to i8* +;CHECK: ret i8* %cb4 +define i8* @caller(i8* %a, i8* %b) { +Top: + %c = icmp eq i8* %a, null + br i1 %c, label %Tail, label %TBB +TBB: + %c2 = icmp eq i8* %b, null + br i1 %c2, label %Tail, label %End +Tail: + %ca = musttail call i8* @callee(i8* %a, i8* %b) + %cb = bitcast i8* %ca to i8* + ret i8* %cb +End: + ret i8* null +} + +define i8* @callee(i8* %a, i8* %b) noinline { + ret i8* %a +}