Index: llvm/include/llvm/Transforms/Utils/CodeExtractor.h =================================================================== --- llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -110,7 +110,12 @@ /// /// Based on the blocks used when constructing the code extractor, /// determine whether it is eligible for extraction. - bool isEligible() const { return !Blocks.empty(); } + bool isEligible(const ValueSet &inputs) const { + return !Blocks.empty() && validateInputDataDependencies(inputs); + } + + /// Compute the set of input values and output values for the code. + void findInputs(ValueSet &Inputs, const ValueSet &SinkCands) const; /// Compute the set of input values and output values for the code. /// @@ -165,6 +170,10 @@ void severSplitPHINodesOfExits(const SmallPtrSetImpl &Exits); void splitReturnBlocks(); + /// Verify if args to the new function are valid. + /// Returns false when any argument is invalid. + bool validateInputDataDependencies(const ValueSet &inputs) const; + Function *constructFunction(const ValueSet &inputs, const ValueSet &outputs, BasicBlock *header, Index: llvm/lib/Transforms/IPO/HotColdSplitting.cpp =================================================================== --- llvm/lib/Transforms/IPO/HotColdSplitting.cpp +++ llvm/lib/Transforms/IPO/HotColdSplitting.cpp @@ -309,6 +309,7 @@ // splitting. SetVector Inputs, Outputs, Sinks; CE.findInputsOutputs(Inputs, Outputs, Sinks); + int OutliningBenefit = getOutliningBenefit(Region, TTI); int OutliningPenalty = getOutliningPenalty(Region, Inputs.size(), Outputs.size()); Index: llvm/lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -537,6 +537,20 @@ } } +void CodeExtractor::findInputs(ValueSet &Inputs, const ValueSet &SinkCands) const { + for (BasicBlock *BB : Blocks) { + // If a used value is defined outside the region, it's an input. + for (Instruction &II : *BB) { + for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE; + ++OI) { + Value *V = *OI; + if (!SinkCands.count(V) && definedInCaller(Blocks, V)) + Inputs.insert(V); + } + } + } +} + void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs, const ValueSet &SinkCands) const { for (BasicBlock *BB : Blocks) { @@ -1333,8 +1347,14 @@ } Function *CodeExtractor::extractCodeRegion() { - if (!isEligible()) + ValueSet inputs, outputs, SinkingCands, HoistingCands; + findInputs(inputs, SinkingCands); + + if (!isEligible(inputs)) return nullptr; + // After sanity check, clear inputs as it will be recomputed + // after CFG is prepared for splitting. + inputs.clear(); // Assumption: this is a single-entry code region, and the header is the first // block in the region. @@ -1359,7 +1379,6 @@ return nullptr; } } - ValueSet inputs, outputs, SinkingCands, HoistingCands; BasicBlock *CommonExit = nullptr; // Calculate the entry frequency of the new function before we change the root @@ -1563,5 +1582,16 @@ }); LLVM_DEBUG(if (verifyFunction(*oldFunction)) report_fatal_error("verification of oldFunction failed!")); + LLVM_DEBUG(if (!validateInputDataDependencies(inputs)) + report_fatal_error("verification of newFunction args failed!")); return newFunction; } + +bool CodeExtractor::validateInputDataDependencies(const ValueSet &inputs) const { + for (const Value *Arg : inputs) { + Type *Ty = Arg->getType(); + if (!Ty->isFirstClassType() || Ty->isMetadataTy() || Ty->isTokenTy()) + return false; + } + return true; +} Index: llvm/test/Transforms/HotColdSplit/token-arg.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/HotColdSplit/token-arg.ll @@ -0,0 +1,37 @@ +; RUN: opt -S -hotcoldsplit < %s | FileCheck %s + +; CHECK-LABEL: @zot +; CHECK-LABEL: bb +; CHECK: invoke void @barney() +; CHECK-LABEL: bb3 +; CHECK: call void @barney() +; CHECK-LABEL: bb4 +; CHECK: call void @barney.1() +; CHECK-NOT: .cold + +define void @zot() personality i8* bitcast (i32 (...)* @bar to i8*) { +bb: + invoke void @barney() + to label %bb1 unwind label %bb2 + +bb1: ; preds = %bb + ret void + +bb2: ; preds = %bb + %tmp = cleanuppad within none [] + br label %bb3 + +bb3: ; preds = %bb2 + call void @barney() [ "funclet"(token %tmp) ] + br label %bb4 + +bb4: ; preds = %bb3 + call void @barney.1() [ "funclet"(token %tmp) ] + unreachable +} + +declare void @barney() + +declare i32 @bar(...) + +declare void @barney.1() Index: llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp =================================================================== --- llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -59,7 +59,9 @@ getBlockByName(Func, "body2") }; CodeExtractor CE(Candidates); - EXPECT_TRUE(CE.isEligible()); + SetVector Inputs, Sinks; + CE.findInputs(Inputs, Sinks); + EXPECT_TRUE(CE.isEligible(Inputs)); Function *Outlined = CE.extractCodeRegion(); EXPECT_TRUE(Outlined); @@ -109,7 +111,9 @@ }; CodeExtractor CE(ExtractedBlocks); - EXPECT_TRUE(CE.isEligible()); + SetVector Inputs, Sinks; + CE.findInputs(Inputs, Sinks); + EXPECT_TRUE(CE.isEligible(Inputs)); Function *Outlined = CE.extractCodeRegion(); EXPECT_TRUE(Outlined); @@ -183,7 +187,9 @@ }; CodeExtractor CE(ExtractedBlocks); - EXPECT_TRUE(CE.isEligible()); + SetVector Inputs, Sinks; + CE.findInputs(Inputs, Sinks); + EXPECT_TRUE(CE.isEligible(Inputs)); Function *Outlined = CE.extractCodeRegion(); EXPECT_TRUE(Outlined); @@ -217,7 +223,9 @@ getBlockByName(Func, "lpad") }; CodeExtractor CE(Blocks); - EXPECT_TRUE(CE.isEligible()); + SetVector Inputs, Sinks; + CE.findInputs(Inputs, Sinks); + EXPECT_TRUE(CE.isEligible(Inputs)); Function *Outlined = CE.extractCodeRegion(); EXPECT_TRUE(Outlined);