Index: lib/Transforms/Scalar/CallSiteSplitting.cpp =================================================================== --- lib/Transforms/Scalar/CallSiteSplitting.cpp +++ lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -209,8 +209,89 @@ return CallSiteBB->canSplitPredecessors(); } -/// Return true if the CS is split into its new predecessors. +/// `musttail` calls must be followed by optional `bitcast`, and `ret` +/// instructions. `splitCallSite` however produces blocks with copies of the +/// call followed by `br` instructions, which violates the invariant. /// +/// Restore the invariant in each of the predecessor blocks by moving the +/// `bitcast` and `ret` from the tail block into the predecessor. +static void fixMustTailCallSite(BasicBlock *TailBB) { + DEBUG(dbgs() << "fix musttail splits:\n"); + + BasicBlock::iterator BI = TailBB->begin(); + // PHINode is present for non-void return values + PHINode *PN = dyn_cast(&*BI); + if (PN != nullptr) + ++BI; + + // Optional BitCastInst + BitCastInst *BCI = dyn_cast(&*BI); + if (BCI != nullptr) { + ++BI; + DEBUG(dbgs() << " `bitcast` : " << *BCI << "\n"); + } + + // Mandatory ReturnInst + ReturnInst *RI = dyn_cast(&*BI); + assert(RI != nullptr && + "`musttail` call must be followed by `ret` instruction"); + DEBUG(dbgs() << " `ret` : " << *RI << "\n"); + + assert(++BI == TailBB->end() && "`ret` doesn't end the block"); + + SmallVector Splits(predecessors((TailBB))); + assert(Splits.size() == 2 && "Expected exactly 2 splits!"); + + for (unsigned i = 0; i < Splits.size(); i++) { + BasicBlock *Split = Splits[i]; + + TerminatorInst *TI = Split->getTerminator(); + + CallInst *CI = dyn_cast(&*Split->begin()); + assert(CI != nullptr && "pred blocks must have the call"); + DEBUG(dbgs() << " feeding CI : " << *CI << " into\n"); + + bool IsVoid = CI->getFunctionType()->getReturnType()->isVoidTy(); + Value *V = CI; + + // Add bitcast + if (BCI != nullptr) { + assert(!IsVoid && "Can\'t bitcast void"); + Instruction *NewBCI = BCI->clone(); + NewBCI->setName(BCI->getName()); + NewBCI->insertBefore(TI); + NewBCI->setOperand(0, V); + DEBUG(dbgs() << " " << *NewBCI << " into\n"); + + V = NewBCI; + } + + // Add return + Instruction *NewRI = RI->clone(); + NewRI->setName(RI->getName()); + NewRI->insertBefore(TI); + if (!IsVoid) + NewRI->setOperand(0, V); + DEBUG(dbgs() << " " << *NewRI << "\n"); + + // Remove terminating `br` instruction + DEBUG(dbgs() << " removing terminator : " << *TI << "\n"); + TI->eraseFromParent(); + + if (PN != nullptr) + PN->removeIncomingValue(Split, true); + } + + // Remove unused instructions + if (RI != nullptr) + RI->eraseFromParent(); + if (BCI != nullptr) + BCI->eraseFromParent(); + + // The block is no longer reachable + TailBB->eraseFromParent(); +} + /// For each (predecessor, conditions from predecessors) pair, it will split the /// basic block containing the call site, hook it up to the predecessor and /// replace the call instruction with new call instructions, which contain @@ -330,6 +411,9 @@ break; } + if (CS.isMustTailCall()) + fixMustTailCallSite(TailBB); + NumCallSiteSplit++; } @@ -415,7 +499,17 @@ Function *Callee = CS.getCalledFunction(); if (!Callee || Callee->isDeclaration()) continue; + + // Successful musttail call-site splits result in erased CI and erased BB. + // Check if such path is possible before attempting the splitting. + bool IsMustTail = CS.isMustTailCall(); + Changed |= tryToSplitCallSite(CS, TTI); + + // There're no interesting instructions after this. The call site + // itself might have been erased on splitting. + if (IsMustTail) + break; } } return Changed; Index: test/Transforms/CallSiteSplitting/musttail.ll =================================================================== --- /dev/null +++ test/Transforms/CallSiteSplitting/musttail.ll @@ -0,0 +1,75 @@ +; RUN: opt < %s -callsite-splitting -S | FileCheck %s + +;CHECK-LABEL: @caller +;CHECK-LABEL: Top.split: +;CHECK: %ca1 = musttail call i8* @callee(i8* null, i8* %b) +;CHECK: %cb4 = bitcast i8* %ca1 to i8* +;CHECK: ret i8* %cb4 +;CHECK-LABEL: TBB.split +;CHECK: %ca2 = musttail call i8* @callee(i8* nonnull %a, i8* null) +;CHECK: %cb3 = bitcast i8* %ca2 to i8* +;CHECK: ret i8* %cb3 +define i8* @caller(i8* %a, i8* %b) { +Top: + %c = icmp eq i8* %a, null + br i1 %c, label %Tail, label %TBB +TBB: + %c2 = icmp eq i8* %b, null + br i1 %c2, label %Tail, label %End +Tail: + %ca = musttail call i8* @callee(i8* %a, i8* %b) + %cb = bitcast i8* %ca to i8* + ret i8* %cb +End: + ret i8* null +} + +define i8* @callee(i8* %a, i8* %b) noinline { + ret i8* %a +} + +;CHECK-LABEL: @no_cast_caller +;CHECK-LABEL: Top.split: +;CHECK: %ca1 = musttail call i8* @callee(i8* null, i8* %b) +;CHECK: ret i8* %ca1 +;CHECK-LABEL: TBB.split +;CHECK: %ca2 = musttail call i8* @callee(i8* nonnull %a, i8* null) +;CHECK: ret i8* %ca2 +define i8* @no_cast_caller(i8* %a, i8* %b) { +Top: + %c = icmp eq i8* %a, null + br i1 %c, label %Tail, label %TBB +TBB: + %c2 = icmp eq i8* %b, null + br i1 %c2, label %Tail, label %End +Tail: + %ca = musttail call i8* @callee(i8* %a, i8* %b) + ret i8* %ca +End: + ret i8* null +} + +;CHECK-LABEL: @void_caller +;CHECK-LABEL: Top.split: +;CHECK: musttail call void @void_callee(i8* null, i8* %b) +;CHECK: ret void +;CHECK-LABEL: TBB.split +;CHECK: musttail call void @void_callee(i8* nonnull %a, i8* null) +;CHECK: ret void +define void @void_caller(i8* %a, i8* %b) { +Top: + %c = icmp eq i8* %a, null + br i1 %c, label %Tail, label %TBB +TBB: + %c2 = icmp eq i8* %b, null + br i1 %c2, label %Tail, label %End +Tail: + musttail call void @void_callee(i8* %a, i8* %b) + ret void +End: + ret void +} + +define void @void_callee(i8* %a, i8* %b) noinline { + ret void +}