diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h --- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -151,7 +151,7 @@ /// /// Checks that varargs handling (with vastart and vaend) is only done in /// the outlined blocks. - bool isEligible() const; + bool isEligible(const ValueSet &inputs) const; /// Compute the set of input values and output values for the code. /// @@ -162,7 +162,8 @@ /// sets, before extraction occurs. These modifications won't have any /// significant impact on the cost however. void findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs, - const ValueSet &Allocas) const; + const ValueSet &Allocas, + bool InputsOnly = false) const; /// Check if life time marker nodes can be hoisted/sunk into the outline /// region. @@ -210,6 +211,10 @@ void severSplitPHINodesOfExits(const SmallPtrSetImpl &Exits); void splitReturnBlocks(); + /// Verify if args to the new function are valid. + /// Returns false when any argument is invalid. + bool validateInputDataDependencies(const ValueSet &inputs) const; + Function *constructFunction(const ValueSet &inputs, const ValueSet &outputs, BasicBlock *header, diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -568,7 +568,7 @@ } } -bool CodeExtractor::isEligible() const { +bool CodeExtractor::isEligible(const ValueSet &inputs) const { if (Blocks.empty()) return false; BasicBlock *Header = *Blocks.begin(); @@ -592,11 +592,12 @@ return false; } } - return true; + return validateInputDataDependencies(inputs); } void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs, - const ValueSet &SinkCands) const { + const ValueSet &SinkCands, + bool InputsOnly) const { for (BasicBlock *BB : Blocks) { // If a used value is defined outside the region, it's an input. If an // instruction is used outside the region, it's an output. @@ -606,7 +607,8 @@ if (!SinkCands.count(V) && definedInCaller(Blocks, V)) Inputs.insert(V); } - + if (InputsOnly) + return; for (User *U : II.users()) if (!definedInRegion(Blocks, U)) { Outputs.insert(&II); @@ -1382,10 +1384,15 @@ MDBuilder(TI->getContext()).createBranchWeights(BranchWeights)); } -Function * -CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) { - if (!isEligible()) +Function *CodeExtractor::extractCodeRegion() { + ValueSet inputs, outputs, SinkingCands, HoistingCands; + findInputsOutputs(inputs, outputs, SinkingCands, true); + + if (!isEligible(inputs)) return nullptr; + // After sanity check, clear inputs as it will be recomputed + // after CFG is prepared for splitting. + inputs.clear(); // Assumption: this is a single-entry code region, and the header is the first // block in the region. @@ -1467,7 +1474,6 @@ } newFuncRoot->getInstList().push_back(BranchI); - ValueSet inputs, outputs, SinkingCands, HoistingCands; BasicBlock *CommonExit = nullptr; findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); assert(HoistingCands.empty() || CommonExit); @@ -1606,6 +1612,8 @@ report_fatal_error("verification of oldFunction failed!")); LLVM_DEBUG(if (AC && verifyAssumptionCache(*oldFunction, AC)) report_fatal_error("Stale Asumption cache for old Function!")); + LLVM_DEBUG(if (!validateInputDataDependencies(inputs)) + report_fatal_error("verification of newFunction args failed!")); return newFunction; } @@ -1618,3 +1626,13 @@ } return false; } + +bool +CodeExtractor::validateInputDataDependencies(const ValueSet &inputs) const { + for (const Value *Arg : inputs) { + Type *Ty = Arg->getType(); + if (!Ty->isFirstClassType() || Ty->isMetadataTy() || Ty->isTokenTy()) + return false; + } + return true; +} diff --git a/llvm/test/Transforms/HotColdSplit/token-arg.ll b/llvm/test/Transforms/HotColdSplit/token-arg.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/HotColdSplit/token-arg.ll @@ -0,0 +1,37 @@ +; RUN: opt -S -hotcoldsplit < %s | FileCheck %s + +; CHECK-LABEL: @zot +; CHECK-LABEL: bb +; CHECK: invoke void @barney() +; CHECK-LABEL: bb3 +; CHECK: call void @barney() +; CHECK-LABEL: bb4 +; CHECK: call void @barney.1() +; CHECK-NOT: .cold + +define void @zot() personality i8* bitcast (i32 (...)* @bar to i8*) { +bb: + invoke void @barney() + to label %bb1 unwind label %bb2 + +bb1: ; preds = %bb + ret void + +bb2: ; preds = %bb + %tmp = cleanuppad within none [] + br label %bb3 + +bb3: ; preds = %bb2 + call void @barney() [ "funclet"(token %tmp) ] + br label %bb4 + +bb4: ; preds = %bb3 + call void @barney.1() [ "funclet"(token %tmp) ] + unreachable +} + +declare void @barney() + +declare i32 @bar(...) + +declare void @barney.1() diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp --- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -60,7 +60,9 @@ getBlockByName(Func, "body2") }; CodeExtractor CE(Candidates); - EXPECT_TRUE(CE.isEligible()); + SetVector Inputs, Outputs, Sinks; + CE.findInputsOutputs(Inputs, Outputs, Sinks, true); + EXPECT_TRUE(CE.isEligible(Inputs)); CodeExtractorAnalysisCache CEAC(*Func); Function *Outlined = CE.extractCodeRegion(CEAC); @@ -111,7 +113,9 @@ }; CodeExtractor CE(ExtractedBlocks); - EXPECT_TRUE(CE.isEligible()); + SetVector Inputs, Outputs, Sinks; + CE.findInputsOutputs(Inputs, Outputs, Sinks, true); + EXPECT_TRUE(CE.isEligible(Inputs)); CodeExtractorAnalysisCache CEAC(*Func); Function *Outlined = CE.extractCodeRegion(CEAC); @@ -186,7 +190,9 @@ }; CodeExtractor CE(ExtractedBlocks); - EXPECT_TRUE(CE.isEligible()); + SetVector Inputs, Outputs, Sinks; + CE.findInputsOutputs(Inputs, Outputs, Sinks, true); + EXPECT_TRUE(CE.isEligible(Inputs)); CodeExtractorAnalysisCache CEAC(*Func); Function *Outlined = CE.extractCodeRegion(CEAC); @@ -221,7 +227,9 @@ getBlockByName(Func, "lpad") }; CodeExtractor CE(Blocks); - EXPECT_TRUE(CE.isEligible()); + SetVector Inputs, Outputs, Sinks; + CE.findInputsOutputs(Inputs, Outputs, Sinks, true); + EXPECT_TRUE(CE.isEligible(Inputs)); CodeExtractorAnalysisCache CEAC(*Func); Function *Outlined = CE.extractCodeRegion(CEAC); @@ -273,7 +281,9 @@ SmallVector Blocks{ getBlockByName(Func, "if.else") }; AssumptionCache AC(*Func); CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC); - EXPECT_TRUE(CE.isEligible()); + SetVector Inputs, Outputs, Sinks; + CE.findInputsOutputs(Inputs, Outputs, Sinks, true); + EXPECT_TRUE(CE.isEligible(Inputs)); CodeExtractorAnalysisCache CEAC(*Func); Function *Outlined = CE.extractCodeRegion(CEAC);