Index: llvm/include/llvm/Transforms/IPO/IROutliner.h =================================================================== --- llvm/include/llvm/Transforms/IPO/IROutliner.h +++ llvm/include/llvm/Transforms/IPO/IROutliner.h @@ -179,8 +179,12 @@ /// for extraction. bool IgnoreGroup = false; - /// The return block for the overall function. - BasicBlock *EndBB = nullptr; + /// The return blocks for the overall function. + DenseMap EndBBs; + + /// The PHIBlocks with their corresponding return block based on the return + /// value as the key. + DenseMap PHIBlocks; /// A set containing the different GVN store sets needed. Each array contains /// a sorted list of the different values that need to be stored into output Index: llvm/lib/Transforms/IPO/IROutliner.cpp =================================================================== --- llvm/lib/Transforms/IPO/IROutliner.cpp +++ llvm/lib/Transforms/IPO/IROutliner.cpp @@ -18,6 +18,7 @@ #include "llvm/IR/Attributes.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DIBuilder.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Mangler.h" #include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" @@ -372,10 +373,11 @@ /// /// \param [in] Old - The function to move the basic blocks from. /// \param [in] New - The function to move the basic blocks to. +/// \param [out] NewEnds - The return blocks of the new overall function. /// \returns the first return block for the function in New. -static BasicBlock *moveFunctionData(Function &Old, Function &New) { +static void moveFunctionData(Function &Old, Function &New, + DenseMap &NewEnds) { Function::iterator CurrBB, NextBB, FinalBB; - BasicBlock *NewEnd = nullptr; std::vector DebugInsts; for (CurrBB = Old.begin(), FinalBB = Old.end(); CurrBB != FinalBB; CurrBB = NextBB) { @@ -383,8 +385,10 @@ CurrBB->removeFromParent(); CurrBB->insertInto(&New); Instruction *I = CurrBB->getTerminator(); - if (isa(I)) - NewEnd = &(*CurrBB); + if (ReturnInst *RI = dyn_cast(I)) { + Value *RetVal = RI->getReturnValue(); + NewEnds.insert(std::make_pair(RetVal, &(*CurrBB))); + } for (Instruction &Val : *CurrBB) { // We must handle the scoping of called functions differently than @@ -418,8 +422,7 @@ I->eraseFromParent(); } - assert(NewEnd && "No return instruction for new function?"); - return NewEnd; + assert(NewEnds.size() > 0 && "No return instruction for new function?"); } /// Find the the constants that will need to be lifted into arguments @@ -887,10 +890,16 @@ /// \param [in,out] OutputBB - The BasicBlock for the output stores for this /// region. static void replaceArgumentUses(OutlinableRegion &Region, - BasicBlock *OutputBB) { + DenseMap &OutputBBs, + bool FirstFunction=false) { OutlinableGroup &Group = *Region.Parent; assert(Region.ExtractedFunction && "Region has no extracted function?"); + Function *DominatingFunction = Region.ExtractedFunction; + if (FirstFunction) + DominatingFunction = Group.OutlinedFunction; + DominatorTree DT(*DominatingFunction); + for (unsigned ArgIdx = 0; ArgIdx < Region.ExtractedFunction->arg_size(); ArgIdx++) { assert(Region.ExtractedArgToAgg.find(ArgIdx) != @@ -917,11 +926,42 @@ assert(InstAsUser && "User is nullptr!"); Instruction *I = cast(InstAsUser); - I->setDebugLoc(DebugLoc()); + BasicBlock *BB = I->getParent(); + SmallVector Descendants; + DT.getDescendants(BB, Descendants); + bool EdgeAdded = false; + if (Descendants.size() == 0) { + EdgeAdded = true; + DT.insertEdge(&DominatingFunction->getEntryBlock(), BB); + DT.getDescendants(BB, Descendants); + } + + // Iterate over the following blocks, looking for return instructions, + // if we find one, find the corresponding output block for the return value + // and move our store instruction there. + for (BasicBlock *DescendBB : Descendants) { + Instruction *Term = DescendBB->getTerminator(); + if (!isa(Term)) + continue; + ReturnInst *RI = cast(Term); + Value *RetVal = RI->getReturnValue(); + DenseMap::iterator VBBIt = OutputBBs.find(RetVal); + assert(VBBIt != OutputBBs.end() && "Could not find output value!"); + + Instruction *NewI = I->clone(); + NewI->setDebugLoc(DebugLoc()); + BasicBlock *OutputBB = VBBIt->second; + OutputBB->getInstList().push_back(NewI); LLVM_DEBUG(dbgs() << "Move store for instruction " << *I << " to " << *OutputBB << "\n"); - I->moveBefore(*OutputBB, OutputBB->end()); + } + + // If we added an edge for basic blocks without a predecessor, we remove it + // here. + if (EdgeAdded) + DT.deleteEdge(&DominatingFunction->getEntryBlock(), BB); + I->eraseFromParent(); LLVM_DEBUG(dbgs() << "Replacing uses of output " << *Arg << " in function " << *Region.ExtractedFunction << " with " << *AggArg @@ -995,19 +1035,32 @@ /// \param OutputBB [in] the block we are looking for a duplicate of. /// \param OutputStoreBBs [in] The existing output blocks. /// \returns an optional value with the number output block if there is a match. -Optional -findDuplicateOutputBlock(BasicBlock *OutputBB, - ArrayRef OutputStoreBBs) { +Optional findDuplicateOutputBlock( + DenseMap &OutputBBs, + std::vector> &OutputStoreBBs) { bool WrongInst = false; bool WrongSize = false; unsigned MatchingNum = 0; - for (BasicBlock *CompBB : OutputStoreBBs) { + // We compare the new set output blocks to the other sets of outut blocks. + // If they are the same number, and have identical instructions, they are + // considered to be the same. + for (DenseMap &CompBBs : OutputStoreBBs) { WrongInst = false; + WrongSize = false; + for (std::pair &VToB : CompBBs) { + DenseMap::iterator OutputBBIt = + OutputBBs.find(VToB.first); + if (OutputBBIt == OutputBBs.end()) { + WrongSize = true; + break; + } + + BasicBlock *CompBB = VToB.second; + BasicBlock *OutputBB = OutputBBIt->second; if (CompBB->size() - 1 != OutputBB->size()) { WrongSize = true; - MatchingNum++; - continue; + break; } WrongSize = false; @@ -1023,6 +1076,8 @@ NIt++; } + } + if (!WrongInst && !WrongSize) return MatchingNum; @@ -1044,11 +1099,12 @@ /// \param [in] OutputMappings - OutputMappings the mapping of values that have /// been replaced by a new output value. /// \param [in,out] OutputStoreBBs - The existing output blocks. -static void -alignOutputBlockWithAggFunc(OutlinableGroup &OG, OutlinableRegion &Region, - BasicBlock *OutputBB, BasicBlock *EndBB, +static void alignOutputBlockWithAggFunc( + OutlinableGroup &OG, OutlinableRegion &Region, + DenseMap OutputBBs, + DenseMap EndBBs, const DenseMap &OutputMappings, - std::vector &OutputStoreBBs) { + std::vector> &OutputStoreBBs) { DenseSet ValuesToFind(Region.GVNStores.begin(), Region.GVNStores.end()); @@ -1057,9 +1113,12 @@ // be contained in a store, we replace the uses of the value with the value // from the overall function, so that the store is storing the correct // value from the overall function. - DenseSet ExcludeBBs(OutputStoreBBs.begin(), - OutputStoreBBs.end()); - ExcludeBBs.insert(OutputBB); + DenseSet ExcludeBBs; + for (DenseMap &BBMap : OutputStoreBBs) + for (std::pair &VBPair : BBMap) + ExcludeBBs.insert(VBPair.second); + for (std::pair &VBPair : OutputBBs) + ExcludeBBs.insert(VBPair.second); std::vector ExtractedFunctionInsts = collectRelevantInstructions(*(Region.ExtractedFunction), ExcludeBBs); std::vector OverallFunctionInsts = @@ -1092,35 +1151,62 @@ // If the size of the block is 0, then there are no stores, and we do not // need to save this block. - if (OutputBB->size() == 0) { + bool AllRemoved = true; + Value *RetValueForBB; + BasicBlock *NewBB; + SmallVector ToRemove; + for (std::pair &VtoBB : OutputBBs) { + RetValueForBB = VtoBB.first; + NewBB = VtoBB.second; + + if (NewBB->size() == 0) { + NewBB->eraseFromParent(); + ToRemove.push_back(RetValueForBB); + continue; + } + + AllRemoved = false; + } + + for (Value *V : ToRemove) + OutputBBs.erase(V); + + if (AllRemoved) { Region.OutputBlockNum = -1; - OutputBB->eraseFromParent(); return; } - // Determine is there is a duplicate block. + // Determine is there is a duplicate set of blocks. Optional MatchingBB = - findDuplicateOutputBlock(OutputBB, OutputStoreBBs); + findDuplicateOutputBlock(OutputBBs, OutputStoreBBs); - // If there is, we remove the new output block. If it does not, - // we add it to our list of output blocks. + // If there is, we remove the new output blocks. If it does not, + // we add it to our list of sets of output blocks. if (MatchingBB.hasValue()) { LLVM_DEBUG(dbgs() << "Set output block for region in function" << Region.ExtractedFunction << " to " << MatchingBB.getValue()); Region.OutputBlockNum = MatchingBB.getValue(); - OutputBB->eraseFromParent(); + for (std::pair &VtoBB : OutputBBs) + VtoBB.second->eraseFromParent(); return; } Region.OutputBlockNum = OutputStoreBBs.size(); + OutputStoreBBs.push_back(DenseMap()); + for (std::pair &VtoBB : OutputBBs) { + RetValueForBB = VtoBB.first; + NewBB = VtoBB.second; + DenseMap::iterator VBBIt = + EndBBs.find(RetValueForBB); LLVM_DEBUG(dbgs() << "Create output block for region in" << Region.ExtractedFunction << " to " - << *OutputBB); - OutputStoreBBs.push_back(OutputBB); - BranchInst::Create(EndBB, OutputBB); + << *NewBB); + BranchInst::Create(VBBIt->second, NewBB); + OutputStoreBBs.back().insert(std::make_pair(RetValueForBB, NewBB)); + } } /// Create the switch statement for outlined function to differentiate between @@ -1132,19 +1218,25 @@ /// \param [in] OG - The group of regions to be outlined. /// \param [in] EndBB - The final block of the extracted function. /// \param [in,out] OutputStoreBBs - The existing output blocks. -void createSwitchStatement(Module &M, OutlinableGroup &OG, BasicBlock *EndBB, - ArrayRef OutputStoreBBs) { +void createSwitchStatement( + Module &M, OutlinableGroup &OG, DenseMap &EndBBs, + ArrayRef> OutputStoreBBs) { // We only need the switch statement if there is more than one store // combination. if (OG.OutputGVNCombinations.size() > 1) { Function *AggFunc = OG.OutlinedFunction; - // Create a final block - BasicBlock *ReturnBlock = - BasicBlock::Create(M.getContext(), "final_block", AggFunc); + // Create a final block for each different return block. + unsigned FinalBlockIdx = 0; + for (std::pair &OutputBlock : EndBBs) { + BasicBlock *EndBB = OutputBlock.second; + BasicBlock *ReturnBlock = BasicBlock::Create( + M.getContext(), + "final_block_" + Twine(static_cast(FinalBlockIdx++)), + AggFunc); Instruction *Term = EndBB->getTerminator(); Term->moveBefore(*ReturnBlock, ReturnBlock->end()); - // Put the switch statement in the old end basic block for the function with - // a fall through to the new return block + // Put the switch statement in the old end basic block for the function + // with a fall through to the new return block. LLVM_DEBUG(dbgs() << "Create switch statement in " << *AggFunc << " for " << OutputStoreBBs.size() << "\n"); SwitchInst *SwitchI = @@ -1152,28 +1244,41 @@ ReturnBlock, OutputStoreBBs.size(), EndBB); unsigned Idx = 0; - for (BasicBlock *BB : OutputStoreBBs) { - SwitchI->addCase(ConstantInt::get(Type::getInt32Ty(M.getContext()), Idx), - BB); + for (DenseMap OutputStoreBB : OutputStoreBBs) { + DenseMap::iterator OSBBIt = + OutputStoreBB.find(OutputBlock.first); + + if (OSBBIt == OutputStoreBB.end()) + continue; + BasicBlock *BB = OSBBIt->second; + SwitchI->addCase( + ConstantInt::get(Type::getInt32Ty(M.getContext()), Idx), BB); Term = BB->getTerminator(); Term->setSuccessor(0, ReturnBlock); Idx++; + } } return; } - // If there needs to be stores, move them from the output block to the end - // block to save on branching instructions. + // If there needs to be stores, move them from the output blocks to their + // corresponding ending block. if (OutputStoreBBs.size() == 1) { LLVM_DEBUG(dbgs() << "Move store instructions to the end block in " << *OG.OutlinedFunction << "\n"); - BasicBlock *OutputBlock = OutputStoreBBs[0]; - Instruction *Term = OutputBlock->getTerminator(); + DenseMap OutputBlocks = OutputStoreBBs[0]; + for (std::pair &VBPair : OutputBlocks) { + DenseMap::iterator EndBBIt = EndBBs.find(VBPair.first); + assert(EndBBIt != EndBBs.end() && "Could not find end block"); + BasicBlock *EndBB = EndBBIt->second; + BasicBlock *OutputBB = VBPair.second; + Instruction *Term = OutputBB->getTerminator(); Term->eraseFromParent(); Term = EndBB->getTerminator(); - moveBBContents(*OutputBlock, *EndBB); + moveBBContents(*OutputBB, *EndBB); Term->moveBefore(*EndBB, EndBB->end()); - OutputBlock->eraseFromParent(); + OutputBB->eraseFromParent(); + } } } @@ -1188,9 +1293,10 @@ /// set of stores needed for the different functions. /// \param [in,out] FuncsToRemove - Extracted functions to erase from module /// once outlining is complete. -static void fillOverallFunction(Module &M, OutlinableGroup &CurrentGroup, - std::vector &OutputStoreBBs, - std::vector &FuncsToRemove) { +static void fillOverallFunction( + Module &M, OutlinableGroup &CurrentGroup, + std::vector> &OutputStoreBBs, + std::vector &FuncsToRemove) { OutlinableRegion *CurrentOS = CurrentGroup.Regions[0]; // Move first extracted function's instructions into new function. @@ -1198,34 +1304,60 @@ << *CurrentOS->ExtractedFunction << " to instruction " << *CurrentGroup.OutlinedFunction << "\n"); - CurrentGroup.EndBB = moveFunctionData(*CurrentOS->ExtractedFunction, - *CurrentGroup.OutlinedFunction); + moveFunctionData(*CurrentOS->ExtractedFunction, + *CurrentGroup.OutlinedFunction, CurrentGroup.EndBBs); // Transfer the attributes from the function to the new function. for (Attribute A : CurrentOS->ExtractedFunction->getAttributes().getFnAttributes()) CurrentGroup.OutlinedFunction->addFnAttr(A); - // Create an output block for the first extracted function. + // Create a new set of output blocks for the first extracted function. + DenseMap NewBBs; + unsigned Idx = 0; + for (std::pair &VtoBB : CurrentGroup.EndBBs) { BasicBlock *NewBB = BasicBlock::Create( - M.getContext(), Twine("output_block_") + Twine(static_cast(0)), + M.getContext(), + Twine("output_block_") + Twine(static_cast(0)) + Twine("_") + + Twine(static_cast(Idx++)), CurrentGroup.OutlinedFunction); + NewBBs.insert(std::make_pair(VtoBB.first, NewBB)); + } CurrentOS->OutputBlockNum = 0; - replaceArgumentUses(*CurrentOS, NewBB); + replaceArgumentUses(*CurrentOS, NewBBs, true); replaceConstants(*CurrentOS); - // If the new basic block has no new stores, we can erase it from the module. - // It it does, we create a branch instruction to the last basic block from the - // new one. + // If a new basic block has no new stores, we can erase it from the module. + // If it does, we create a branch instruction to the basic block to the return + // block for the function. + BasicBlock *NewBB; + Value *RetValueForBB; + OutputStoreBBs.push_back(DenseMap()); + bool AllRemoved = true; + SmallVector ToRemove; + for (std::pair &VtoBB : NewBBs) { + RetValueForBB = VtoBB.first; + NewBB = VtoBB.second; + if (NewBB->size() == 0) { CurrentOS->OutputBlockNum = -1; NewBB->eraseFromParent(); - } else { - BranchInst::Create(CurrentGroup.EndBB, NewBB); - OutputStoreBBs.push_back(NewBB); + ToRemove.push_back(RetValueForBB); + continue; + } + + AllRemoved = false; + DenseMap::iterator VBBIt = + CurrentGroup.EndBBs.find(RetValueForBB); + BasicBlock *EndBB = VBBIt->second; + BranchInst::Create(EndBB, NewBB); + OutputStoreBBs.back().insert(std::make_pair(RetValueForBB, NewBB)); } + if (AllRemoved) + CurrentOS->OutputBlockNum = -1; + // Replace the call to the extracted function with the outlined function. CurrentOS->Call = replaceCalledFunction(M, *CurrentOS); @@ -1239,7 +1371,7 @@ std::vector &FuncsToRemove, unsigned &OutlinedFunctionNum) { createFunction(M, CurrentGroup, OutlinedFunctionNum); - std::vector OutputStoreBBs; + std::vector> OutputStoreBBs; OutlinableRegion *CurrentOS; @@ -1250,14 +1382,21 @@ AttributeFuncs::mergeAttributesForOutlining(*CurrentGroup.OutlinedFunction, *CurrentOS->ExtractedFunction); - // Create a new BasicBlock to hold the needed store instructions. + // Create a set of BasicBlocks, one for each return block, to hold the + // needed store instructions. + DenseMap NewBBs; + unsigned BBIdx = 0; + for (std::pair &VtoBB : CurrentGroup.EndBBs) { BasicBlock *NewBB = BasicBlock::Create( - M.getContext(), "output_block_" + std::to_string(Idx), + M.getContext(), + Twine("output_block_") + Twine(static_cast(Idx)) + + Twine("_") + Twine(static_cast(BBIdx++)), CurrentGroup.OutlinedFunction); - replaceArgumentUses(*CurrentOS, NewBB); - - alignOutputBlockWithAggFunc(CurrentGroup, *CurrentOS, NewBB, - CurrentGroup.EndBB, OutputMappings, + NewBBs.insert(std::make_pair(VtoBB.first, NewBB)); + } + replaceArgumentUses(*CurrentOS, NewBBs); + alignOutputBlockWithAggFunc(CurrentGroup, *CurrentOS, NewBBs, + CurrentGroup.EndBBs, OutputMappings, OutputStoreBBs); CurrentOS->Call = replaceCalledFunction(M, *CurrentOS); @@ -1265,7 +1404,7 @@ } // Create a switch statement to handle the different output schemes. - createSwitchStatement(M, CurrentGroup, CurrentGroup.EndBB, OutputStoreBBs); + createSwitchStatement(M, CurrentGroup, CurrentGroup.EndBBs, OutputStoreBBs); OutlinedFunctionNum++; }