Index: lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- lib/Transforms/Utils/CodeExtractor.cpp +++ lib/Transforms/Utils/CodeExtractor.cpp @@ -50,9 +50,9 @@ // for functions produced by the code extractor. This is useful when converting // extracted functions to pthread-based code, as only one argument (void*) can // be passed in to pthread_create(). -static cl::opt -AggregateArgsOpt("aggregate-extracted-args", cl::Hidden, - cl::desc("Aggregate arguments to code-extracted functions")); +static cl::opt AggregateArgsOpt( + "aggregate-extracted-args", cl::Hidden, + cl::desc("Aggregate arguments to code-extracted functions")); /// \brief Test whether a block is valid for extraction. bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) { @@ -126,8 +126,7 @@ for (SetVector::iterator I = std::next(Result.begin()), E = Result.end(); I != E; ++I) - for (pred_iterator PI = pred_begin(*I), PE = pred_end(*I); - PI != PE; ++PI) + for (pred_iterator PI = pred_begin(*I), PE = pred_end(*I); PI != PE; ++PI) assert(Result.count(*PI) && "No blocks in this region may have entries from outside the region" " except for the first block!"); @@ -162,7 +161,8 @@ /// function being code extracted, but not in the region being extracted. /// These values must be passed in as live-ins to the function. static bool definedInCaller(const SetVector &Blocks, Value *V) { - if (isa(V)) return true; + if (isa(V)) + return true; if (Instruction *I = dyn_cast(V)) if (!Blocks.count(I->getParent())) return true; @@ -437,7 +437,8 @@ if (Header != &Header->getParent()->getEntryBlock()) { PHINode *PN = dyn_cast(Header->begin()); - if (!PN) return; // No PHI nodes. + if (!PN) + return; // No PHI nodes. // If the header node contains any PHI nodes, check to see if there is more // than one entry from outside the region. If so, we need to sever the @@ -450,7 +451,8 @@ // If there is one (or fewer) predecessor from outside the region, we don't // need to do anything special. - if (NumPredsOutsideRegion <= 1) return; + if (NumPredsOutsideRegion <= 1) + return; } // Otherwise, we need to split the header block into two pieces: one @@ -478,8 +480,8 @@ TI->replaceUsesOfWith(OldPred, NewBB); } - // Okay, everything within the region is now branching to the right block, we - // just have to update the PHI nodes now, inserting PHI nodes into NewBB. + // Okay, everything within the region is now branching to the right block, + // we just have to update the PHI nodes now, inserting PHI nodes into NewBB. BasicBlock::iterator AfterPHIs; for (AfterPHIs = OldPred->begin(); isa(AfterPHIs); ++AfterPHIs) { PHINode *PN = cast(AfterPHIs); @@ -531,20 +533,25 @@ BasicBlock *header, BasicBlock *newRootNode, BasicBlock *newHeader, - Function *oldFunction, - Module *M) { + Function *oldFunction, Module *M) { DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); DEBUG(dbgs() << "outputs: " << outputs.size() << "\n"); // This function returns unsigned, outputs will go back by reference. switch (NumExitBlocks) { case 0: - case 1: RetTy = Type::getVoidTy(header->getContext()); break; - case 2: RetTy = Type::getInt1Ty(header->getContext()); break; - default: RetTy = Type::getInt16Ty(header->getContext()); break; + case 1: + RetTy = Type::getVoidTy(header->getContext()); + break; + case 2: + RetTy = Type::getInt1Ty(header->getContext()); + break; + default: + RetTy = Type::getInt16Ty(header->getContext()); + break; } - std::vector paramTy; + std::vector paramTy; // Add the types of the input values to the function's argument list for (Value *value : inputs) { @@ -574,14 +581,12 @@ paramTy.clear(); paramTy.push_back(PointerType::getUnqual(StructTy)); } - FunctionType *funcType = - FunctionType::get(RetTy, paramTy, false); + FunctionType *funcType = FunctionType::get(RetTy, paramTy, false); // Create the new function - Function *newFunction = Function::Create(funcType, - GlobalValue::InternalLinkage, - oldFunction->getName() + "_" + - header->getName(), M); + Function *newFunction = + Function::Create(funcType, GlobalValue::InternalLinkage, + oldFunction->getName() + "_" + header->getName(), M); // If the old function is no-throw, so is the new one. if (oldFunction->doesNotThrow()) newFunction->setDoesNotThrow(); @@ -620,7 +625,7 @@ } else RewriteVal = &*AI++; - std::vector Users(inputs[i]->user_begin(), inputs[i]->user_end()); + std::vector Users(inputs[i]->user_begin(), inputs[i]->user_end()); for (User *use : Users) if (Instruction *inst = dyn_cast(use)) if (Blocks.count(inst->getParent())) @@ -633,13 +638,13 @@ for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI) AI->setName(inputs[i]->getName()); for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI) - AI->setName(outputs[i]->getName()+".out"); + AI->setName(outputs[i]->getName() + ".out"); } // Rewrite branches to basic blocks outside of the loop to new dummy blocks // within the new function. This must be done before we lose track of which // blocks were originally in the code region. - std::vector Users(header->user_begin(), header->user_end()); + std::vector Users(header->user_begin(), header->user_end()); for (unsigned i = 0, e = Users.size(); i != e; ++i) // The BasicBlock which contains the branch is not in the region // modify the branch target to a new block @@ -654,11 +659,11 @@ /// FindPhiPredForUseInBlock - Given a value and a basic block, find a PHI /// that uses the value within the basic block, and return the predecessor /// block associated with that use, or return 0 if none is found. -static BasicBlock* FindPhiPredForUseInBlock(Value* Used, BasicBlock* BB) { +static BasicBlock *FindPhiPredForUseInBlock(Value *Used, BasicBlock *BB) { for (Use &U : Used->uses()) { - PHINode *P = dyn_cast(U.getUser()); - if (P && P->getParent() == BB) - return P->getIncomingBlock(U); + PHINode *P = dyn_cast(U.getUser()); + if (P && P->getParent() == BB) + return P->getIncomingBlock(U); } return nullptr; @@ -667,12 +672,13 @@ /// emitCallAndSwitchStatement - This method sets up the caller side by adding /// the call instruction, splitting any PHI nodes in the header block as /// necessary. -void CodeExtractor:: -emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, - ValueSet &inputs, ValueSet &outputs) { +void CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, + BasicBlock *codeReplacer, + ValueSet &inputs, + ValueSet &outputs) { // Emit a call to the new function, passing in: *pointer to struct (if // aggregating parameters), or plan inputs and allocated memory for outputs - std::vector params, StructValues, ReloadOutputs, Reloads; + std::vector params, StructValues, ReloadOutputs, Reloads; Module *M = newFunction->getParent(); LLVMContext &Context = M->getContext(); @@ -691,9 +697,9 @@ StructValues.push_back(output); } else { AllocaInst *alloca = - new AllocaInst(output->getType(), DL.getAllocaAddrSpace(), - nullptr, output->getName() + ".loc", - &codeReplacer->getParent()->front().front()); + new AllocaInst(output->getType(), DL.getAllocaAddrSpace(), nullptr, + output->getName() + ".loc", + &codeReplacer->getParent()->front().front()); ReloadOutputs.push_back(alloca); params.push_back(alloca); } @@ -702,9 +708,9 @@ StructType *StructArgTy = nullptr; AllocaInst *Struct = nullptr; if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { - std::vector ArgTypes; - for (ValueSet::iterator v = StructValues.begin(), - ve = StructValues.end(); v != ve; ++v) + std::vector ArgTypes; + for (ValueSet::iterator v = StructValues.begin(), ve = StructValues.end(); + v != ve; ++v) ArgTypes.push_back((*v)->getType()); // Allocate a struct at the beginning of this function @@ -737,6 +743,7 @@ std::advance(OutputArgBegin, inputs.size()); // Reload the outputs passed in by reference + Function::arg_iterator OAI = OutputArgBegin; for (unsigned i = 0, e = outputs.size(); i != e; ++i) { Value *Output = nullptr; if (AggregateArgs) { @@ -750,15 +757,38 @@ } else { Output = ReloadOutputs[i]; } - LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload"); + LoadInst *load = new LoadInst(Output, outputs[i]->getName() + ".reload"); Reloads.push_back(load); codeReplacer->getInstList().push_back(load); - std::vector Users(outputs[i]->user_begin(), outputs[i]->user_end()); + std::vector Users(outputs[i]->user_begin(), outputs[i]->user_end()); for (unsigned u = 0, e = Users.size(); u != e; ++u) { Instruction *inst = cast(Users[u]); if (!Blocks.count(inst->getParent())) inst->replaceUsesOfWith(outputs[i], load); } + + // Store to argument right after the definition of output value + auto *OutI = dyn_cast(outputs[i]); + if (!OutI) + continue; + // Find proper insertion point + Instruction *InsertPt = OutI->getNextNode(); + // Let's assume that there is no other guy interleave + // non-PHI in PHIs + if (isa(InsertPt)) + InsertPt = InsertPt->getParent()->getFirstNonPHI(); + + if (AggregateArgs) { + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), InsertPt); + new StoreInst(outputs[i], GEP, InsertPt); + } else { + new StoreInst(outputs[i], &*OAI, InsertPt); + ++OAI; + } } // Now we can emit a switch statement using the call as a value. @@ -771,7 +801,7 @@ // over all of the blocks in the extracted region, updating any terminator // instructions in the to-be-extracted region that branch to blocks that are // not in the region to be extracted. - std::map ExitBlockMap; + std::map ExitBlockMap; unsigned switchVal = 0; for (BasicBlock *Block : Blocks) { @@ -784,16 +814,16 @@ if (!NewTarget) { // If we don't already have an exit stub for this non-extracted // destination, create one now! - NewTarget = BasicBlock::Create(Context, - OldTarget->getName() + ".exitStub", - newFunction); + NewTarget = BasicBlock::Create( + Context, OldTarget->getName() + ".exitStub", newFunction); unsigned SuccNum = switchVal++; Value *brVal = nullptr; switch (NumExitBlocks) { case 0: - case 1: break; // No value needed. - case 2: // Conditional branch, return a bool + case 1: + break; // No value needed. + case 2: // Conditional branch, return a bool brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); break; default: @@ -801,75 +831,11 @@ break; } - ReturnInst *NTRet = ReturnInst::Create(Context, brVal, NewTarget); + ReturnInst::Create(Context, brVal, NewTarget); // Update the switch instruction. - TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), - SuccNum), - OldTarget); - - // Restore values just before we exit - Function::arg_iterator OAI = OutputArgBegin; - for (unsigned out = 0, e = outputs.size(); out != e; ++out) { - // For an invoke, the normal destination is the only one that is - // dominated by the result of the invocation - BasicBlock *DefBlock = cast(outputs[out])->getParent(); - - bool DominatesDef = true; - - BasicBlock *NormalDest = nullptr; - if (auto *Invoke = dyn_cast(outputs[out])) - NormalDest = Invoke->getNormalDest(); - - if (NormalDest) { - DefBlock = NormalDest; - - // Make sure we are looking at the original successor block, not - // at a newly inserted exit block, which won't be in the dominator - // info. - for (const auto &I : ExitBlockMap) - if (DefBlock == I.second) { - DefBlock = I.first; - break; - } - - // In the extract block case, if the block we are extracting ends - // with an invoke instruction, make sure that we don't emit a - // store of the invoke value for the unwind block. - if (!DT && DefBlock != OldTarget) - DominatesDef = false; - } - - if (DT) { - DominatesDef = DT->dominates(DefBlock, OldTarget); - - // If the output value is used by a phi in the target block, - // then we need to test for dominance of the phi's predecessor - // instead. Unfortunately, this a little complicated since we - // have already rewritten uses of the value to uses of the reload. - BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out], - OldTarget); - if (pred && DT && DT->dominates(DefBlock, pred)) - DominatesDef = true; - } - - if (DominatesDef) { - if (AggregateArgs) { - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), - FirstOut+out); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, &*OAI, Idx, "gep_" + outputs[out]->getName(), - NTRet); - new StoreInst(outputs[out], GEP, NTRet); - } else { - new StoreInst(outputs[out], &*OAI, NTRet); - } - } - // Advance output iterator even if we don't emit a store - if (!AggregateArgs) ++OAI; - } + TheSwitch->addCase( + ConstantInt::get(Type::getInt16Ty(Context), SuccNum), OldTarget); } // rewrite the original branch instruction with this new target @@ -887,15 +853,15 @@ // Check if the function should return a value if (OldFnRetTy->isVoidTy()) { - ReturnInst::Create(Context, nullptr, TheSwitch); // Return void + ReturnInst::Create(Context, nullptr, TheSwitch); // Return void } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) { // return what we have ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch); } else { // Otherwise we must have code extracted an unwind or something, just // return whatever we want. - ReturnInst::Create(Context, - Constant::getNullValue(OldFnRetTy), TheSwitch); + ReturnInst::Create(Context, Constant::getNullValue(OldFnRetTy), + TheSwitch); } TheSwitch->eraseFromParent(); @@ -917,7 +883,7 @@ TheSwitch->setCondition(call); TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks)); // Remove redundant case - TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1)); + TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks - 1)); break; } } @@ -1015,14 +981,13 @@ Function *oldFunction = header->getParent(); // This takes place of the original loop - BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), - "codeRepl", oldFunction, - header); + BasicBlock *codeReplacer = + BasicBlock::Create(header->getContext(), "codeRepl", oldFunction, header); // The new function needs a root node because other nodes can branch to the // head of the region, but the entry node of a function cannot have preds. - BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(), - "newFuncRoot"); + BasicBlock *newFuncRoot = + BasicBlock::Create(header->getContext(), "newFuncRoot"); newFuncRoot->getInstList().push_back(BranchInst::Create(header)); findAllocas(SinkingCands, HoistingCands, CommonExit); @@ -1063,10 +1028,9 @@ NumExitBlocks = ExitBlocks.size(); // Construct new function based on inputs/outputs & add allocas for all defs. - Function *newFunction = constructFunction(inputs, outputs, header, - newFuncRoot, - codeReplacer, oldFunction, - oldFunction->getParent()); + Function *newFunction = + constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer, + oldFunction, oldFunction->getParent()); // Update the entry count of the function. if (BFI) { @@ -1097,12 +1061,12 @@ // Look at all successors of the codeReplacer block. If any of these blocks // had PHI nodes in them, we need to update the "from" block to be the code // replacer, not the original block in the extracted region. - std::vector Succs(succ_begin(codeReplacer), - succ_end(codeReplacer)); + std::vector Succs(succ_begin(codeReplacer), + succ_end(codeReplacer)); for (unsigned i = 0, e = Succs.size(); i != e; ++i) for (BasicBlock::iterator I = Succs[i]->begin(); isa(I); ++I) { PHINode *PN = cast(I); - std::set ProcessedPreds; + std::set ProcessedPreds; for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) if (Blocks.count(PN->getIncomingBlock(i))) { if (ProcessedPreds.insert(PN->getIncomingBlock(i)).second) @@ -1111,12 +1075,13 @@ // There were multiple entries in the PHI for this block, now there // is only one, so remove the duplicated entries. PN->removeIncomingValue(i, false); - --i; --e; + --i; + --e; } } } - DEBUG(if (verifyFunction(*newFunction)) - report_fatal_error("verifyFunction failed!")); + DEBUG(if (verifyFunction(*newFunction)) + report_fatal_error("verifyFunction failed!")); return newFunction; } Index: unittests/Transforms/Utils/CMakeLists.txt =================================================================== --- unittests/Transforms/Utils/CMakeLists.txt +++ unittests/Transforms/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ set(LLVM_LINK_COMPONENTS Analysis + AsmParser Core Support TransformUtils @@ -8,6 +9,7 @@ add_llvm_unittest(UtilsTests ASanStackFrameLayoutTest.cpp Cloning.cpp + CodeExtractor.cpp FunctionComparator.cpp IntegerDivision.cpp Local.cpp Index: unittests/Transforms/Utils/CodeExtractor.cpp =================================================================== --- /dev/null +++ unittests/Transforms/Utils/CodeExtractor.cpp @@ -0,0 +1,68 @@ +//===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/CodeExtractor.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { +TEST(CodeExtractor, ExitStub) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + define i32 @foo(i32 %x, i32 %y, i32 %z) { + header: + %0 = icmp ugt i32 %x, %y + br i1 %0, label %body1, label %body2 + + body1: + %1 = add i32 %z, 2 + br label %notExtracted + + body2: + %2 = mul i32 %z, 7 + br label %notExtracted + + notExtracted: + %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ] + %4 = add i32 %3, %x + ret i32 %4 + } + )invalid", Err, Ctx)); + + Function *Func = M->getFunction("foo"); + SmallVector Candidates; + for (auto &BB : *Func) { + if (BB.getName() == "body1") + Candidates.push_back(&BB); + if (BB.getName() == "body2") + Candidates.push_back(&BB); + } + // CodeExtractor require first element + // should dominate others + Candidates.insert(Candidates.begin(), &Func->getEntryBlock()); + + DominatorTree DT(*Func); + CodeExtractor CE(Candidates, &DT); + EXPECT_TRUE(CE.isEligible()); + + Function *Outlined = CE.extractCodeRegion(); + EXPECT_TRUE(Outlined); + EXPECT_FALSE(verifyFunction(*Outlined)); +} +} // end anonymous namespace