Index: llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h =================================================================== --- llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h +++ llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h @@ -18,6 +18,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include namespace llvm { @@ -146,7 +147,8 @@ BasicBlock *findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock); private: - void severSplitPHINodes(BasicBlock *&Header); + void severSplitPHINodesOfEntry(BasicBlock *&Header); + void severSplitPHINodesOfExits(const SmallPtrSetImpl &Exits); void splitReturnBlocks(); Function *constructFunction(const ValueSet &inputs, Index: llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp +++ llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp @@ -531,10 +531,10 @@ } } -/// severSplitPHINodes - If a PHI node has multiple inputs from outside of the -/// region, we need to split the entry block of the region so that the PHI node -/// is easier to deal with. -void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { +/// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside +/// of the region, we need to split the entry block of the region so that the +/// PHI node is easier to deal with. +void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) { unsigned NumPredsFromRegion = 0; unsigned NumPredsOutsideRegion = 0; @@ -606,6 +606,56 @@ } } +/// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from +/// outlined region, we split these PHIs on two: one with inputs from region +/// and other with remaining incoming blocks; then first PHIs are placed in +/// outlined region. +void CodeExtractor::severSplitPHINodesOfExits( + const SmallPtrSetImpl &Exits) { + for (BasicBlock *ExitBB : Exits) { + BasicBlock *NewBB = nullptr; + + for (PHINode &PN : ExitBB->phis()) { + // Find all incoming values from the outlining region. + SmallVector IncomingVals; + for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i) + if (Blocks.count(PN.getIncomingBlock(i))) + IncomingVals.push_back(i); + + // Do not process PHI if there is one (or fewer) predecessor from region. + // If PHI has exactly one predecessor from region, only this one incoming + // will be replaced on codeRepl block, so it should be safe to skip PHI. + if (IncomingVals.size() <= 1) + continue; + + // Create block for new PHIs and add it to the list of outlined if it + // wasn't done before. + if (!NewBB) { + NewBB = BasicBlock::Create(ExitBB->getContext(), + ExitBB->getName() + ".split", + ExitBB->getParent(), ExitBB); + SmallVector Preds(pred_begin(ExitBB), + pred_end(ExitBB)); + for (BasicBlock *PredBB : Preds) + if (Blocks.count(PredBB)) + PredBB->getTerminator()->replaceUsesOfWith(ExitBB, NewBB); + BranchInst::Create(ExitBB, NewBB); + Blocks.insert(NewBB); + } + + // Split this PHI. + PHINode *NewPN = + PHINode::Create(PN.getType(), IncomingVals.size(), + PN.getName() + ".ce", NewBB->getFirstNonPHI()); + for (unsigned i : IncomingVals) + NewPN->addIncoming(PN.getIncomingValue(i), PN.getIncomingBlock(i)); + for (unsigned i : reverse(IncomingVals)) + PN.removeIncomingValue(i, false); + PN.addIncoming(NewPN, NewBB); + } + } +} + void CodeExtractor::splitReturnBlocks() { for (BasicBlock *Block : Blocks) if (ReturnInst *RI = dyn_cast(Block->getTerminator())) { @@ -1173,13 +1223,33 @@ } } - // If we have to split PHI nodes or the entry block, do so now. - severSplitPHINodes(header); - // If we have any return instructions in the region, split those blocks so // that the return is not in the region. splitReturnBlocks(); + // Calculate the exit blocks for the extracted region and the total exit + // weights for each of those blocks. + DenseMap ExitWeights; + SmallPtrSet ExitBlocks; + for (BasicBlock *Block : Blocks) { + for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE; + ++SI) { + if (!Blocks.count(*SI)) { + // Update the branch weight for this successor. + if (BFI) { + BlockFrequency &BF = ExitWeights[*SI]; + BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI); + } + ExitBlocks.insert(*SI); + } + } + } + NumExitBlocks = ExitBlocks.size(); + + // If we have to split PHI nodes of the entry or exit blocks, do so now. + severSplitPHINodesOfEntry(header); + severSplitPHINodesOfExits(ExitBlocks); + // This takes place of the original loop BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), "codeRepl", oldFunction, @@ -1224,25 +1294,6 @@ cast(II)->moveBefore(TI); } - // Calculate the exit blocks for the extracted region and the total exit - // weights for each of those blocks. - DenseMap ExitWeights; - SmallPtrSet ExitBlocks; - for (BasicBlock *Block : Blocks) { - for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE; - ++SI) { - if (!Blocks.count(*SI)) { - // Update the branch weight for this successor. - if (BFI) { - BlockFrequency &BF = ExitWeights[*SI]; - BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI); - } - ExitBlocks.insert(*SI); - } - } - } - NumExitBlocks = ExitBlocks.size(); - // Construct new function based on inputs/outputs & add allocas for all defs. Function *newFunction = constructFunction(inputs, outputs, header, newFuncRoot, @@ -1270,8 +1321,8 @@ if (BFI && NumExitBlocks > 1) calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); - // Loop over all of the PHI nodes in the header block, and change any - // references to the old incoming edge to be the new incoming edge. + // Loop over all of the PHI nodes in the header and exit blocks, and change + // any references to the old incoming edge to be the new incoming edge. for (BasicBlock::iterator I = header->begin(); isa(I); ++I) { PHINode *PN = cast(I); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) @@ -1279,35 +1330,23 @@ PN->setIncomingBlock(i, newFuncRoot); } - // Look at all successors of the codeReplacer block. If any of these blocks - // had PHI nodes in them, we need to update the "from" block to be the code - // replacer, not the original block in the extracted region. - for (BasicBlock *SuccBB : successors(codeReplacer)) { - for (PHINode &PN : SuccBB->phis()) { + for (BasicBlock *ExitBB : ExitBlocks) + for (PHINode &PN : ExitBB->phis()) { Value *IncomingCodeReplacerVal = nullptr; - SmallVector IncomingValsToRemove; - for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) { - BasicBlock *IncomingBB = PN.getIncomingBlock(I); - + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { // Ignore incoming values from outside of the extracted region. - if (!Blocks.count(IncomingBB)) + if (!Blocks.count(PN.getIncomingBlock(i))) continue; // Ensure that there is only one incoming value from codeReplacer. if (!IncomingCodeReplacerVal) { - PN.setIncomingBlock(I, codeReplacer); - IncomingCodeReplacerVal = PN.getIncomingValue(I); - } else { - assert(IncomingCodeReplacerVal == PN.getIncomingValue(I) && + PN.setIncomingBlock(i, codeReplacer); + IncomingCodeReplacerVal = PN.getIncomingValue(i); + } else + assert(IncomingCodeReplacerVal == PN.getIncomingValue(i) && "PHI has two incompatbile incoming values from codeRepl"); - IncomingValsToRemove.push_back(I); - } } - - for (unsigned I : reverse(IncomingValsToRemove)) - PN.removeIncomingValue(I, /*DeletePHIIfEmpty=*/false); } - } // Erase debug info intrinsics. Variable updates within the new function are // invisible to debuggers. This could be improved by defining a DISubprogram @@ -1338,6 +1377,8 @@ newFunction->setDoesNotReturn(); LLVM_DEBUG(if (verifyFunction(*newFunction)) - report_fatal_error("verifyFunction failed!")); + report_fatal_error("verification of newFunction failed!")); + LLVM_DEBUG(if (verifyFunction(*oldFunction)) + report_fatal_error("verification of oldFunction failed!")); return newFunction; } Index: llvm/trunk/test/Transforms/HotColdSplit/duplicate-phi-preds-crash.ll =================================================================== --- llvm/trunk/test/Transforms/HotColdSplit/duplicate-phi-preds-crash.ll +++ llvm/trunk/test/Transforms/HotColdSplit/duplicate-phi-preds-crash.ll @@ -15,9 +15,9 @@ ; CHECK: call {{.*}}@sideeffect( ; CHECK: call {{.*}}@realloc( ; CHECK-LABEL: codeRepl: -; CHECK-NEXT: call {{.*}}@realloc2.cold.1(i64 %size, i8* %ptr) +; CHECK-NEXT: call {{.*}}@realloc2.cold.1(i64 %size, i8* %ptr, i8** %retval.0.ce.loc) ; CHECK-LABEL: cleanup: -; CHECK-NEXT: phi i8* [ null, %if.then ], [ null, %codeRepl ], [ %call, %if.end ] +; CHECK-NEXT: phi i8* [ null, %if.then ], [ %call, %if.end ], [ %retval.0.ce.reload, %codeRepl ] define i8* @realloc2(i8* %ptr, i64 %size) { entry: %0 = add i64 %size, -1 Index: llvm/trunk/unittests/Transforms/Utils/CodeExtractorTest.cpp =================================================================== --- llvm/trunk/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ llvm/trunk/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -11,6 +11,7 @@ #include "llvm/AsmParser/Parser.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" @@ -21,7 +22,14 @@ using namespace llvm; namespace { -TEST(CodeExtractor, DISABLED_ExitStub) { +BasicBlock *getBlockByName(Function *F, StringRef name) { + for (auto &BB : *F) + if (BB.getName() == name) + return &BB; + return nullptr; +} + +TEST(CodeExtractor, ExitStub) { LLVMContext Ctx; SMDiagnostic Err; std::unique_ptr M(parseAssemblyString(R"invalid( @@ -46,36 +54,10 @@ )invalid", Err, Ctx)); - // CodeExtractor miscompiles this function. There appear to be some issues - // with the handling of outlined regions with live output values. - // - // In the original function, CE adds two reloads in the codeReplacer block: - // - // codeRepl: ; preds = %header - // call void @foo_header.split(i32 %z, i32 %x, i32 %y, i32* %.loc, i32* %.loc1) - // %.reload = load i32, i32* %.loc - // %.reload2 = load i32, i32* %.loc1 - // br label %notExtracted - // - // These reloads must flow into the notExtracted block: - // - // notExtracted: ; preds = %codeRepl - // %0 = phi i32 [ %.reload, %codeRepl ], [ %.reload2, %body2 ] - // - // The problem is that the PHI node in notExtracted now has an incoming - // value from a BasicBlock that's in a different function. - Function *Func = M->getFunction("foo"); - SmallVector Candidates; - for (auto &BB : *Func) { - if (BB.getName() == "body1") - Candidates.push_back(&BB); - if (BB.getName() == "body2") - Candidates.push_back(&BB); - } - // CodeExtractor requires the first basic block - // to dominate all the other ones. - Candidates.insert(Candidates.begin(), &Func->getEntryBlock()); + SmallVector Candidates{ getBlockByName(Func, "header"), + getBlockByName(Func, "body1"), + getBlockByName(Func, "body2") }; DominatorTree DT(*Func); CodeExtractor CE(Candidates, &DT); @@ -83,6 +65,66 @@ Function *Outlined = CE.extractCodeRegion(); EXPECT_TRUE(Outlined); + BasicBlock *Exit = getBlockByName(Func, "notExtracted"); + BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split"); + // Ensure that PHI in exit block has only one incoming value (from code + // replacer block). + EXPECT_TRUE(Exit && cast(Exit->front()).getNumIncomingValues() == 1); + // Ensure that there is a PHI in outlined function with 2 incoming values. + EXPECT_TRUE(ExitSplit && + cast(ExitSplit->front()).getNumIncomingValues() == 2); + EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); +} + +TEST(CodeExtractor, ExitPHIOnePredFromRegion) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + define i32 @foo() { + header: + br i1 undef, label %extracted1, label %pred + + pred: + br i1 undef, label %exit1, label %exit2 + + extracted1: + br i1 undef, label %extracted2, label %exit1 + + extracted2: + br label %exit2 + + exit1: + %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ] + ret i32 %0 + + exit2: + %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ] + ret i32 %1 + } + )invalid", Err, Ctx)); + + Function *Func = M->getFunction("foo"); + SmallVector ExtractedBlocks{ + getBlockByName(Func, "extracted1"), + getBlockByName(Func, "extracted2") + }; + + DominatorTree DT(*Func); + CodeExtractor CE(ExtractedBlocks, &DT); + EXPECT_TRUE(CE.isEligible()); + + Function *Outlined = CE.extractCodeRegion(); + EXPECT_TRUE(Outlined); + BasicBlock *Exit1 = getBlockByName(Func, "exit1"); + BasicBlock *Exit2 = getBlockByName(Func, "exit2"); + // Ensure that PHIs in exits are not splitted (since that they have only one + // incoming value from extracted region). + EXPECT_TRUE(Exit1 && + cast(Exit1->front()).getNumIncomingValues() == 2); + EXPECT_TRUE(Exit2 && + cast(Exit2->front()).getNumIncomingValues() == 2); EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); } } // end anonymous namespace