Index: include/llvm/Transforms/Utils/CodeExtractor.h =================================================================== --- include/llvm/Transforms/Utils/CodeExtractor.h +++ include/llvm/Transforms/Utils/CodeExtractor.h @@ -47,6 +47,7 @@ // Various bits of state computed on construction. DominatorTree *const DT; const bool AggregateArgs; + BasicBlock * CodeReplacer; // Bits of intermediate state computed at various phases of extraction. SetVector Blocks; @@ -59,6 +60,7 @@ /// In this formation, we don't require a dominator tree. The given basic /// block is set up for extraction. CodeExtractor(BasicBlock *BB, bool AggregateArgs = false); + BasicBlock * getCodeReplacerIfAvailable() const; /// \brief Create a code extractor for a sequence of blocks. /// @@ -111,13 +113,12 @@ Function *constructFunction(const ValueSet &inputs, const ValueSet &outputs, BasicBlock *header, - BasicBlock *newRootNode, BasicBlock *newHeader, + BasicBlock *newRootNode, Function *oldFunction, Module *M); void moveCodeToFunction(Function *newFunction); void emitCallAndSwitchStatement(Function *newFunction, - BasicBlock *newHeader, ValueSet &inputs, ValueSet &outputs); }; Index: lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- lib/Transforms/Utils/CodeExtractor.cpp +++ lib/Transforms/Utils/CodeExtractor.cpp @@ -120,22 +120,26 @@ } CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs) - : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} + : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt), + CodeReplacer(nullptr), Blocks(buildExtractionBlockSet(BB)), + NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT, bool AggregateArgs) - : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} + : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), + CodeReplacer(nullptr), Blocks(buildExtractionBlockSet(BBs)), + NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs) - : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {} + : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), + CodeReplacer(nullptr), Blocks(buildExtractionBlockSet(L.getBlocks())), + NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN, bool AggregateArgs) - : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} + : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), + CodeReplacer(nullptr), Blocks(buildExtractionBlockSet(RN)), + NumExitBlocks(~0U) {} /// definedInRegion - Return true if the specified value is defined in the /// extracted region. @@ -292,7 +296,6 @@ const ValueSet &outputs, BasicBlock *header, BasicBlock *newRootNode, - BasicBlock *newHeader, Function *oldFunction, Module *M) { DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); @@ -397,7 +400,7 @@ if (TerminatorInst *TI = dyn_cast(Users[i])) if (!Blocks.count(TI->getParent()) && TI->getParent()->getParent() == oldFunction) - TI->replaceUsesOfWith(header, newHeader); + TI->replaceUsesOfWith(header, CodeReplacer); return newFunction; } @@ -419,8 +422,8 @@ /// the call instruction, splitting any PHI nodes in the header block as /// necessary. void CodeExtractor:: -emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, - ValueSet &inputs, ValueSet &outputs) { +emitCallAndSwitchStatement(Function *newFunction, ValueSet &inputs, + ValueSet &outputs) { // Emit a call to the new function, passing in: *pointer to struct (if // aggregating parameters), or plan inputs and allocated memory for outputs std::vector params, StructValues, ReloadOutputs, Reloads; @@ -441,7 +444,7 @@ } else { AllocaInst *alloca = new AllocaInst((*i)->getType(), nullptr, (*i)->getName()+".loc", - codeReplacer->getParent()->begin()->begin()); + CodeReplacer->getParent()->begin()->begin()); ReloadOutputs.push_back(alloca); params.push_back(alloca); } @@ -458,7 +461,7 @@ Type *StructArgTy = StructType::get(newFunction->getContext(), ArgTypes); Struct = new AllocaInst(StructArgTy, nullptr, "structArg", - codeReplacer->getParent()->begin()->begin()); + CodeReplacer->getParent()->begin()->begin()); params.push_back(Struct); for (unsigned i = 0, e = inputs.size(); i != e; ++i) { @@ -468,16 +471,16 @@ GetElementPtrInst *GEP = GetElementPtrInst::Create(Struct, Idx, "gep_" + StructValues[i]->getName()); - codeReplacer->getInstList().push_back(GEP); + CodeReplacer->getInstList().push_back(GEP); StoreInst *SI = new StoreInst(StructValues[i], GEP); - codeReplacer->getInstList().push_back(SI); + CodeReplacer->getInstList().push_back(SI); } } // Emit the call to the function CallInst *call = CallInst::Create(newFunction, params, NumExitBlocks > 1 ? "targetBlock" : ""); - codeReplacer->getInstList().push_back(call); + CodeReplacer->getInstList().push_back(call); Function::arg_iterator OutputArgBegin = newFunction->arg_begin(); unsigned FirstOut = inputs.size(); @@ -494,14 +497,14 @@ GetElementPtrInst *GEP = GetElementPtrInst::Create(Struct, Idx, "gep_reload_" + outputs[i]->getName()); - codeReplacer->getInstList().push_back(GEP); + CodeReplacer->getInstList().push_back(GEP); Output = GEP; } else { Output = ReloadOutputs[i]; } LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload"); Reloads.push_back(load); - codeReplacer->getInstList().push_back(load); + CodeReplacer->getInstList().push_back(load); std::vector Users(outputs[i]->user_begin(), outputs[i]->user_end()); for (unsigned u = 0, e = Users.size(); u != e; ++u) { Instruction *inst = cast(Users[u]); @@ -513,7 +516,7 @@ // Now we can emit a switch statement using the call as a value. SwitchInst *TheSwitch = SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)), - codeReplacer, 0, codeReplacer); + CodeReplacer, 0, CodeReplacer); // Since there may be multiple exits from the original region, make the new // function return an unsigned, switch on that number. This loop iterates @@ -685,6 +688,12 @@ } } +// Return CodeReplacer block. +BasicBlock * CodeExtractor::getCodeReplacerIfAvailable() const +{ + return CodeReplacer; +} + Function *CodeExtractor::extractCodeRegion() { if (!isEligible()) return nullptr; @@ -705,9 +714,9 @@ Function *oldFunction = header->getParent(); // This takes place of the original loop - BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), - "codeRepl", oldFunction, - header); + CodeReplacer = BasicBlock::Create(header->getContext(), + "codeRepl", oldFunction, + header); // The new function needs a root node because other nodes can branch to the // head of the region, but the entry node of a function cannot have preds. @@ -729,10 +738,10 @@ // Construct new function based on inputs/outputs & add allocas for all defs. Function *newFunction = constructFunction(inputs, outputs, header, newFuncRoot, - codeReplacer, oldFunction, + oldFunction, oldFunction->getParent()); - emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); + emitCallAndSwitchStatement(newFunction, inputs, outputs); moveCodeToFunction(newFunction); @@ -745,11 +754,11 @@ PN->setIncomingBlock(i, newFuncRoot); } - // Look at all successors of the codeReplacer block. If any of these blocks + // Look at all successors of the CodeReplacer block. If any of these blocks // had PHI nodes in them, we need to update the "from" block to be the code // replacer, not the original block in the extracted region. - std::vector Succs(succ_begin(codeReplacer), - succ_end(codeReplacer)); + std::vector Succs(succ_begin(CodeReplacer), + succ_end(CodeReplacer)); for (unsigned i = 0, e = Succs.size(); i != e; ++i) for (BasicBlock::iterator I = Succs[i]->begin(); isa(I); ++I) { PHINode *PN = cast(I); @@ -757,7 +766,7 @@ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) if (Blocks.count(PN->getIncomingBlock(i))) { if (ProcessedPreds.insert(PN->getIncomingBlock(i)).second) - PN->setIncomingBlock(i, codeReplacer); + PN->setIncomingBlock(i, CodeReplacer); else { // There were multiple entries in the PHI for this block, now there // is only one, so remove the duplicated entries.