Index: llvm/include/llvm/Transforms/Utils/CodeExtractor.h =================================================================== --- llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -100,6 +100,10 @@ unsigned NumExitBlocks = std::numeric_limits::max(); Type *RetTy; + // Mapping from the original exit blocks, to the new blocks inside + // the function. + SmallVector OldTargets; + // Suffix to use when creating extracted function (appended to the original // function name + "."). If empty, the default is to use the entry block // label, if non-empty, otherwise "extracted". Index: llvm/lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -434,6 +434,7 @@ } // Now add the old exit block to the outline region. Blocks.insert(CommonExitBlock); + OldTargets.push_back(NewExitBlock); return CommonExitBlock; } @@ -1247,41 +1248,52 @@ // not in the region to be extracted. std::map ExitBlockMap; + // Iterate over the previously collected targets, and create new blocks inside + // the function to branch to. unsigned switchVal = 0; + for (BasicBlock *OldTarget : OldTargets) { + if (Blocks.count(OldTarget)) + continue; + BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; + if (NewTarget) + continue; + + // If we don't already have an exit stub for this non-extracted + // destination, create one now! + NewTarget = BasicBlock::Create(Context, + OldTarget->getName() + ".exitStub", + newFunction); + unsigned SuccNum = switchVal++; + + Value *brVal = nullptr; + assert(NumExitBlocks < 0xffff && "too many exit blocks for switch"); + switch (NumExitBlocks) { + case 0: + case 1: break; // No value needed. + case 2: // Conditional branch, return a bool + brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); + break; + default: + brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); + break; + } + + ReturnInst::Create(Context, brVal, NewTarget); + + // Update the switch instruction. + TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), + SuccNum), + OldTarget); + } + for (BasicBlock *Block : Blocks) { Instruction *TI = Block->getTerminator(); for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) if (!Blocks.count(TI->getSuccessor(i))) { BasicBlock *OldTarget = TI->getSuccessor(i); // add a new basic block which returns the appropriate value - BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; - if (!NewTarget) { - // If we don't already have an exit stub for this non-extracted - // destination, create one now! - NewTarget = BasicBlock::Create(Context, - OldTarget->getName() + ".exitStub", - newFunction); - unsigned SuccNum = switchVal++; - - Value *brVal = nullptr; - switch (NumExitBlocks) { - case 0: - case 1: break; // No value needed. - case 2: // Conditional branch, return a bool - brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); - break; - default: - brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); - break; - } - - ReturnInst::Create(Context, brVal, NewTarget); - - // Update the switch instruction. - TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), - SuccNum), - OldTarget); - } + BasicBlock *NewTarget = ExitBlockMap[OldTarget]; + assert(NewTarget && "Unknown target block!"); // rewrite the original branch instruction with this new target TI->setSuccessor(i, NewTarget); @@ -1639,6 +1651,15 @@ } NumExitBlocks = ExitBlocks.size(); + for (BasicBlock *Block : Blocks) { + Instruction *TI = Block->getTerminator(); + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + if (!Blocks.count(TI->getSuccessor(i))) { + BasicBlock *OldTarget = TI->getSuccessor(i); + OldTargets.push_back(OldTarget); + } + } + // If we have to split PHI nodes of the entry or exit blocks, do so now. severSplitPHINodesOfEntry(header); severSplitPHINodesOfExits(ExitBlocks); Index: llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp =================================================================== --- llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -115,11 +115,11 @@ Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); EXPECT_TRUE(Outlined); - EXPECT_EQ(Inputs.size(), 3); + EXPECT_EQ(Inputs.size(), 3u); EXPECT_EQ(Inputs[0], Func->getArg(2)); EXPECT_EQ(Inputs[1], Func->getArg(0)); EXPECT_EQ(Inputs[2], Func->getArg(1)); - EXPECT_EQ(Outputs.size() == 1); + EXPECT_EQ(Outputs.size(), 1u); StoreInst *SI = cast(Outlined->getArg(3)->user_back()); Value *OutputVal = SI->getValueOperand(); EXPECT_EQ(Outputs[0], OutputVal); @@ -135,6 +135,121 @@ EXPECT_FALSE(verifyFunction(*Func)); } +TEST(CodeExtractor, ExitBlockOrderingPhis) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + define void @foo(i32 %a, i32 %b) { + entry: + %0 = alloca i32, align 4 + br label %test0 + test0: + %c = load i32, i32* %0, align 4 + br label %test1 + test1: + %e = load i32, i32* %0, align 4 + br i1 true, label %first, label %test + test: + %d = load i32, i32* %0, align 4 + br i1 true, label %first, label %next + first: + %1 = phi i32 [ %c, %test ], [ %e, %test1 ] + ret void + next: + %2 = add i32 %d, 1 + %3 = add i32 %e, 1 + ret void + } + )invalid", + Err, Ctx)); + Function *Func = M->getFunction("foo"); + SmallVector Candidates{ getBlockByName(Func, "test0"), + getBlockByName(Func, "test1"), + getBlockByName(Func, "test") }; + + CodeExtractor CE(Candidates); + EXPECT_TRUE(CE.isEligible()); + + CodeExtractorAnalysisCache CEAC(*Func); + Function *Outlined = CE.extractCodeRegion(CEAC); + EXPECT_TRUE(Outlined); + + BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub"); + BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub"); + + Instruction *FirstTerm = FirstExitStub->getTerminator(); + ReturnInst *FirstReturn = dyn_cast(FirstTerm); + EXPECT_TRUE(FirstReturn); + ConstantInt *CIFirst = dyn_cast(FirstReturn->getReturnValue()); + EXPECT_TRUE(CIFirst->getLimitedValue() == 1u); + + Instruction *NextTerm = NextExitStub->getTerminator(); + ReturnInst *NextReturn = dyn_cast(NextTerm); + EXPECT_TRUE(NextReturn); + ConstantInt *CINext = dyn_cast(NextReturn->getReturnValue()); + EXPECT_TRUE(CINext->getLimitedValue() == 0u); + + EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); +} + +TEST(CodeExtractor, ExitBlockOrdering) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + define void @foo(i32 %a, i32 %b) { + entry: + %0 = alloca i32, align 4 + br label %test0 + test0: + %c = load i32, i32* %0, align 4 + br label %test1 + test1: + %e = load i32, i32* %0, align 4 + br i1 true, label %first, label %test + test: + %d = load i32, i32* %0, align 4 + br i1 true, label %first, label %next + first: + ret void + next: + %1 = add i32 %d, 1 + %2 = add i32 %e, 1 + ret void + } + )invalid", + Err, Ctx)); + Function *Func = M->getFunction("foo"); + SmallVector Candidates{ getBlockByName(Func, "test0"), + getBlockByName(Func, "test1"), + getBlockByName(Func, "test") }; + + CodeExtractor CE(Candidates); + EXPECT_TRUE(CE.isEligible()); + + CodeExtractorAnalysisCache CEAC(*Func); + Function *Outlined = CE.extractCodeRegion(CEAC); + EXPECT_TRUE(Outlined); + + BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub"); + BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub"); + + Instruction *FirstTerm = FirstExitStub->getTerminator(); + ReturnInst *FirstReturn = dyn_cast(FirstTerm); + EXPECT_TRUE(FirstReturn); + ConstantInt *CIFirst = dyn_cast(FirstReturn->getReturnValue()); + EXPECT_TRUE(CIFirst->getLimitedValue() == 1u); + + Instruction *NextTerm = NextExitStub->getTerminator(); + ReturnInst *NextReturn = dyn_cast(NextTerm); + EXPECT_TRUE(NextReturn); + ConstantInt *CINext = dyn_cast(NextReturn->getReturnValue()); + EXPECT_TRUE(CINext->getLimitedValue() == 0u); + + EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); +} + TEST(CodeExtractor, ExitPHIOnePredFromRegion) { LLVMContext Ctx; SMDiagnostic Err;