diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp --- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -445,11 +445,53 @@ return &*I; } -static CallInst *findTRECandidate(Instruction *TI, - bool CannotTailCallElimCallsMarkedTail, - const TargetTransformInfo *TTI) { +namespace { +class TailRecursionEliminator { + Function &F; + const TargetTransformInfo *TTI; + AliasAnalysis *AA; + OptimizationRemarkEmitter *ORE; + DomTreeUpdater &DTU; + + // The below are shared state we want to have available when eliminating any + // calls in the function. There values should be populated by + // createTailRecurseLoopHeader the first time we find a call we can eliminate. + BasicBlock *HeaderBB = nullptr; + SmallVector ArgumentPHIs; + bool RemovableCallsMustBeMarkedTail = false; + + TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI, + AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, + DomTreeUpdater &DTU) + : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {} + + CallInst *findTRECandidate(Instruction *TI, + bool CannotTailCallElimCallsMarkedTail); + + void createTailRecurseLoopHeader(CallInst *CI); + + PHINode *insertAccumulator(Value *AccumulatorRecursionEliminationInitVal); + + bool eliminateCall(CallInst *CI); + + bool foldReturnAndProcessPred(ReturnInst *Ret, + bool CannotTailCallElimCallsMarkedTail); + + bool processReturningBlock(ReturnInst *Ret, + bool CannotTailCallElimCallsMarkedTail); + + void cleanupAndFinalize(); + +public: + static bool eliminate(Function &F, const TargetTransformInfo *TTI, + AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, + DomTreeUpdater &DTU); +}; +} // namespace + +CallInst *TailRecursionEliminator::findTRECandidate( + Instruction *TI, bool CannotTailCallElimCallsMarkedTail) { BasicBlock *BB = TI->getParent(); - Function *F = BB->getParent(); if (&BB->front() == TI) // Make sure there is something before the terminator. return nullptr; @@ -460,7 +502,7 @@ BasicBlock::iterator BBI(TI); while (true) { CI = dyn_cast(BBI); - if (CI && CI->getCalledFunction() == F) + if (CI && CI->getCalledFunction() == &F) break; if (BBI == BB->begin()) @@ -477,15 +519,14 @@ // double fabs(double f) { return __builtin_fabs(f); } // a 'fabs' call // and disable this xform in this case, because the code generator will // lower the call to fabs into inline code. - if (BB == &F->getEntryBlock() && + if (BB == &F.getEntryBlock() && firstNonDbg(BB->front().getIterator()) == CI && firstNonDbg(std::next(BB->begin())) == TI && CI->getCalledFunction() && !TTI->isLoweredToCall(CI->getCalledFunction())) { // A single-block function with just a call and a return. Check that // the arguments match. auto I = CI->arg_begin(), E = CI->arg_end(); - Function::arg_iterator FI = F->arg_begin(), - FE = F->arg_end(); + Function::arg_iterator FI = F.arg_begin(), FE = F.arg_end(); for (; I != E && FI != FE; ++I, ++FI) if (*I != &*FI) break; if (I == E && FI == FE) @@ -495,10 +536,81 @@ return CI; } -static bool eliminateRecursiveTailCall( - CallInst *CI, ReturnInst *Ret, BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, SmallVectorImpl &ArgumentPHIs, - AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { +void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) { + HeaderBB = &F.getEntryBlock(); + BasicBlock *NewEntry = BasicBlock::Create(F.getContext(), "", &F, HeaderBB); + NewEntry->takeName(HeaderBB); + HeaderBB->setName("tailrecurse"); + BranchInst *BI = BranchInst::Create(HeaderBB, NewEntry); + BI->setDebugLoc(CI->getDebugLoc()); + + // If this function has self recursive calls in the tail position where some + // are marked tail and some are not, only transform one flavor or another. + // We have to choose whether we move allocas in the entry block to the new + // entry block or not, so we can't make a good choice for both. We make this + // decision here based on whether the first call we found to remove is + // marked tail. + // NOTE: We could do slightly better here in the case that the function has + // no entry block allocas. + RemovableCallsMustBeMarkedTail = CI->isTailCall(); + + // If this tail call is marked 'tail' and if there are any allocas in the + // entry block, move them up to the new entry block. + if (RemovableCallsMustBeMarkedTail) + // Move all fixed sized allocas from HeaderBB to NewEntry. + for (BasicBlock::iterator OEBI = HeaderBB->begin(), E = HeaderBB->end(), + NEBI = NewEntry->begin(); + OEBI != E;) + if (AllocaInst *AI = dyn_cast(OEBI++)) + if (isa(AI->getArraySize())) + AI->moveBefore(&*NEBI); + + // Now that we have created a new block, which jumps to the entry + // block, insert a PHI node for each argument of the function. + // For now, we initialize each PHI to only have the real arguments + // which are passed in. + Instruction *InsertPos = &HeaderBB->front(); + for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) { + PHINode *PN = + PHINode::Create(I->getType(), 2, I->getName() + ".tr", InsertPos); + I->replaceAllUsesWith(PN); // Everyone use the PHI node now! + PN->addIncoming(&*I, NewEntry); + ArgumentPHIs.push_back(PN); + } + // The entry block was changed from HeaderBB to NewEntry. + // The forward DominatorTree needs to be recalculated when the EntryBB is + // changed. In this corner-case we recalculate the entire tree. + DTU.recalculate(*NewEntry->getParent()); +} + +PHINode *TailRecursionEliminator::insertAccumulator( + Value *AccumulatorRecursionEliminationInitVal) { + // Start by inserting a new PHI node for the accumulator. + pred_iterator PB = pred_begin(HeaderBB), PE = pred_end(HeaderBB); + PHINode *AccPN = PHINode::Create( + AccumulatorRecursionEliminationInitVal->getType(), + std::distance(PB, PE) + 1, "accumulator.tr", &HeaderBB->front()); + + // Loop over all of the predecessors of the tail recursion block. For the + // real entry into the function we seed the PHI with the initial value, + // computed earlier. For any other existing branches to this block (due to + // other tail recursions eliminated) the accumulator is not modified. + // Because we haven't added the branch in the current block to HeaderBB yet, + // it will not show up as a predecessor. + for (pred_iterator PI = PB; PI != PE; ++PI) { + BasicBlock *P = *PI; + if (P == &F.getEntryBlock()) + AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, P); + else + AccPN->addIncoming(AccPN, P); + } + + return AccPN; +} + +bool TailRecursionEliminator::eliminateCall(CallInst *CI) { + ReturnInst *Ret = cast(CI->getParent()->getTerminator()); + // If we are introducing accumulator recursion to eliminate operations after // the call instruction that are both associative and commutative, the initial // value for the accumulator is placed in this variable. If this value is set @@ -556,7 +668,6 @@ } BasicBlock *BB = Ret->getParent(); - Function *F = BB->getParent(); using namespace ore; ORE->emit([&]() { @@ -566,51 +677,10 @@ // OK! We can transform this tail call. If this is the first one found, // create the new entry block, allowing us to branch back to the old entry. - if (!OldEntry) { - OldEntry = &F->getEntryBlock(); - BasicBlock *NewEntry = BasicBlock::Create(F->getContext(), "", F, OldEntry); - NewEntry->takeName(OldEntry); - OldEntry->setName("tailrecurse"); - BranchInst *BI = BranchInst::Create(OldEntry, NewEntry); - BI->setDebugLoc(CI->getDebugLoc()); - - // If this tail call is marked 'tail' and if there are any allocas in the - // entry block, move them up to the new entry block. - TailCallsAreMarkedTail = CI->isTailCall(); - if (TailCallsAreMarkedTail) - // Move all fixed sized allocas from OldEntry to NewEntry. - for (BasicBlock::iterator OEBI = OldEntry->begin(), E = OldEntry->end(), - NEBI = NewEntry->begin(); OEBI != E; ) - if (AllocaInst *AI = dyn_cast(OEBI++)) - if (isa(AI->getArraySize())) - AI->moveBefore(&*NEBI); - - // Now that we have created a new block, which jumps to the entry - // block, insert a PHI node for each argument of the function. - // For now, we initialize each PHI to only have the real arguments - // which are passed in. - Instruction *InsertPos = &OldEntry->front(); - for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); - I != E; ++I) { - PHINode *PN = PHINode::Create(I->getType(), 2, - I->getName() + ".tr", InsertPos); - I->replaceAllUsesWith(PN); // Everyone use the PHI node now! - PN->addIncoming(&*I, NewEntry); - ArgumentPHIs.push_back(PN); - } - // The entry block was changed from OldEntry to NewEntry. - // The forward DominatorTree needs to be recalculated when the EntryBB is - // changed. In this corner-case we recalculate the entire tree. - DTU.recalculate(*NewEntry->getParent()); - } + if (!HeaderBB) + createTailRecurseLoopHeader(CI); - // If this function has self recursive calls in the tail position where some - // are marked tail and some are not, only transform one flavor or another. We - // have to choose whether we move allocas in the entry block to the new entry - // block or not, so we can't make a good choice for both. NOTE: We could do - // slightly better here in the case that the function has no entry block - // allocas. - if (TailCallsAreMarkedTail && !CI->isTailCall()) + if (RemovableCallsMustBeMarkedTail && !CI->isTailCall()) return false; // Ok, now that we know we have a pseudo-entry block WITH all of the @@ -625,27 +695,9 @@ // accumulator recursion predicate is set up. // if (AccumulatorRecursionEliminationInitVal) { - Instruction *AccRecInstr = AccumulatorRecursionInstr; - // Start by inserting a new PHI node for the accumulator. - pred_iterator PB = pred_begin(OldEntry), PE = pred_end(OldEntry); - PHINode *AccPN = PHINode::Create( - AccumulatorRecursionEliminationInitVal->getType(), - std::distance(PB, PE) + 1, "accumulator.tr", &OldEntry->front()); - - // Loop over all of the predecessors of the tail recursion block. For the - // real entry into the function we seed the PHI with the initial value, - // computed earlier. For any other existing branches to this block (due to - // other tail recursions eliminated) the accumulator is not modified. - // Because we haven't added the branch in the current block to OldEntry yet, - // it will not show up as a predecessor. - for (pred_iterator PI = PB; PI != PE; ++PI) { - BasicBlock *P = *PI; - if (P == &F->getEntryBlock()) - AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, P); - else - AccPN->addIncoming(AccPN, P); - } + PHINode *AccPN = insertAccumulator(AccumulatorRecursionEliminationInitVal); + Instruction *AccRecInstr = AccumulatorRecursionInstr; if (AccRecInstr) { // Add an incoming argument for the current block, which is computed by // our associative and commutative accumulator instruction. @@ -664,7 +716,7 @@ // Finally, rewrite any return instructions in the program to return the PHI // node instead of the "initval" that they do currently. This loop will // actually rewrite the return value we are destroying, but that's ok. - for (BasicBlock &BBI : *F) + for (BasicBlock &BBI : F) if (ReturnInst *RI = dyn_cast(BBI.getTerminator())) RI->setOperand(0, AccPN); ++NumAccumAdded; @@ -672,21 +724,20 @@ // Now that all of the PHI nodes are in place, remove the call and // ret instructions, replacing them with an unconditional branch. - BranchInst *NewBI = BranchInst::Create(OldEntry, Ret); + BranchInst *NewBI = BranchInst::Create(HeaderBB, Ret); NewBI->setDebugLoc(CI->getDebugLoc()); BB->getInstList().erase(Ret); // Remove return. BB->getInstList().erase(CI); // Remove call. - DTU.applyUpdates({{DominatorTree::Insert, BB, OldEntry}}); + DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}}); ++NumEliminated; return true; } -static bool foldReturnAndProcessPred( - BasicBlock *BB, ReturnInst *Ret, BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, SmallVectorImpl &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI, - AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { +bool TailRecursionEliminator::foldReturnAndProcessPred( + ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) { + BasicBlock *BB = Ret->getParent(); + bool Change = false; // Make sure this block is a trivial return block. @@ -709,10 +760,11 @@ while (!UncondBranchPreds.empty()) { BranchInst *BI = UncondBranchPreds.pop_back_val(); BasicBlock *Pred = BI->getParent(); - if (CallInst *CI = findTRECandidate(BI, CannotTailCallElimCallsMarkedTail, TTI)){ + if (CallInst *CI = + findTRECandidate(BI, CannotTailCallElimCallsMarkedTail)) { LLVM_DEBUG(dbgs() << "FOLDING: " << *BB << "INTO UNCOND BRANCH PRED: " << *Pred); - ReturnInst *RI = FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU); + FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU); // Cleanup: if all predecessors of BB have been eliminated by // FoldReturnIntoUncondBranch, delete it. It is important to empty it, @@ -721,8 +773,7 @@ if (!BB->hasAddressTaken() && pred_begin(BB) == pred_end(BB)) DTU.deleteBB(BB); - eliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, AA, ORE, DTU); + eliminateCall(CI); ++NumRetDuped; Change = true; } @@ -731,23 +782,35 @@ return Change; } -static bool processReturningBlock( - ReturnInst *Ret, BasicBlock *&OldEntry, bool &TailCallsAreMarkedTail, - SmallVectorImpl &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI, - AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { - CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail, TTI); +bool TailRecursionEliminator::processReturningBlock( + ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) { + CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail); if (!CI) return false; - return eliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, AA, ORE, DTU); + return eliminateCall(CI); +} + +void TailRecursionEliminator::cleanupAndFinalize() { + // If we eliminated any tail recursions, it's possible that we inserted some + // silly PHI nodes which just merge an initial value (the incoming operand) + // with themselves. Check to see if we did and clean up our mess if so. This + // occurs when a function passes an argument straight through to its tail + // call. + for (PHINode *PN : ArgumentPHIs) { + // If the PHI Node is a dynamic constant, replace it with the value it is. + if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) { + PN->replaceAllUsesWith(PNV); + PN->eraseFromParent(); + } + } } -static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, - AliasAnalysis *AA, - OptimizationRemarkEmitter *ORE, - DomTreeUpdater &DTU) { +bool TailRecursionEliminator::eliminate(Function &F, + const TargetTransformInfo *TTI, + AliasAnalysis *AA, + OptimizationRemarkEmitter *ORE, + DomTreeUpdater &DTU) { if (F.getFnAttribute("disable-tail-calls").getValueAsString() == "true") return false; @@ -762,15 +825,13 @@ if (F.getFunctionType()->isVarArg()) return false; - BasicBlock *OldEntry = nullptr; - bool TailCallsAreMarkedTail = false; - SmallVector ArgumentPHIs; - // If false, we cannot perform TRE on tail calls marked with the 'tail' // attribute, because doing so would cause the stack size to increase (real // TRE would deallocate variable sized allocas, TRE doesn't). bool CanTRETailMarkedCall = canTRE(F); + TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU); + // Change any tail recursive calls to loops. // // FIXME: The code generator produces really bad code when an 'escaping @@ -780,29 +841,14 @@ for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; /*in loop*/) { BasicBlock *BB = &*BBI++; // foldReturnAndProcessPred may delete BB. if (ReturnInst *Ret = dyn_cast(BB->getTerminator())) { - bool Change = processReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, !CanTRETailMarkedCall, - TTI, AA, ORE, DTU); + bool Change = TRE.processReturningBlock(Ret, !CanTRETailMarkedCall); if (!Change && BB->getFirstNonPHIOrDbg() == Ret) - Change = foldReturnAndProcessPred( - BB, Ret, OldEntry, TailCallsAreMarkedTail, ArgumentPHIs, - !CanTRETailMarkedCall, TTI, AA, ORE, DTU); + Change = TRE.foldReturnAndProcessPred(Ret, !CanTRETailMarkedCall); MadeChange |= Change; } } - // If we eliminated any tail recursions, it's possible that we inserted some - // silly PHI nodes which just merge an initial value (the incoming operand) - // with themselves. Check to see if we did and clean up our mess if so. This - // occurs when a function passes an argument straight through to its tail - // call. - for (PHINode *PN : ArgumentPHIs) { - // If the PHI Node is a dynamic constant, replace it with the value it is. - if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) { - PN->replaceAllUsesWith(PNV); - PN->eraseFromParent(); - } - } + TRE.cleanupAndFinalize(); return MadeChange; } @@ -836,7 +882,7 @@ // UpdateStrategy to Lazy if we find it profitable later. DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager); - return eliminateTailRecursion( + return TailRecursionEliminator::eliminate( F, &getAnalysis().getTTI(F), &getAnalysis().getAAResults(), &getAnalysis().getORE(), DTU); @@ -869,7 +915,7 @@ // UpdateStrategy based on some test results. It is feasible to switch the // UpdateStrategy to Lazy if we find it profitable later. DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager); - bool Changed = eliminateTailRecursion(F, &TTI, &AA, &ORE, DTU); + bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU); if (!Changed) return PreservedAnalyses::all();