diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h --- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -36,6 +36,7 @@ class Module; class Type; class Value; +class StructType; /// A cache for the CodeExtractor analysis. The operation \ref /// CodeExtractor::extractCodeRegion is guaranteed not to invalidate this @@ -102,12 +103,13 @@ // Bits of intermediate state computed at various phases of extraction. SetVector Blocks; - unsigned NumExitBlocks = std::numeric_limits::max(); - Type *RetTy; - // Mapping from the original exit blocks, to the new blocks inside - // the function. - SmallVector OldTargets; + /// Lists of blocks that are branched from the code region to be extracted. + /// Each block is contained at most once. Its order defines the return value + /// of the extracted function, when leaving the extracted function via the + /// first block it returns 0. When leaving via the second entry it returns + /// 1, etc. + SmallVector SwitchCases; // Suffix to use when creating extracted function (appended to the original // function name + "."). If empty, the default is to use the entry block @@ -241,26 +243,61 @@ getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC, Instruction *Addr, BasicBlock *ExitBlock) const; + /// Updates the list of SwitchCases (corresponding to exit blocks) after + /// changes of the control flow or the Blocks list. + void recomputeSwitchCases(); + + /// Return the type used for the return code of the extracted function to + /// indicate which exit block to jump to. + Type *getSwitchType(); + void severSplitPHINodesOfEntry(BasicBlock *&Header); - void severSplitPHINodesOfExits(const SmallPtrSetImpl &Exits); + void severSplitPHINodesOfExits(); void splitReturnBlocks(); - Function *constructFunction(const ValueSet &inputs, - const ValueSet &outputs, - BasicBlock *header, - BasicBlock *newRootNode, BasicBlock *newHeader, - Function *oldFunction, Module *M); - void moveCodeToFunction(Function *newFunction); void calculateNewCallTerminatorWeights( BasicBlock *CodeReplacer, - DenseMap &ExitWeights, + const DenseMap &ExitWeights, BranchProbabilityInfo *BPI); - CallInst *emitCallAndSwitchStatement(Function *newFunction, - BasicBlock *newHeader, - ValueSet &inputs, ValueSet &outputs); + /// Normalizes the control flow of the extracted regions, such as ensuring + /// that the extracted region does not contain a return instruction. + void normalizeCFGForExtraction(BasicBlock *&header); + + /// Generates the function declaration for the function containing the + /// extracted code. + Function *constructFunctionDeclaration(const ValueSet &inputs, + const ValueSet &outputs, + BlockFrequency EntryFreq, + const Twine &Name, + ValueSet &StructValues, + StructType *&StructTy); + + /// Generates the code for the extracted function. That is: a prolog, the + /// moved or copied code from the original function, and epilogs for each + /// exit. + void emitFunctionBody(const ValueSet &inputs, const ValueSet &outputs, + const ValueSet &StructValues, Function *newFunction, + StructType *StructArgTy, BasicBlock *header, + const ValueSet &SinkingCands); + + /// Generates a Basic Block that calls the extracted function. + CallInst *emitReplacerCall(const ValueSet &inputs, const ValueSet &outputs, + const ValueSet &StructValues, + Function *newFunction, StructType *StructArgTy, + Function *oldFunction, BasicBlock *ReplIP, + BlockFrequency EntryFreq, + ArrayRef LifetimesStart, + std::vector &Reloads); + + /// Connects the basic block containing the call to the extracted function + /// into the original function's control flow. + void insertReplacerCall( + Function *oldFunction, BasicBlock *header, BasicBlock *codeReplacer, + const ValueSet &outputs, ArrayRef Reloads, + const DenseMap &ExitWeights); }; } // end namespace llvm diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -433,7 +433,6 @@ } // Now add the old exit block to the outline region. Blocks.insert(CommonExitBlock); - OldTargets.push_back(NewExitBlock); return CommonExitBlock; } @@ -743,9 +742,8 @@ /// outlined region, we split these PHIs on two: one with inputs from region /// and other with remaining incoming blocks; then first PHIs are placed in /// outlined region. -void CodeExtractor::severSplitPHINodesOfExits( - const SmallPtrSetImpl &Exits) { - for (BasicBlock *ExitBB : Exits) { +void CodeExtractor::severSplitPHINodesOfExits() { + for (BasicBlock *ExitBB : SwitchCases) { BasicBlock *NewBB = nullptr; for (PHINode &PN : ExitBB->phis()) { @@ -808,29 +806,18 @@ } } -/// constructFunction - make a function based on inputs and outputs, as follows: -/// f(in0, ..., inN, out0, ..., outN) -Function *CodeExtractor::constructFunction(const ValueSet &inputs, - const ValueSet &outputs, - BasicBlock *header, - BasicBlock *newRootNode, - BasicBlock *newHeader, - Function *oldFunction, - Module *M) { +Function *CodeExtractor::constructFunctionDeclaration( + const ValueSet &inputs, const ValueSet &outputs, BlockFrequency EntryFreq, + const Twine &Name, ValueSet &StructValues, StructType *&StructTy) { LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n"); - // This function returns unsigned, outputs will go back by reference. - switch (NumExitBlocks) { - case 0: - case 1: RetTy = Type::getVoidTy(header->getContext()); break; - case 2: RetTy = Type::getInt1Ty(header->getContext()); break; - default: RetTy = Type::getInt16Ty(header->getContext()); break; - } + Function *oldFunction = Blocks.front()->getParent(); + Module *M = Blocks.front()->getModule(); + // Assemble the function's parameter lists. std::vector ParamTy; std::vector AggParamTy; - ValueSet StructValues; const DataLayout &DL = M->getDataLayout(); // Add the types of the input values to the function's argument list @@ -862,13 +849,12 @@ "Expeced StructValues only with AggregateArgs set"); // Concatenate scalar and aggregate params in ParamTy. - size_t NumScalarParams = ParamTy.size(); - StructType *StructTy = nullptr; - if (AggregateArgs && !AggParamTy.empty()) { + if (!AggParamTy.empty()) { StructTy = StructType::get(M->getContext(), AggParamTy); ParamTy.push_back(PointerType::get(StructTy, DL.getAllocaAddrSpace())); } + Type *RetTy = getSwitchType(); LLVM_DEBUG({ dbgs() << "Function type: " << *RetTy << " f("; for (Type *i : ParamTy) @@ -879,14 +865,14 @@ FunctionType *funcType = FunctionType::get( RetTy, ParamTy, AllowVarArgs && oldFunction->isVarArg()); - std::string SuffixToUse = - Suffix.empty() - ? (header->getName().empty() ? "extracted" : header->getName().str()) - : Suffix; // Create the new function - Function *newFunction = Function::Create( - funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(), - oldFunction->getName() + "." + SuffixToUse, M); + Function *newFunction = + Function::Create(funcType, GlobalValue::InternalLinkage, + oldFunction->getAddressSpace(), Name, M); + + // Propagate personality info to the new function if there is one. + if (oldFunction->hasPersonalityFn()) + newFunction->setPersonalityFn(oldFunction->getPersonalityFn()); // Inherit all of the target dependent attributes and white-listed // target independent attributes. @@ -1000,63 +986,57 @@ newFunction->addFnAttr(Attr); } - newFunction->getBasicBlockList().push_back(newRootNode); // Create scalar and aggregate iterators to name all of the arguments we // inserted. Function::arg_iterator ScalarAI = newFunction->arg_begin(); - Function::arg_iterator AggAI = std::next(ScalarAI, NumScalarParams); - // Rewrite all users of the inputs in the extracted region to use the - // arguments (or appropriate addressing into struct) instead. - for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) { - Value *RewriteVal; - if (AggregateArgs && StructValues.contains(inputs[i])) { - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext())); - Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx); - Instruction *TI = newFunction->begin()->getTerminator(); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI); - RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP, - "loadgep_" + inputs[i]->getName(), TI); - ++aggIdx; - } else - RewriteVal = &*ScalarAI++; + // Set names and attributes for input and output arguments. + ScalarAI = newFunction->arg_begin(); + for (Value *input : inputs) { + if (StructValues.contains(input)) + continue; - std::vector Users(inputs[i]->user_begin(), inputs[i]->user_end()); - for (User *use : Users) - if (Instruction *inst = dyn_cast(use)) - if (Blocks.count(inst->getParent())) - inst->replaceUsesOfWith(inputs[i], RewriteVal); + ScalarAI->setName(input->getName()); + if (input->isSwiftError()) + newFunction->addParamAttr(ScalarAI - newFunction->arg_begin(), + Attribute::SwiftError); + ++ScalarAI; } + for (Value *output : outputs) { + if (StructValues.contains(output)) + continue; - // Set names for input and output arguments. - if (NumScalarParams) { - ScalarAI = newFunction->arg_begin(); - for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI) - if (!StructValues.contains(inputs[i])) - ScalarAI->setName(inputs[i]->getName()); - for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI) - if (!StructValues.contains(outputs[i])) - ScalarAI->setName(outputs[i]->getName() + ".out"); + ScalarAI->setName(output->getName() + ".out"); + ++ScalarAI; } - // Rewrite branches to basic blocks outside of the loop to new dummy blocks - // within the new function. This must be done before we lose track of which - // blocks were originally in the code region. - std::vector Users(header->user_begin(), header->user_end()); - for (auto &U : Users) - // The BasicBlock which contains the branch is not in the region - // modify the branch target to a new block - if (Instruction *I = dyn_cast(U)) - if (I->isTerminator() && I->getFunction() == oldFunction && - !Blocks.count(I->getParent())) - I->replaceUsesOfWith(header, newHeader); + // Update the entry count of the function. + if (BFI) { + auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency()); + if (Count.has_value()) + newFunction->setEntryCount( + ProfileCount(Count.value(), Function::PCT_Real)); // FIXME + } return newFunction; } +static void applyFirstDebugLoc(Function *oldFunction, + ArrayRef Blocks, + Instruction *BranchI) { + if (oldFunction->getSubprogram()) { + any_of(Blocks, [&BranchI](const BasicBlock *BB) { + return any_of(*BB, [&BranchI](const Instruction &I) { + if (!I.getDebugLoc()) + return false; + BranchI->setDebugLoc(I.getDebugLoc()); + return true; + }); + }); + } +} + /// Erase lifetime.start markers which reference inputs to the extraction /// region, and insert the referenced memory into \p LifetimesStart. /// @@ -1138,309 +1118,12 @@ } } -/// emitCallAndSwitchStatement - This method sets up the caller side by adding -/// the call instruction, splitting any PHI nodes in the header block as -/// necessary. -CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, - BasicBlock *codeReplacer, - ValueSet &inputs, - ValueSet &outputs) { - // Emit a call to the new function, passing in: *pointer to struct (if - // aggregating parameters), or plan inputs and allocated memory for outputs - std::vector params, ReloadOutputs, Reloads; - ValueSet StructValues; - - Module *M = newFunction->getParent(); - LLVMContext &Context = M->getContext(); - const DataLayout &DL = M->getDataLayout(); - CallInst *call = nullptr; - - // Add inputs as params, or to be filled into the struct - unsigned ScalarInputArgNo = 0; - SmallVector SwiftErrorArgs; - for (Value *input : inputs) { - if (AggregateArgs && !ExcludeArgsFromAggregate.contains(input)) - StructValues.insert(input); - else { - params.push_back(input); - if (input->isSwiftError()) - SwiftErrorArgs.push_back(ScalarInputArgNo); - } - ++ScalarInputArgNo; - } - - // Create allocas for the outputs - unsigned ScalarOutputArgNo = 0; - for (Value *output : outputs) { - if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) { - StructValues.insert(output); - } else { - AllocaInst *alloca = - new AllocaInst(output->getType(), DL.getAllocaAddrSpace(), - nullptr, output->getName() + ".loc", - &codeReplacer->getParent()->front().front()); - ReloadOutputs.push_back(alloca); - params.push_back(alloca); - ++ScalarOutputArgNo; - } - } - - StructType *StructArgTy = nullptr; - AllocaInst *Struct = nullptr; - unsigned NumAggregatedInputs = 0; - if (AggregateArgs && !StructValues.empty()) { - std::vector ArgTypes; - for (Value *V : StructValues) - ArgTypes.push_back(V->getType()); - - // Allocate a struct at the beginning of this function - StructArgTy = StructType::get(newFunction->getContext(), ArgTypes); - Struct = new AllocaInst( - StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg", - AllocationBlock ? &*AllocationBlock->getFirstInsertionPt() - : &codeReplacer->getParent()->front().front()); - params.push_back(Struct); - - // Store aggregated inputs in the struct. - for (unsigned i = 0, e = StructValues.size(); i != e; ++i) { - if (inputs.contains(StructValues[i])) { - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName()); - codeReplacer->getInstList().push_back(GEP); - new StoreInst(StructValues[i], GEP, codeReplacer); - NumAggregatedInputs++; - } - } - } - - // Emit the call to the function - call = CallInst::Create(newFunction, params, - NumExitBlocks > 1 ? "targetBlock" : ""); - // Add debug location to the new call, if the original function has debug - // info. In that case, the terminator of the entry block of the extracted - // function contains the first debug location of the extracted function, - // set in extractCodeRegion. - if (codeReplacer->getParent()->getSubprogram()) { - if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc()) - call->setDebugLoc(DL); - } - codeReplacer->getInstList().push_back(call); - - // Set swifterror parameter attributes. - for (unsigned SwiftErrArgNo : SwiftErrorArgs) { - call->addParamAttr(SwiftErrArgNo, Attribute::SwiftError); - newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError); - } - - // Reload the outputs passed in by reference, use the struct if output is in - // the aggregate or reload from the scalar argument. - for (unsigned i = 0, e = outputs.size(), scalarIdx = 0, - aggIdx = NumAggregatedInputs; - i != e; ++i) { - Value *Output = nullptr; - if (AggregateArgs && StructValues.contains(outputs[i])) { - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName()); - codeReplacer->getInstList().push_back(GEP); - Output = GEP; - ++aggIdx; - } else { - Output = ReloadOutputs[scalarIdx]; - ++scalarIdx; - } - LoadInst *load = new LoadInst(outputs[i]->getType(), Output, - outputs[i]->getName() + ".reload", - codeReplacer); - Reloads.push_back(load); - std::vector Users(outputs[i]->user_begin(), outputs[i]->user_end()); - for (User *U : Users) { - Instruction *inst = cast(U); - if (!Blocks.count(inst->getParent())) - inst->replaceUsesOfWith(outputs[i], load); - } - } - - // Now we can emit a switch statement using the call as a value. - SwitchInst *TheSwitch = - SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)), - codeReplacer, 0, codeReplacer); - - // Since there may be multiple exits from the original region, make the new - // function return an unsigned, switch on that number. This loop iterates - // over all of the blocks in the extracted region, updating any terminator - // instructions in the to-be-extracted region that branch to blocks that are - // not in the region to be extracted. - std::map ExitBlockMap; - - // Iterate over the previously collected targets, and create new blocks inside - // the function to branch to. - unsigned switchVal = 0; - for (BasicBlock *OldTarget : OldTargets) { - if (Blocks.count(OldTarget)) - continue; - BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; - if (NewTarget) - continue; - - // If we don't already have an exit stub for this non-extracted - // destination, create one now! - NewTarget = BasicBlock::Create(Context, - OldTarget->getName() + ".exitStub", - newFunction); - unsigned SuccNum = switchVal++; - - Value *brVal = nullptr; - assert(NumExitBlocks < 0xffff && "too many exit blocks for switch"); - switch (NumExitBlocks) { - case 0: - case 1: break; // No value needed. - case 2: // Conditional branch, return a bool - brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); - break; - default: - brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); - break; - } - - ReturnInst::Create(Context, brVal, NewTarget); - - // Update the switch instruction. - TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), - SuccNum), - OldTarget); - } - - for (BasicBlock *Block : Blocks) { - Instruction *TI = Block->getTerminator(); - for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { - if (Blocks.count(TI->getSuccessor(i))) - continue; - BasicBlock *OldTarget = TI->getSuccessor(i); - // add a new basic block which returns the appropriate value - BasicBlock *NewTarget = ExitBlockMap[OldTarget]; - assert(NewTarget && "Unknown target block!"); - - // rewrite the original branch instruction with this new target - TI->setSuccessor(i, NewTarget); - } - } - - // Store the arguments right after the definition of output value. - // This should be proceeded after creating exit stubs to be ensure that invoke - // result restore will be placed in the outlined function. - Function::arg_iterator ScalarOutputArgBegin = newFunction->arg_begin(); - std::advance(ScalarOutputArgBegin, ScalarInputArgNo); - Function::arg_iterator AggOutputArgBegin = newFunction->arg_begin(); - std::advance(AggOutputArgBegin, ScalarInputArgNo + ScalarOutputArgNo); - - for (unsigned i = 0, e = outputs.size(), aggIdx = NumAggregatedInputs; i != e; - ++i) { - auto *OutI = dyn_cast(outputs[i]); - if (!OutI) - continue; - - // Find proper insertion point. - BasicBlock::iterator InsertPt; - // In case OutI is an invoke, we insert the store at the beginning in the - // 'normal destination' BB. Otherwise we insert the store right after OutI. - if (auto *InvokeI = dyn_cast(OutI)) - InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt(); - else if (auto *Phi = dyn_cast(OutI)) - InsertPt = Phi->getParent()->getFirstInsertionPt(); - else - InsertPt = std::next(OutI->getIterator()); - - Instruction *InsertBefore = &*InsertPt; - assert((InsertBefore->getFunction() == newFunction || - Blocks.count(InsertBefore->getParent())) && - "InsertPt should be in new function"); - if (AggregateArgs && StructValues.contains(outputs[i])) { - assert(AggOutputArgBegin != newFunction->arg_end() && - "Number of aggregate output arguments should match " - "the number of defined values"); - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(), - InsertBefore); - new StoreInst(outputs[i], GEP, InsertBefore); - ++aggIdx; - // Since there should be only one struct argument aggregating - // all the output values, we shouldn't increment AggOutputArgBegin, which - // always points to the struct argument, in this case. - } else { - assert(ScalarOutputArgBegin != newFunction->arg_end() && - "Number of scalar output arguments should match " - "the number of defined values"); - new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertBefore); - ++ScalarOutputArgBegin; - } - } - - // Now that we've done the deed, simplify the switch instruction. - Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType(); - switch (NumExitBlocks) { - case 0: - // There are no successors (the block containing the switch itself), which - // means that previously this was the last part of the function, and hence - // this should be rewritten as a `ret' - - // Check if the function should return a value - if (OldFnRetTy->isVoidTy()) { - ReturnInst::Create(Context, nullptr, TheSwitch); // Return void - } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) { - // return what we have - ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch); - } else { - // Otherwise we must have code extracted an unwind or something, just - // return whatever we want. - ReturnInst::Create(Context, - Constant::getNullValue(OldFnRetTy), TheSwitch); - } - - TheSwitch->eraseFromParent(); - break; - case 1: - // Only a single destination, change the switch into an unconditional - // branch. - BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch); - TheSwitch->eraseFromParent(); - break; - case 2: - BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2), - call, TheSwitch); - TheSwitch->eraseFromParent(); - break; - default: - // Otherwise, make the default destination of the switch instruction be one - // of the other successors. - TheSwitch->setCondition(call); - TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks)); - // Remove redundant case - TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1)); - break; - } - - // Insert lifetime markers around the reloads of any output values. The - // allocas output values are stored in are only in-use in the codeRepl block. - insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, ReloadOutputs, call); - - return call; -} - void CodeExtractor::moveCodeToFunction(Function *newFunction) { - Function *oldFunc = (*Blocks.begin())->getParent(); + Function *oldFunc = Blocks.front()->getParent(); Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); - auto newFuncIt = newFunction->front().getIterator(); + auto newFuncIt = newFunction->begin(); for (BasicBlock *Block : Blocks) { // Delete the basic block from the old function, and the list of blocks oldBlocks.remove(Block); @@ -1456,7 +1139,7 @@ void CodeExtractor::calculateNewCallTerminatorWeights( BasicBlock *CodeReplacer, - DenseMap &ExitWeights, + const DenseMap &ExitWeights, BranchProbabilityInfo *BPI) { using Distribution = BlockFrequencyInfoImplBase::Distribution; using BlockNode = BlockFrequencyInfoImplBase::BlockNode; @@ -1474,7 +1157,7 @@ // Add each of the frequencies of the successors. for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) { BlockNode ExitNode(i); - uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency(); + uint64_t ExitFreq = ExitWeights.lookup(TI->getSuccessor(i)).getFrequency(); if (ExitFreq != 0) BranchDist.addExit(ExitNode, ExitFreq); else @@ -1641,18 +1324,7 @@ BasicBlock *header = *Blocks.begin(); Function *oldFunction = header->getParent(); - // Calculate the entry frequency of the new function before we change the root - // block. - BlockFrequency EntryFreq; - if (BFI) { - assert(BPI && "Both BPI and BFI are required to preserve profile info"); - for (BasicBlock *Pred : predecessors(header)) { - if (Blocks.count(Pred)) - continue; - EntryFreq += - BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header); - } - } + normalizeCFGForExtraction(header); // Remove @llvm.assume calls that will be moved to the new function from the // old function's assumption cache. @@ -1666,140 +1338,270 @@ } } - // If we have any return instructions in the region, split those blocks so - // that the return is not in the region. - splitReturnBlocks(); + ValueSet SinkingCands, HoistingCands; + BasicBlock *CommonExit = nullptr; + findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); + assert(HoistingCands.empty() || CommonExit); - // Calculate the exit blocks for the extracted region and the total exit - // weights for each of those blocks. - DenseMap ExitWeights; - SmallPtrSet ExitBlocks; - for (BasicBlock *Block : Blocks) { - for (BasicBlock *Succ : successors(Block)) { - if (!Blocks.count(Succ)) { - // Update the branch weight for this successor. - if (BFI) { - BlockFrequency &BF = ExitWeights[Succ]; - BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, Succ); - } - ExitBlocks.insert(Succ); - } - } - } - NumExitBlocks = ExitBlocks.size(); + // Find inputs to, outputs from the code region. + findInputsOutputs(inputs, outputs, SinkingCands); - for (BasicBlock *Block : Blocks) { - Instruction *TI = Block->getTerminator(); - for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { - if (Blocks.count(TI->getSuccessor(i))) + // Collect objects which are inputs to the extraction region and also + // referenced by lifetime start markers within it. The effects of these + // markers must be replicated in the calling function to prevent the stack + // coloring pass from merging slots which store input objects. + ValueSet LifetimesStart; + eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart); + + if (!HoistingCands.empty()) { + auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit); + Instruction *TI = HoistToBlock->getTerminator(); + for (auto *II : HoistingCands) + cast(II)->moveBefore(TI); + recomputeSwitchCases(); + } + + // CFG/ExitBlocks must not change hereafter + + // Calculate the entry frequency of the new function before we change the root + // block. + BlockFrequency EntryFreq; + DenseMap ExitWeights; + if (BFI) { + assert(BPI && "Both BPI and BFI are required to preserve profile info"); + for (BasicBlock *Pred : predecessors(header)) { + if (Blocks.count(Pred)) continue; - BasicBlock *OldTarget = TI->getSuccessor(i); - OldTargets.push_back(OldTarget); + EntryFreq += + BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header); + } + + for (BasicBlock *Succ : SwitchCases) { + for (BasicBlock *Block : predecessors(Succ)) { + if (!Blocks.count(Block)) + continue; + + // Update the branch weight for this successor. + BlockFrequency &BF = ExitWeights[Succ]; + BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, Succ); + } } } + // Determine position for the replacement code. Do so before header is moved + // to the new function. + BasicBlock *ReplIP = header; + while (ReplIP && Blocks.count(ReplIP)) + ReplIP = ReplIP->getNextNode(); + + // Construct new function based on inputs/outputs & add allocas for all defs. + std::string SuffixToUse = + Suffix.empty() + ? (header->getName().empty() ? "extracted" : header->getName().str()) + : Suffix; + + ValueSet StructValues; + StructType *StructTy; + Function *newFunction = constructFunctionDeclaration( + inputs, outputs, EntryFreq, oldFunction->getName() + "." + SuffixToUse, + StructValues, StructTy); + + emitFunctionBody(inputs, outputs, StructValues, newFunction, StructTy, header, + SinkingCands); + + std::vector Reloads; + CallInst *TheCall = emitReplacerCall( + inputs, outputs, StructValues, newFunction, StructTy, oldFunction, ReplIP, + EntryFreq, LifetimesStart.getArrayRef(), Reloads); + + insertReplacerCall(oldFunction, header, TheCall->getParent(), outputs, + Reloads, ExitWeights); + + fixupDebugInfoPostExtraction(*oldFunction, *newFunction, *TheCall); + + // Mark the new function `noreturn` if applicable. Terminators which resume + // exception propagation are treated as returning instructions. This is to + // avoid inserting traps after calls to outlined functions which unwind. + bool doesNotReturn = none_of(*newFunction, [](const BasicBlock &BB) { + const Instruction *Term = BB.getTerminator(); + return isa(Term) || isa(Term); + }); + if (doesNotReturn) + newFunction->setDoesNotReturn(); + + LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) { + newFunction->dump(); + report_fatal_error("verification of newFunction failed!"); + }); + LLVM_DEBUG(if (verifyFunction(*oldFunction)) + report_fatal_error("verification of oldFunction failed!")); + LLVM_DEBUG(if (AC && verifyAssumptionCache(*oldFunction, *newFunction, AC)) + report_fatal_error("Stale Asumption cache for old Function!")); + return newFunction; +} + +void CodeExtractor::normalizeCFGForExtraction(BasicBlock *&header) { + // If we have any return instructions in the region, split those blocks so + // that the return is not in the region. + splitReturnBlocks(); + // If we have to split PHI nodes of the entry or exit blocks, do so now. severSplitPHINodesOfEntry(header); - severSplitPHINodesOfExits(ExitBlocks); - // This takes place of the original loop - BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), - "codeRepl", oldFunction, - header); + // If a PHI in an exit block has multiple invoming values from the outlined + // region, create a new PHI for those values within the region such that only + // PHI itself becomes an output value, not each of its incoming values + // individually. + recomputeSwitchCases(); + severSplitPHINodesOfExits(); +} - // The new function needs a root node because other nodes can branch to the - // head of the region, but the entry node of a function cannot have preds. - BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(), - "newFuncRoot"); - auto *BranchI = BranchInst::Create(header); - // If the original function has debug info, we have to add a debug location - // to the new branch instruction from the artificial entry block. - // We use the debug location of the first instruction in the extracted - // blocks, as there is no other equivalent line in the source code. - if (oldFunction->getSubprogram()) { - any_of(Blocks, [&BranchI](const BasicBlock *BB) { - return any_of(*BB, [&BranchI](const Instruction &I) { - if (!I.getDebugLoc()) - return false; - BranchI->setDebugLoc(I.getDebugLoc()); - return true; - }); - }); +void CodeExtractor::recomputeSwitchCases() { + SwitchCases.clear(); + + SmallPtrSet ExitBlocks; + for (BasicBlock *Block : Blocks) { + for (BasicBlock *Succ : successors(Block)) { + if (Blocks.count(Succ)) + continue; + + bool IsNew = ExitBlocks.insert(Succ).second; + if (IsNew) + SwitchCases.push_back(Succ); + } } - newFuncRoot->getInstList().push_back(BranchI); +} - ValueSet SinkingCands, HoistingCands; - BasicBlock *CommonExit = nullptr; - findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); - assert(HoistingCands.empty() || CommonExit); +Type *CodeExtractor::getSwitchType() { + LLVMContext &Context = Blocks.front()->getContext(); - // Find inputs to, outputs from the code region. - findInputsOutputs(inputs, outputs, SinkingCands); + assert(SwitchCases.size() < 0xffff && "too many exit blocks for switch"); + switch (SwitchCases.size()) { + case 0: + case 1: + return Type::getVoidTy(Context); + case 2: + // Conditional branch, return a bool + return Type::getInt1Ty(Context); + default: + return Type::getInt16Ty(Context); + } +} + +void CodeExtractor::emitFunctionBody( + const ValueSet &inputs, const ValueSet &outputs, + const ValueSet &StructValues, Function *newFunction, + StructType *StructArgTy, BasicBlock *header, const ValueSet &SinkingCands) { + Function *oldFunction = header->getParent(); + LLVMContext &Context = oldFunction->getContext(); + + // The new function needs a root node because other nodes can branch to the + // head of the region, but the entry node of a function cannot have preds. + BasicBlock *newFuncRoot = + BasicBlock::Create(Context, "newFuncRoot", newFunction); // Now sink all instructions which only have non-phi uses inside the region. // Group the allocas at the start of the block, so that any bitcast uses of // the allocas are well-defined. - AllocaInst *FirstSunkAlloca = nullptr; for (auto *II : SinkingCands) { - if (auto *AI = dyn_cast(II)) { - AI->moveBefore(*newFuncRoot, newFuncRoot->getFirstInsertionPt()); - if (!FirstSunkAlloca) - FirstSunkAlloca = AI; + if (!isa(II)) { + cast(II)->moveBefore(*newFuncRoot, + newFuncRoot->getFirstInsertionPt()); } } - assert((SinkingCands.empty() || FirstSunkAlloca) && - "Did not expect a sink candidate without any allocas"); for (auto *II : SinkingCands) { - if (!isa(II)) { - cast(II)->moveAfter(FirstSunkAlloca); + if (auto *AI = dyn_cast(II)) { + AI->moveBefore(*newFuncRoot, newFuncRoot->getFirstInsertionPt()); } } - if (!HoistingCands.empty()) { - auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit); - Instruction *TI = HoistToBlock->getTerminator(); - for (auto *II : HoistingCands) - cast(II)->moveBefore(TI); + Function::arg_iterator ScalarAI = newFunction->arg_begin(); + Argument *AggArg = StructValues.empty() + ? nullptr + : newFunction->getArg(newFunction->arg_size() - 1); + + // Rewrite all users of the inputs in the extracted region to use the + // arguments (or appropriate addressing into struct) instead. + SmallVector NewValues; + for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) { + Value *RewriteVal; + if (StructValues.contains(inputs[i])) { + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext())); + Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, AggArg, Idx, "gep_" + inputs[i]->getName(), newFuncRoot); + RewriteVal = new LoadInst(StructArgTy->getElementType(aggIdx), GEP, + "loadgep_" + inputs[i]->getName(), newFuncRoot); + ++aggIdx; + } else + RewriteVal = &*ScalarAI++; + + NewValues.push_back(RewriteVal); } - // Collect objects which are inputs to the extraction region and also - // referenced by lifetime start markers within it. The effects of these - // markers must be replicated in the calling function to prevent the stack - // coloring pass from merging slots which store input objects. - ValueSet LifetimesStart; - eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart); + moveCodeToFunction(newFunction); - // Construct new function based on inputs/outputs & add allocas for all defs. - Function *newFunction = - constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer, - oldFunction, oldFunction->getParent()); + for (unsigned i = 0, e = inputs.size(); i != e; ++i) { + Value *RewriteVal = NewValues[i]; - // Update the entry count of the function. - if (BFI) { - auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency()); - if (Count) - newFunction->setEntryCount( - ProfileCount(Count.value(), Function::PCT_Real)); // FIXME - BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); + std::vector Users(inputs[i]->user_begin(), inputs[i]->user_end()); + for (User *use : Users) + if (Instruction *inst = dyn_cast(use)) + if (Blocks.count(inst->getParent())) + inst->replaceUsesOfWith(inputs[i], RewriteVal); } - CallInst *TheCall = - emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); + // Since there may be multiple exits from the original region, make the new + // function return an unsigned, switch on that number. This loop iterates + // over all of the blocks in the extracted region, updating any terminator + // instructions in the to-be-extracted region that branch to blocks that are + // not in the region to be extracted. + std::map ExitBlockMap; - moveCodeToFunction(newFunction); + // Iterate over the previously collected targets, and create new blocks inside + // the function to branch to. + for (auto P : enumerate(SwitchCases)) { + BasicBlock *OldTarget = P.value(); + size_t SuccNum = P.index(); - // Replicate the effects of any lifetime start/end markers which referenced - // input objects in the extraction region by placing markers around the call. - insertLifetimeMarkersSurroundingCall( - oldFunction->getParent(), LifetimesStart.getArrayRef(), {}, TheCall); + BasicBlock *NewTarget = BasicBlock::Create( + Context, OldTarget->getName() + ".exitStub", newFunction); + ExitBlockMap[OldTarget] = NewTarget; - // Propagate personality info to the new function if there is one. - if (oldFunction->hasPersonalityFn()) - newFunction->setPersonalityFn(oldFunction->getPersonalityFn()); + Value *brVal = nullptr; + Type *RetTy = getSwitchType(); + assert(SwitchCases.size() < 0xffff && "too many exit blocks for switch"); + switch (SwitchCases.size()) { + case 0: + case 1: + // No value needed. + break; + case 2: // Conditional branch, return a bool + brVal = ConstantInt::get(RetTy, !SuccNum); + break; + default: + brVal = ConstantInt::get(RetTy, SuccNum); + break; + } - // Update the branch weights for the exit block. - if (BFI && NumExitBlocks > 1) - calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); + ReturnInst::Create(Context, brVal, NewTarget); + } + + for (BasicBlock *Block : Blocks) { + Instruction *TI = Block->getTerminator(); + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + if (Blocks.count(TI->getSuccessor(i))) + continue; + BasicBlock *OldTarget = TI->getSuccessor(i); + // add a new basic block which returns the appropriate value + BasicBlock *NewTarget = ExitBlockMap[OldTarget]; + assert(NewTarget && "Unknown target block!"); + + // rewrite the original branch instruction with this new target + TI->setSuccessor(i, NewTarget); + } + } // Loop over all of the PHI nodes in the header and exit blocks, and change // any references to the old incoming edge to be the new incoming edge. @@ -1810,7 +1612,274 @@ PN->setIncomingBlock(i, newFuncRoot); } - for (BasicBlock *ExitBB : ExitBlocks) + // Connect newFunction entry block to new header. + BranchInst *BranchI = BranchInst::Create(header, newFuncRoot); + applyFirstDebugLoc(oldFunction, Blocks.getArrayRef(), BranchI); + + // Store the arguments right after the definition of output value. + // This should be proceeded after creating exit stubs to be ensure that invoke + // result restore will be placed in the outlined function. + ScalarAI = newFunction->arg_begin(); + unsigned AggIdx = 0; + + for (Value *Input : inputs) { + if (StructValues.contains(Input)) + ++AggIdx; + else + ++ScalarAI; + } + + for (Value *Output : outputs) { + // Find proper insertion point. + // In case Output is an invoke, we insert the store at the beginning in the + // 'normal destination' BB. Otherwise we insert the store right after + // Output. + Instruction *InsertBefore = nullptr; + if (auto *InvokeI = dyn_cast(Output)) + InsertBefore = &*InvokeI->getNormalDest()->getFirstInsertionPt(); + else if (auto *Phi = dyn_cast(Output)) + InsertBefore = &*Phi->getParent()->getFirstInsertionPt(); + else if (auto *OutI = dyn_cast(Output)) + InsertBefore = &*std::next(OutI->getIterator()); + + assert((!InsertBefore || InsertBefore->getFunction() == newFunction || + Blocks.count(InsertBefore->getParent())) && + "InsertPt should be in new function"); + + if (StructValues.contains(Output)) { + if (InsertBefore) { + assert(AggArg && "Number of aggregate output arguments should match " + "the number of defined values"); + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), AggIdx); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, AggArg, Idx, "gep_" + Output->getName(), InsertBefore); + new StoreInst(Output, GEP, InsertBefore); + } + ++AggIdx; + } else { + if (InsertBefore) { + assert(ScalarAI != newFunction->arg_end() && + "Number of scalar output arguments should match " + "the number of defined values"); + new StoreInst(Output, &*ScalarAI, InsertBefore); + } + ++ScalarAI; + } + } +} + +CallInst *CodeExtractor::emitReplacerCall( + const ValueSet &inputs, const ValueSet &outputs, + const ValueSet &StructValues, Function *newFunction, + StructType *StructArgTy, Function *oldFunction, BasicBlock *ReplIP, + BlockFrequency EntryFreq, ArrayRef LifetimesStart, + std::vector &Reloads) { + LLVMContext &Context = oldFunction->getContext(); + Module *M = oldFunction->getParent(); + const DataLayout &DL = M->getDataLayout(); + + // This takes place of the original loop + BasicBlock *codeReplacer = + BasicBlock::Create(Context, "codeRepl", oldFunction, ReplIP); + BasicBlock *AllocaBlock = + AllocationBlock ? AllocationBlock : &oldFunction->getEntryBlock(); + + // Update the entry count of the function. + if (BFI) + BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); + + std::vector params; + + // Add inputs as params, or to be filled into the struct + for (Value *input : inputs) { + if (StructValues.contains(input)) + continue; + + params.push_back(input); + } + + // Create allocas for the outputs + std::vector ReloadOutputs; + for (Value *output : outputs) { + if (StructValues.contains(output)) + continue; + + AllocaInst *alloca = new AllocaInst( + output->getType(), DL.getAllocaAddrSpace(), nullptr, + output->getName() + ".loc", &*AllocaBlock->getFirstInsertionPt()); + params.push_back(alloca); + ReloadOutputs.push_back(alloca); + } + + AllocaInst *Struct = nullptr; + if (!StructValues.empty()) { + Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr, + "structArg", &*AllocaBlock->getFirstInsertionPt()); + params.push_back(Struct); + + unsigned AggIdx = 0; + for (Value *input : inputs) { + if (!StructValues.contains(input)) + continue; + + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), AggIdx); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, Struct, Idx, "gep_" + input->getName()); + codeReplacer->getInstList().push_back(GEP); + new StoreInst(input, GEP, codeReplacer); + + ++AggIdx; + } + } + + // Emit the call to the function + CallInst *call = CallInst::Create(newFunction, params, + SwitchCases.size() > 1 ? "targetBlock" : "", + codeReplacer); + + // Set swifterror parameter attributes. + unsigned ParamIdx = 0; + unsigned AggIdx = 0; + for (auto input : inputs) { + if (StructValues.contains(input)) { + ++AggIdx; + } else { + if (input->isSwiftError()) + call->addParamAttr(ParamIdx, Attribute::SwiftError); + ++ParamIdx; + } + } + + // Add debug location to the new call, if the original function has debug + // info. In that case, the terminator of the entry block of the extracted + // function contains the first debug location of the extracted function, + // set in extractCodeRegion. + if (codeReplacer->getParent()->getSubprogram()) { + if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc()) + call->setDebugLoc(DL); + } + + // Reload the outputs passed in by reference, use the struct if output is in + // the aggregate or reload from the scalar argument. + for (unsigned i = 0, e = outputs.size(), scalarIdx = 0; i != e; ++i) { + Value *Output = nullptr; + if (StructValues.contains(outputs[i])) { + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), AggIdx); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName()); + codeReplacer->getInstList().push_back(GEP); + Output = GEP; + ++AggIdx; + } else { + Output = ReloadOutputs[scalarIdx]; + ++scalarIdx; + } + LoadInst *load = + new LoadInst(outputs[i]->getType(), Output, + outputs[i]->getName() + ".reload", codeReplacer); + Reloads.push_back(load); + } + + // Now we can emit a switch statement using the call as a value. + SwitchInst *TheSwitch = + SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)), + codeReplacer, 0, codeReplacer); + for (auto P : enumerate(SwitchCases)) { + BasicBlock *OldTarget = P.value(); + size_t SuccNum = P.index(); + + TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), SuccNum), + OldTarget); + } + + // Now that we've done the deed, simplify the switch instruction. + Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType(); + switch (SwitchCases.size()) { + case 0: + // There are no successors (the block containing the switch itself), which + // means that previously this was the last part of the function, and hence + // this should be rewritten as a `ret' + + // Check if the function should return a value + if (OldFnRetTy->isVoidTy()) { + ReturnInst::Create(Context, nullptr, TheSwitch); // Return void + } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) { + // return what we have + ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch); + } else { + // Otherwise we must have code extracted an unwind or something, just + // return whatever we want. + ReturnInst::Create(Context, Constant::getNullValue(OldFnRetTy), + TheSwitch); + } + + TheSwitch->eraseFromParent(); + break; + case 1: + // Only a single destination, change the switch into an unconditional + // branch. + BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch); + TheSwitch->eraseFromParent(); + break; + case 2: + // Only two destinations, convert to a condition branch. + // Remark: This also swaps the target branches: + // 0 -> false -> getSuccessor(2); 1 -> true -> getSuccessor(1) + BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2), + call, TheSwitch); + TheSwitch->eraseFromParent(); + break; + default: + // Otherwise, make the default destination of the switch instruction be one + // of the other successors. + TheSwitch->setCondition(call); + TheSwitch->setDefaultDest(TheSwitch->getSuccessor(SwitchCases.size())); + // Remove redundant case + TheSwitch->removeCase( + SwitchInst::CaseIt(TheSwitch, SwitchCases.size() - 1)); + break; + } + + // Insert lifetime markers around the reloads of any output values. The + // allocas output values are stored in are only in-use in the codeRepl block. + insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, ReloadOutputs, call); + + // Replicate the effects of any lifetime start/end markers which referenced + // input objects in the extraction region by placing markers around the call. + insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), LifetimesStart, + {}, call); + + return call; +} + +void CodeExtractor::insertReplacerCall( + Function *oldFunction, BasicBlock *header, BasicBlock *codeReplacer, + const ValueSet &outputs, ArrayRef Reloads, + const DenseMap &ExitWeights) { + + // Rewrite branches to basic blocks outside of the loop to new dummy blocks + // within the new function. This must be done before we lose track of which + // blocks were originally in the code region. + std::vector Users(header->user_begin(), header->user_end()); + for (auto &U : Users) + // The BasicBlock which contains the branch is not in the region + // modify the branch target to a new block + if (Instruction *I = dyn_cast(U)) + if (I->isTerminator() && I->getFunction() == oldFunction && + !Blocks.count(I->getParent())) + I->replaceUsesOfWith(header, codeReplacer); + + // When moving the code region it is sufficient to replace all uses to the + // extracted function values. Since the original definition's block + // dominated its use, it will also be dominated by codeReplacer's switch + // which joined multiple exit blocks. + for (BasicBlock *ExitBB : SwitchCases) for (PHINode &PN : ExitBB->phis()) { Value *IncomingCodeReplacerVal = nullptr; for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { @@ -1828,27 +1897,19 @@ } } - fixupDebugInfoPostExtraction(*oldFunction, *newFunction, *TheCall); - - // Mark the new function `noreturn` if applicable. Terminators which resume - // exception propagation are treated as returning instructions. This is to - // avoid inserting traps after calls to outlined functions which unwind. - bool doesNotReturn = none_of(*newFunction, [](const BasicBlock &BB) { - const Instruction *Term = BB.getTerminator(); - return isa(Term) || isa(Term); - }); - if (doesNotReturn) - newFunction->setDoesNotReturn(); + for (unsigned i = 0, e = outputs.size(); i != e; ++i) { + Value *load = Reloads[i]; + std::vector Users(outputs[i]->user_begin(), outputs[i]->user_end()); + for (User *U : Users) { + Instruction *inst = cast(U); + if (inst->getParent()->getParent() == oldFunction) + inst->replaceUsesOfWith(outputs[i], load); + } + } - LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) { - newFunction->dump(); - report_fatal_error("verification of newFunction failed!"); - }); - LLVM_DEBUG(if (verifyFunction(*oldFunction)) - report_fatal_error("verification of oldFunction failed!")); - LLVM_DEBUG(if (AC && verifyAssumptionCache(*oldFunction, *newFunction, AC)) - report_fatal_error("Stale Asumption cache for old Function!")); - return newFunction; + // Update the branch weights for the exit block. + if (BFI && SwitchCases.size() > 1) + calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); } bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc, diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp --- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -7,11 +7,12 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CodeExtractor.h" -#include "llvm/AsmParser/Parser.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/AsmParser/Parser.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" @@ -30,6 +31,13 @@ return nullptr; } +Instruction *getInstByName(Function *F, StringRef Name) { + for (Instruction &I : instructions(F)) + if (I.getName() == Name) + return &I; + return nullptr; +} + TEST(CodeExtractor, ExitStub) { LLVMContext Ctx; SMDiagnostic Err; @@ -513,19 +521,28 @@ target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" target triple = "x86_64-unknown-linux-gnu" - declare void @use(i32) + ; use different types such that an index mismatch will result in a type mismatch during verification. + declare void @use16(i16) + declare void @use32(i32) + declare void @use64(i64) - define void @foo(i32 %a, i32 %b, i32 %c) { + define void @foo(i16 %a, i32 %b, i64 %c) { entry: br label %extract extract: - call void @use(i32 %a) - call void @use(i32 %b) - call void @use(i32 %c) + call void @use16(i16 %a) + call void @use32(i32 %b) + call void @use64(i64 %c) + %d = add i16 21, 21 + %e = add i32 21, 21 + %f = add i64 21, 21 br label %exit exit: + call void @use16(i16 %d) + call void @use32(i32 %e) + call void @use64(i64 %f) ret void } )ir", @@ -544,15 +561,68 @@ BasicBlock *CommonExit = nullptr; CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); CE.findInputsOutputs(Inputs, Outputs, SinkingCands); - // Exclude the first input from the argument aggregate. - CE.excludeArgFromAggregate(Inputs[0]); + // Exclude the middle input and output from the argument aggregate. + CE.excludeArgFromAggregate(Inputs[1]); + CE.excludeArgFromAggregate(Outputs[1]); + + Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); + EXPECT_TRUE(Outlined); + // Expect 3 arguments in the outlined function: the excluded input, the + // excluded output, and the struct aggregate for the remaining inputs. + EXPECT_EQ(Outlined->arg_size(), 3U); + EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); +} + +TEST(CodeExtractor, AllocaBlock) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + define i32 @foo(i32 %x, i32 %y, i32 %z) { + entry: + br label %allocas + allocas: + br label %body + + body: + %w = add i32 %x, %y + br label %notExtracted + + notExtracted: + %r = add i32 %w, %x + ret i32 %r + } + )invalid", + Err, Ctx)); + + Function *Func = M->getFunction("foo"); + SmallVector Candidates{getBlockByName(Func, "body")}; + + BasicBlock *AllocaBlock = getBlockByName(Func, "allocas"); + CodeExtractor CE(Candidates, nullptr, true, nullptr, nullptr, nullptr, false, + false, AllocaBlock); + CE.excludeArgFromAggregate(Func->getArg(0)); + CE.excludeArgFromAggregate(getInstByName(Func, "w")); + EXPECT_TRUE(CE.isEligible()); + + CodeExtractorAnalysisCache CEAC(*Func); + SetVector Inputs, Outputs; Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); EXPECT_TRUE(Outlined); - // Expect 2 arguments in the outlined function: the excluded input and the - // struct aggregate for the remaining inputs. - EXPECT_EQ(Outlined->arg_size(), 2U); EXPECT_FALSE(verifyFunction(*Outlined)); EXPECT_FALSE(verifyFunction(*Func)); + + // The only added allocas may be in the dedicated alloca block. There should + // be one alloca for the struct, and another one for the reload value. + int NumAllocas = 0; + for (Instruction &I : instructions(Func)) { + if (!isa(I)) + continue; + EXPECT_EQ(I.getParent(), AllocaBlock); + NumAllocas += 1; + } + EXPECT_EQ(NumAllocas, 2); } + } // end anonymous namespace