Index: lib/Transforms/Scalar/CallSiteSplitting.cpp =================================================================== --- lib/Transforms/Scalar/CallSiteSplitting.cpp +++ lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -209,8 +209,16 @@ return CallSiteBB->canSplitPredecessors(); } -/// Return true if the CS is split into its new predecessors. -/// +static Instruction *cloneInstForMustTail(Instruction *I, Instruction *Before, + Value *V) { + Instruction *Copy = I->clone(); + Copy->setName(I->getName()); + Copy->insertBefore(Before); + if (V) + Copy->setOperand(0, V); + return Copy; +} + /// For each (predecessor, conditions from predecessors) pair, it will split the /// basic block containing the call site, hook it up to the predecessor and /// replace the call instruction with new call instructions, which contain @@ -257,9 +265,29 @@ const SmallVectorImpl> &Preds) { Instruction *Instr = CS.getInstruction(); BasicBlock *TailBB = Instr->getParent(); + bool IsMustTailCall = CS.isMustTailCall(); + BitCastInst *MustTailBCI; + ReturnInst *MustTailRI; + bool IsVoid = TailBB->getParent()->getReturnType()->isVoidTy(); PHINode *CallPN = nullptr; - if (Instr->getNumUses()) + + // `musttail` calls must be followed by optional `bitcast`, and `ret` + // instructions. Save both for later patching in a loop below. + if (IsMustTailCall) { + auto II = std::next(Instr->getIterator()); + + MustTailBCI = dyn_cast(&*II); + if (MustTailBCI) { + ++II; + assert(!IsVoid && "bitcast after void musttail call"); + } + + MustTailRI = dyn_cast(&*II); + assert(MustTailRI && + "`musttail` call must be followed by `ret` instruction"); + assert(std::next(II) == TailBB->end() && "`ret` doesn't end the block"); + } else if (Instr->getNumUses()) CallPN = PHINode::Create(Instr->getType(), Preds.size(), "phi.call"); DEBUG(dbgs() << "split call-site : " << *Instr << " into \n"); @@ -293,6 +321,35 @@ << "\n"); if (CallPN) CallPN->addIncoming(NewCI, SplitBlock); + + // Clone and place bitcast and return instructions before `TI` + if (IsMustTailCall) { + TerminatorInst *TI = SplitBlock->getTerminator(); + Value *V = NewCI; + if (MustTailBCI) + V = cloneInstForMustTail(MustTailBCI, TI, V); + cloneInstForMustTail(MustTailRI, TI, IsVoid ? nullptr : V); + + // FIXME: remove TI here, once https://reviews.llvm.org/D43822 lands + } + } + + // FIXME: remove TI above, once https://reviews.llvm.org/D43822 lands + if (IsMustTailCall) { + SmallVector Splits(predecessors((TailBB))); + assert(Splits.size() == 2 && "Expected exactly 2 splits!"); + + for (unsigned i = 0; i < Splits.size(); i++) { + Splits[i]->getTerminator()->eraseFromParent(); + } + } + + NumCallSiteSplit++; + + // Erase the tail block once done with musttail patching + if (IsMustTailCall) { + TailBB->eraseFromParent(); + return; } auto *OriginalBegin = &*TailBB->begin(); @@ -329,8 +386,6 @@ if (CurrentI == OriginalBegin) break; } - - NumCallSiteSplit++; } // Return true if the call-site has an argument which is a PHI with only @@ -415,7 +470,17 @@ Function *Callee = CS.getCalledFunction(); if (!Callee || Callee->isDeclaration()) continue; + + // Successful musttail call-site splits result in erased CI and erased BB. + // Check if such path is possible before attempting the splitting. + bool IsMustTail = CS.isMustTailCall(); + Changed |= tryToSplitCallSite(CS, TTI); + + // There're no interesting instructions after this. The call site + // itself might have been erased on splitting. + if (IsMustTail) + break; } } return Changed; Index: test/Transforms/CallSiteSplitting/musttail.ll =================================================================== --- /dev/null +++ test/Transforms/CallSiteSplitting/musttail.ll @@ -0,0 +1,75 @@ +; RUN: opt < %s -callsite-splitting -S | FileCheck %s + +;CHECK-LABEL: @caller +;CHECK-LABEL: Top.split: +;CHECK: %ca1 = musttail call i8* @callee(i8* null, i8* %b) +;CHECK: %cb2 = bitcast i8* %ca1 to i8* +;CHECK: ret i8* %cb2 +;CHECK-LABEL: TBB.split +;CHECK: %ca3 = musttail call i8* @callee(i8* nonnull %a, i8* null) +;CHECK: %cb4 = bitcast i8* %ca3 to i8* +;CHECK: ret i8* %cb4 +define i8* @caller(i8* %a, i8* %b) { +Top: + %c = icmp eq i8* %a, null + br i1 %c, label %Tail, label %TBB +TBB: + %c2 = icmp eq i8* %b, null + br i1 %c2, label %Tail, label %End +Tail: + %ca = musttail call i8* @callee(i8* %a, i8* %b) + %cb = bitcast i8* %ca to i8* + ret i8* %cb +End: + ret i8* null +} + +define i8* @callee(i8* %a, i8* %b) noinline { + ret i8* %a +} + +;CHECK-LABEL: @no_cast_caller +;CHECK-LABEL: Top.split: +;CHECK: %ca1 = musttail call i8* @callee(i8* null, i8* %b) +;CHECK: ret i8* %ca1 +;CHECK-LABEL: TBB.split +;CHECK: %ca2 = musttail call i8* @callee(i8* nonnull %a, i8* null) +;CHECK: ret i8* %ca2 +define i8* @no_cast_caller(i8* %a, i8* %b) { +Top: + %c = icmp eq i8* %a, null + br i1 %c, label %Tail, label %TBB +TBB: + %c2 = icmp eq i8* %b, null + br i1 %c2, label %Tail, label %End +Tail: + %ca = musttail call i8* @callee(i8* %a, i8* %b) + ret i8* %ca +End: + ret i8* null +} + +;CHECK-LABEL: @void_caller +;CHECK-LABEL: Top.split: +;CHECK: musttail call void @void_callee(i8* null, i8* %b) +;CHECK: ret void +;CHECK-LABEL: TBB.split +;CHECK: musttail call void @void_callee(i8* nonnull %a, i8* null) +;CHECK: ret void +define void @void_caller(i8* %a, i8* %b) { +Top: + %c = icmp eq i8* %a, null + br i1 %c, label %Tail, label %TBB +TBB: + %c2 = icmp eq i8* %b, null + br i1 %c2, label %Tail, label %End +Tail: + musttail call void @void_callee(i8* %a, i8* %b) + ret void +End: + ret void +} + +define void @void_callee(i8* %a, i8* %b) noinline { + ret void +}