Index: lib/Transforms/Scalar/CallSiteSplitting.cpp =================================================================== --- lib/Transforms/Scalar/CallSiteSplitting.cpp +++ lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -209,8 +209,89 @@ return CallSiteBB->canSplitPredecessors(); } -/// Return true if the CS is split into its new predecessors. +static Instruction *cloneReturnInstForMustTail(Instruction *I, + Instruction *Before, Value *V) { + Instruction *Copy = I->clone(); + Copy->setName(I->getName()); + Copy->insertBefore(Before); + if (V) + Copy->setOperand(0, V); + return Copy; +} + +/// `musttail` calls must be followed by optional `bitcast`, and `ret` +/// instructions. `splitCallSite` however produces blocks with copies of the +/// call followed by `br` instructions, which violates the invariant. /// +/// Restore the invariant in each of the predecessor blocks by moving the +/// `bitcast` and `ret` from the tail block into the predecessor. +static void fixMustTailCallSite(BasicBlock *TailBB) { + DEBUG(dbgs() << "fix musttail splits:\n"); + + bool IsVoid = TailBB->getParent()->getReturnType()->isVoidTy(); + + BasicBlock::iterator BI = TailBB->begin(); + // PHINode is present for non-void return values + PHINode *PN = dyn_cast(&*BI); + if (PN) + ++BI; + + // Optional BitCastInst + BitCastInst *BCI = dyn_cast(&*BI); + if (BCI) + ++BI; + + // Mandatory ReturnInst + ReturnInst *RI = dyn_cast(&*BI); + assert(RI && "`musttail` call must be followed by `ret` instruction"); + + assert(std::next(BI) == TailBB->end() && "`ret` doesn't end the block"); + + SmallVector Splits(predecessors((TailBB))); + assert(Splits.size() == 2 && "Expected exactly 2 splits!"); + + for (unsigned i = 0; i < Splits.size(); i++) { + BasicBlock *Split = Splits[i]; + + TerminatorInst *TI = Split->getTerminator(); + + CallInst *CI = dyn_cast(&*std::prev(TI->getIterator())); + assert(CI && "pred blocks must have the call"); + assert(IsVoid == CI->getFunctionType()->getReturnType()->isVoidTy() && + "Return type mismatch between musttail call and the caller function"); + DEBUG(dbgs() << " feeding CI : " << *CI << " into\n"); + + Value *V = CI; + + // Add bitcast + if (BCI) { + assert(!IsVoid && "Can\'t bitcast void"); + V = cloneReturnInstForMustTail(BCI, TI, V); + DEBUG(dbgs() << " " << *V << " into\n"); + } + + // Add return + Instruction *NewRI = cloneReturnInstForMustTail(RI, TI, + IsVoid ? nullptr : V); + DEBUG(dbgs() << " " << *NewRI << "\n"); + + // Remove terminating `br` instruction + TI->eraseFromParent(); + + if (PN) + PN->removeIncomingValue(Split, true); + } + + // Remove unused instructions + if (RI) + RI->eraseFromParent(); + if (BCI) + BCI->eraseFromParent(); + + // The block is no longer reachable + TailBB->eraseFromParent(); +} + /// For each (predecessor, conditions from predecessors) pair, it will split the /// basic block containing the call site, hook it up to the predecessor and /// replace the call instruction with new call instructions, which contain @@ -257,6 +338,7 @@ const SmallVectorImpl> &Preds) { Instruction *Instr = CS.getInstruction(); BasicBlock *TailBB = Instr->getParent(); + bool IsMustTailCall = CS.isMustTailCall(); PHINode *CallPN = nullptr; if (Instr->getNumUses()) @@ -330,6 +412,9 @@ break; } + if (IsMustTailCall) + fixMustTailCallSite(TailBB); + NumCallSiteSplit++; } @@ -415,7 +500,17 @@ Function *Callee = CS.getCalledFunction(); if (!Callee || Callee->isDeclaration()) continue; + + // Successful musttail call-site splits result in erased CI and erased BB. + // Check if such path is possible before attempting the splitting. + bool IsMustTail = CS.isMustTailCall(); + Changed |= tryToSplitCallSite(CS, TTI); + + // There're no interesting instructions after this. The call site + // itself might have been erased on splitting. + if (IsMustTail) + break; } } return Changed; Index: test/Transforms/CallSiteSplitting/musttail.ll =================================================================== --- /dev/null +++ test/Transforms/CallSiteSplitting/musttail.ll @@ -0,0 +1,75 @@ +; RUN: opt < %s -callsite-splitting -S | FileCheck %s + +;CHECK-LABEL: @caller +;CHECK-LABEL: Top.split: +;CHECK: %ca1 = musttail call i8* @callee(i8* null, i8* %b) +;CHECK: %cb4 = bitcast i8* %ca1 to i8* +;CHECK: ret i8* %cb4 +;CHECK-LABEL: TBB.split +;CHECK: %ca2 = musttail call i8* @callee(i8* nonnull %a, i8* null) +;CHECK: %cb3 = bitcast i8* %ca2 to i8* +;CHECK: ret i8* %cb3 +define i8* @caller(i8* %a, i8* %b) { +Top: + %c = icmp eq i8* %a, null + br i1 %c, label %Tail, label %TBB +TBB: + %c2 = icmp eq i8* %b, null + br i1 %c2, label %Tail, label %End +Tail: + %ca = musttail call i8* @callee(i8* %a, i8* %b) + %cb = bitcast i8* %ca to i8* + ret i8* %cb +End: + ret i8* null +} + +define i8* @callee(i8* %a, i8* %b) noinline { + ret i8* %a +} + +;CHECK-LABEL: @no_cast_caller +;CHECK-LABEL: Top.split: +;CHECK: %ca1 = musttail call i8* @callee(i8* null, i8* %b) +;CHECK: ret i8* %ca1 +;CHECK-LABEL: TBB.split +;CHECK: %ca2 = musttail call i8* @callee(i8* nonnull %a, i8* null) +;CHECK: ret i8* %ca2 +define i8* @no_cast_caller(i8* %a, i8* %b) { +Top: + %c = icmp eq i8* %a, null + br i1 %c, label %Tail, label %TBB +TBB: + %c2 = icmp eq i8* %b, null + br i1 %c2, label %Tail, label %End +Tail: + %ca = musttail call i8* @callee(i8* %a, i8* %b) + ret i8* %ca +End: + ret i8* null +} + +;CHECK-LABEL: @void_caller +;CHECK-LABEL: Top.split: +;CHECK: musttail call void @void_callee(i8* null, i8* %b) +;CHECK: ret void +;CHECK-LABEL: TBB.split +;CHECK: musttail call void @void_callee(i8* nonnull %a, i8* null) +;CHECK: ret void +define void @void_caller(i8* %a, i8* %b) { +Top: + %c = icmp eq i8* %a, null + br i1 %c, label %Tail, label %TBB +TBB: + %c2 = icmp eq i8* %b, null + br i1 %c2, label %Tail, label %End +Tail: + musttail call void @void_callee(i8* %a, i8* %b) + ret void +End: + ret void +} + +define void @void_callee(i8* %a, i8* %b) noinline { + ret void +}