Index: llvm/include/llvm/Transforms/Utils/CodeExtractor.h =================================================================== --- llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -139,6 +139,9 @@ /// returns false. Function *extractCodeRegion(const CodeExtractorAnalysisCache &CEAC); + Function *extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, + ValueSet &Inputs, ValueSet &Outputs); + /// Verify that assumption cache isn't stale after a region is extracted. /// Returns true when verifier finds errors. AssumptionCache is passed as /// parameter to make this function stateless. Index: llvm/lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -1569,6 +1569,13 @@ Function * CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) { + ValueSet Inputs, Outputs; + return extractCodeRegion(CEAC, Inputs, Outputs); +} + +Function * +CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, + ValueSet &inputs, ValueSet &outputs) { if (!isEligible()) return nullptr; @@ -1657,7 +1664,7 @@ } newFuncRoot->getInstList().push_back(BranchI); - ValueSet inputs, outputs, SinkingCands, HoistingCands; + ValueSet SinkingCands, HoistingCands; BasicBlock *CommonExit = nullptr; findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); assert(HoistingCands.empty() || CommonExit); Index: llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp =================================================================== --- llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -77,6 +77,64 @@ EXPECT_FALSE(verifyFunction(*Func)); } +TEST(CodeExtractor, InputOutputMonitoring) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + define i32 @foo(i32 %x, i32 %y, i32 %z) { + header: + %0 = icmp ugt i32 %x, %y + br i1 %0, label %body1, label %body2 + + body1: + %1 = add i32 %z, 2 + br label %notExtracted + + body2: + %2 = mul i32 %z, 7 + br label %notExtracted + + notExtracted: + %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ] + %4 = add i32 %3, %x + ret i32 %4 + } + )invalid", + Err, Ctx)); + + Function *Func = M->getFunction("foo"); + SmallVector Candidates{ getBlockByName(Func, "header"), + getBlockByName(Func, "body1"), + getBlockByName(Func, "body2") }; + + CodeExtractor CE(Candidates); + EXPECT_TRUE(CE.isEligible()); + + CodeExtractorAnalysisCache CEAC(*Func); + SetVector Inputs, Outputs; + Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); + EXPECT_TRUE(Outlined); + + EXPECT_TRUE(Inputs.size() == 3); + EXPECT_TRUE(Inputs[0] == Func->getArg(2)); + EXPECT_TRUE(Inputs[1] == Func->getArg(0)); + EXPECT_TRUE(Inputs[2] == Func->getArg(1)); + EXPECT_TRUE(Outputs.size() == 1); + StoreInst *SI = cast(Outlined->getArg(3)->user_back()); + Value *OutputVal = SI->getValueOperand(); + EXPECT_TRUE(Outputs[0] == OutputVal); + BasicBlock *Exit = getBlockByName(Func, "notExtracted"); + BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split"); + // Ensure that PHI in exit block has only one incoming value (from code + // replacer block). + EXPECT_TRUE(Exit && cast(Exit->front()).getNumIncomingValues() == 1); + // Ensure that there is a PHI in outlined function with 2 incoming values. + EXPECT_TRUE(ExitSplit && + cast(ExitSplit->front()).getNumIncomingValues() == 2); + EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); +} + TEST(CodeExtractor, ExitPHIOnePredFromRegion) { LLVMContext Ctx; SMDiagnostic Err;