Index: lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- lib/Transforms/Utils/CodeExtractor.cpp +++ lib/Transforms/Utils/CodeExtractor.cpp @@ -737,6 +737,7 @@ std::advance(OutputArgBegin, inputs.size()); // Reload the outputs passed in by reference + Function::arg_iterator OAI = OutputArgBegin; for (unsigned i = 0, e = outputs.size(); i != e; ++i) { Value *Output = nullptr; if (AggregateArgs) { @@ -759,6 +760,35 @@ if (!Blocks.count(inst->getParent())) inst->replaceUsesOfWith(outputs[i], load); } + + // Store to argument right after the definition of output value + auto *OutI = dyn_cast(outputs[i]); + if (!OutI) + continue; + // Find proper insertion point + Instruction *InsertPt = OutI->getNextNode(); + // Let's assume that there is no other guy interleave + // non-PHI in PHIs + if (isa(InsertPt)) + InsertPt = InsertPt->getParent()->getFirstNonPHI(); + + assert(OAI != newFunction->arg_end() && + "Amount of output arguments should match " + "the amount of defined values"); + if (AggregateArgs) { + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), InsertPt); + new StoreInst(outputs[i], GEP, InsertPt); + // Since there should be only one struct argument aggregating + // all the output values, we shouldn't increase OAI, which always + // point to the struct argument, in this case + } else { + new StoreInst(outputs[i], &*OAI, InsertPt); + ++OAI; + } } // Now we can emit a switch statement using the call as a value. @@ -801,75 +831,13 @@ break; } - ReturnInst *NTRet = ReturnInst::Create(Context, brVal, NewTarget); + ReturnInst::Create(Context, brVal, NewTarget); // Update the switch instruction. TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), SuccNum), OldTarget); - // Restore values just before we exit - Function::arg_iterator OAI = OutputArgBegin; - for (unsigned out = 0, e = outputs.size(); out != e; ++out) { - // For an invoke, the normal destination is the only one that is - // dominated by the result of the invocation - BasicBlock *DefBlock = cast(outputs[out])->getParent(); - - bool DominatesDef = true; - - BasicBlock *NormalDest = nullptr; - if (auto *Invoke = dyn_cast(outputs[out])) - NormalDest = Invoke->getNormalDest(); - - if (NormalDest) { - DefBlock = NormalDest; - - // Make sure we are looking at the original successor block, not - // at a newly inserted exit block, which won't be in the dominator - // info. - for (const auto &I : ExitBlockMap) - if (DefBlock == I.second) { - DefBlock = I.first; - break; - } - - // In the extract block case, if the block we are extracting ends - // with an invoke instruction, make sure that we don't emit a - // store of the invoke value for the unwind block. - if (!DT && DefBlock != OldTarget) - DominatesDef = false; - } - - if (DT) { - DominatesDef = DT->dominates(DefBlock, OldTarget); - - // If the output value is used by a phi in the target block, - // then we need to test for dominance of the phi's predecessor - // instead. Unfortunately, this a little complicated since we - // have already rewritten uses of the value to uses of the reload. - BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out], - OldTarget); - if (pred && DT && DT->dominates(DefBlock, pred)) - DominatesDef = true; - } - - if (DominatesDef) { - if (AggregateArgs) { - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), - FirstOut+out); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, &*OAI, Idx, "gep_" + outputs[out]->getName(), - NTRet); - new StoreInst(outputs[out], GEP, NTRet); - } else { - new StoreInst(outputs[out], &*OAI, NTRet); - } - } - // Advance output iterator even if we don't emit a store - if (!AggregateArgs) ++OAI; - } } // rewrite the original branch instruction with this new target Index: unittests/Transforms/Utils/CMakeLists.txt =================================================================== --- unittests/Transforms/Utils/CMakeLists.txt +++ unittests/Transforms/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ set(LLVM_LINK_COMPONENTS Analysis + AsmParser Core Support TransformUtils @@ -8,6 +9,7 @@ add_llvm_unittest(UtilsTests ASanStackFrameLayoutTest.cpp Cloning.cpp + CodeExtractor.cpp FunctionComparator.cpp IntegerDivision.cpp Local.cpp Index: unittests/Transforms/Utils/CodeExtractor.cpp =================================================================== --- /dev/null +++ unittests/Transforms/Utils/CodeExtractor.cpp @@ -0,0 +1,69 @@ +//===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/CodeExtractor.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { +TEST(CodeExtractor, ExitStub) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + define i32 @foo(i32 %x, i32 %y, i32 %z) { + header: + %0 = icmp ugt i32 %x, %y + br i1 %0, label %body1, label %body2 + + body1: + %1 = add i32 %z, 2 + br label %notExtracted + + body2: + %2 = mul i32 %z, 7 + br label %notExtracted + + notExtracted: + %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ] + %4 = add i32 %3, %x + ret i32 %4 + } + )invalid", + Err, Ctx)); + + Function *Func = M->getFunction("foo"); + SmallVector Candidates; + for (auto &BB : *Func) { + if (BB.getName() == "body1") + Candidates.push_back(&BB); + if (BB.getName() == "body2") + Candidates.push_back(&BB); + } + // CodeExtractor require the first basic block + // to dominate all the other ones + Candidates.insert(Candidates.begin(), &Func->getEntryBlock()); + + DominatorTree DT(*Func); + CodeExtractor CE(Candidates, &DT); + EXPECT_TRUE(CE.isEligible()); + + Function *Outlined = CE.extractCodeRegion(); + EXPECT_TRUE(Outlined); + EXPECT_FALSE(verifyFunction(*Outlined)); +} +} // end anonymous namespace