diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h --- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -139,6 +139,20 @@ /// returns false. Function *extractCodeRegion(const CodeExtractorAnalysisCache &CEAC); + /// Perform the extraction, returning the new function and providing an + /// interface to see what was categorized as inputs and outputs. + /// + /// \param CEAC - Cache to speed up operations for the CodeExtractor when + /// hoisting, and extracting lifetime values and assumes. + /// \param Inputs [out] - filled with values marked as inputs to the + /// newly outlined function. + /// \param Outputs [out] - filled with values marked as outputs to the + /// newly outlined function. + /// \returns zero when called on a CodeExtractor instance where isEligible + /// returns false. + Function *extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, + ValueSet &Inputs, ValueSet &Outputs); + /// Verify that assumption cache isn't stale after a region is extracted. /// Returns true when verifier finds errors. AssumptionCache is passed as /// parameter to make this function stateless. diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -1575,6 +1575,13 @@ Function * CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) { + ValueSet Inputs, Outputs; + return extractCodeRegion(CEAC, Inputs, Outputs); +} + +Function * +CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, + ValueSet &inputs, ValueSet &outputs) { if (!isEligible()) return nullptr; @@ -1663,7 +1670,7 @@ } newFuncRoot->getInstList().push_back(BranchI); - ValueSet inputs, outputs, SinkingCands, HoistingCands; + ValueSet SinkingCands, HoistingCands; BasicBlock *CommonExit = nullptr; findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); assert(HoistingCands.empty() || CommonExit); diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp --- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -77,6 +77,64 @@ EXPECT_FALSE(verifyFunction(*Func)); } +TEST(CodeExtractor, InputOutputMonitoring) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + define i32 @foo(i32 %x, i32 %y, i32 %z) { + header: + %0 = icmp ugt i32 %x, %y + br i1 %0, label %body1, label %body2 + + body1: + %1 = add i32 %z, 2 + br label %notExtracted + + body2: + %2 = mul i32 %z, 7 + br label %notExtracted + + notExtracted: + %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ] + %4 = add i32 %3, %x + ret i32 %4 + } + )invalid", + Err, Ctx)); + + Function *Func = M->getFunction("foo"); + SmallVector Candidates{getBlockByName(Func, "header"), + getBlockByName(Func, "body1"), + getBlockByName(Func, "body2")}; + + CodeExtractor CE(Candidates); + EXPECT_TRUE(CE.isEligible()); + + CodeExtractorAnalysisCache CEAC(*Func); + SetVector Inputs, Outputs; + Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); + EXPECT_TRUE(Outlined); + + EXPECT_EQ(Inputs.size(), 3u); + EXPECT_EQ(Inputs[0], Func->getArg(2)); + EXPECT_EQ(Inputs[1], Func->getArg(0)); + EXPECT_EQ(Inputs[2], Func->getArg(1)); + EXPECT_EQ(Outputs.size(), 1u); + StoreInst *SI = cast(Outlined->getArg(3)->user_back()); + Value *OutputVal = SI->getValueOperand(); + EXPECT_EQ(Outputs[0], OutputVal); + BasicBlock *Exit = getBlockByName(Func, "notExtracted"); + BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split"); + // Ensure that PHI in exit block has only one incoming value (from code + // replacer block). + EXPECT_TRUE(Exit && cast(Exit->front()).getNumIncomingValues() == 1); + // Ensure that there is a PHI in outlined function with 2 incoming values. + EXPECT_TRUE(ExitSplit && + cast(ExitSplit->front()).getNumIncomingValues() == 2); + EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); +} + TEST(CodeExtractor, ExitPHIOnePredFromRegion) { LLVMContext Ctx; SMDiagnostic Err;