Index: llvm/trunk/lib/Transforms/Scalar/CallSiteSplitting.cpp =================================================================== --- llvm/trunk/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ llvm/trunk/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -209,8 +209,46 @@ return CallSiteBB->canSplitPredecessors(); } -/// Return true if the CS is split into its new predecessors. +static Instruction *cloneInstForMustTail(Instruction *I, Instruction *Before, + Value *V) { + Instruction *Copy = I->clone(); + Copy->setName(I->getName()); + Copy->insertBefore(Before); + if (V) + Copy->setOperand(0, V); + return Copy; +} + +/// Copy mandatory `musttail` return sequence that follows original `CI`, and +/// link it up to `NewCI` value instead: +/// +/// * (optional) `bitcast NewCI to ...` +/// * `ret bitcast or NewCI` /// +/// Insert this sequence right before `SplitBB`'s terminator, which will be +/// cleaned up later in `splitCallSite` below. +static void copyMustTailReturn(BasicBlock *SplitBB, Instruction *CI, + Instruction *NewCI) { + bool IsVoid = SplitBB->getParent()->getReturnType()->isVoidTy(); + auto II = std::next(CI->getIterator()); + + BitCastInst* BCI = dyn_cast(&*II); + if (BCI) + ++II; + + ReturnInst* RI = dyn_cast(&*II); + assert(RI && "`musttail` call must be followed by `ret` instruction"); + + TerminatorInst *TI = SplitBB->getTerminator(); + Value *V = NewCI; + if (BCI) + V = cloneInstForMustTail(BCI, TI, V); + cloneInstForMustTail(RI, TI, IsVoid ? nullptr : V); + + // FIXME: remove TI here, `DuplicateInstructionsInSplitBetween` has a bug + // that prevents doing this now. +} + /// For each (predecessor, conditions from predecessors) pair, it will split the /// basic block containing the call site, hook it up to the predecessor and /// replace the call instruction with new call instructions, which contain @@ -257,9 +295,14 @@ const SmallVectorImpl> &Preds) { Instruction *Instr = CS.getInstruction(); BasicBlock *TailBB = Instr->getParent(); + bool IsMustTailCall = CS.isMustTailCall(); PHINode *CallPN = nullptr; - if (Instr->getNumUses()) + + // `musttail` calls must be followed by optional `bitcast`, and `ret`. The + // split blocks will be terminated right after that so there're no users for + // this phi in a `TailBB`. + if (!IsMustTailCall && Instr->getNumUses()) CallPN = PHINode::Create(Instr->getType(), Preds.size(), "phi.call"); DEBUG(dbgs() << "split call-site : " << *Instr << " into \n"); @@ -293,6 +336,23 @@ << "\n"); if (CallPN) CallPN->addIncoming(NewCI, SplitBlock); + + // Clone and place bitcast and return instructions before `TI` + if (IsMustTailCall) + copyMustTailReturn(SplitBlock, Instr, NewCI); + } + + NumCallSiteSplit++; + + // FIXME: remove TI in `copyMustTailReturn` + if (IsMustTailCall) { + // Remove superfluous `br` terminators from the end of the Split blocks + for (BasicBlock *SplitBlock : predecessors(TailBB)) + SplitBlock->getTerminator()->eraseFromParent(); + + // Erase the tail block once done with musttail patching + TailBB->eraseFromParent(); + return; } auto *OriginalBegin = &*TailBB->begin(); @@ -329,8 +389,6 @@ if (CurrentI == OriginalBegin) break; } - - NumCallSiteSplit++; } // Return true if the call-site has an argument which is a PHI with only @@ -415,7 +473,17 @@ Function *Callee = CS.getCalledFunction(); if (!Callee || Callee->isDeclaration()) continue; + + // Successful musttail call-site splits result in erased CI and erased BB. + // Check if such path is possible before attempting the splitting. + bool IsMustTail = CS.isMustTailCall(); + Changed |= tryToSplitCallSite(CS, TTI); + + // There're no interesting instructions after this. The call site + // itself might have been erased on splitting. + if (IsMustTail) + break; } } return Changed; Index: llvm/trunk/test/Transforms/CallSiteSplitting/musttail.ll =================================================================== --- llvm/trunk/test/Transforms/CallSiteSplitting/musttail.ll +++ llvm/trunk/test/Transforms/CallSiteSplitting/musttail.ll @@ -0,0 +1,75 @@ +; RUN: opt < %s -callsite-splitting -S | FileCheck %s + +;CHECK-LABEL: @caller +;CHECK-LABEL: Top.split: +;CHECK: %ca1 = musttail call i8* @callee(i8* null, i8* %b) +;CHECK: %cb2 = bitcast i8* %ca1 to i8* +;CHECK: ret i8* %cb2 +;CHECK-LABEL: TBB.split +;CHECK: %ca3 = musttail call i8* @callee(i8* nonnull %a, i8* null) +;CHECK: %cb4 = bitcast i8* %ca3 to i8* +;CHECK: ret i8* %cb4 +define i8* @caller(i8* %a, i8* %b) { +Top: + %c = icmp eq i8* %a, null + br i1 %c, label %Tail, label %TBB +TBB: + %c2 = icmp eq i8* %b, null + br i1 %c2, label %Tail, label %End +Tail: + %ca = musttail call i8* @callee(i8* %a, i8* %b) + %cb = bitcast i8* %ca to i8* + ret i8* %cb +End: + ret i8* null +} + +define i8* @callee(i8* %a, i8* %b) noinline { + ret i8* %a +} + +;CHECK-LABEL: @no_cast_caller +;CHECK-LABEL: Top.split: +;CHECK: %ca1 = musttail call i8* @callee(i8* null, i8* %b) +;CHECK: ret i8* %ca1 +;CHECK-LABEL: TBB.split +;CHECK: %ca2 = musttail call i8* @callee(i8* nonnull %a, i8* null) +;CHECK: ret i8* %ca2 +define i8* @no_cast_caller(i8* %a, i8* %b) { +Top: + %c = icmp eq i8* %a, null + br i1 %c, label %Tail, label %TBB +TBB: + %c2 = icmp eq i8* %b, null + br i1 %c2, label %Tail, label %End +Tail: + %ca = musttail call i8* @callee(i8* %a, i8* %b) + ret i8* %ca +End: + ret i8* null +} + +;CHECK-LABEL: @void_caller +;CHECK-LABEL: Top.split: +;CHECK: musttail call void @void_callee(i8* null, i8* %b) +;CHECK: ret void +;CHECK-LABEL: TBB.split +;CHECK: musttail call void @void_callee(i8* nonnull %a, i8* null) +;CHECK: ret void +define void @void_caller(i8* %a, i8* %b) { +Top: + %c = icmp eq i8* %a, null + br i1 %c, label %Tail, label %TBB +TBB: + %c2 = icmp eq i8* %b, null + br i1 %c2, label %Tail, label %End +Tail: + musttail call void @void_callee(i8* %a, i8* %b) + ret void +End: + ret void +} + +define void @void_callee(i8* %a, i8* %b) noinline { + ret void +}