Index: llvm/include/llvm/Transforms/Utils/CodeExtractor.h =================================================================== --- llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -110,7 +110,9 @@ /// /// Based on the blocks used when constructing the code extractor, /// determine whether it is eligible for extraction. - bool isEligible() const { return !Blocks.empty(); } + bool isEligible(const ValueSet &inputs) const { + return !Blocks.empty() && validateInputDataDependencies(inputs); + } /// Compute the set of input values and output values for the code. /// @@ -121,7 +123,8 @@ /// sets, before extraction occurs. These modifications won't have any /// significant impact on the cost however. void findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs, - const ValueSet &Allocas) const; + const ValueSet &Allocas, + bool InputsOnly = false) const; /// Check if life time marker nodes can be hoisted/sunk into the outline /// region. @@ -165,6 +168,10 @@ void severSplitPHINodesOfExits(const SmallPtrSetImpl &Exits); void splitReturnBlocks(); + /// Verify if args to the new function are valid. + /// Returns false when any argument is invalid. + bool validateInputDataDependencies(const ValueSet &inputs) const; + Function *constructFunction(const ValueSet &inputs, const ValueSet &outputs, BasicBlock *header, Index: llvm/lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -538,7 +538,8 @@ } void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs, - const ValueSet &SinkCands) const { + const ValueSet &SinkCands, + bool InputsOnly) const { for (BasicBlock *BB : Blocks) { // If a used value is defined outside the region, it's an input. If an // instruction is used outside the region, it's an output. @@ -549,7 +550,8 @@ if (!SinkCands.count(V) && definedInCaller(Blocks, V)) Inputs.insert(V); } - + if (InputsOnly) + return; for (User *U : II.users()) if (!definedInRegion(Blocks, U)) { Outputs.insert(&II); @@ -1333,8 +1335,14 @@ } Function *CodeExtractor::extractCodeRegion() { - if (!isEligible()) + ValueSet inputs, outputs, SinkingCands, HoistingCands; + findInputsOutputs(inputs, outputs, SinkingCands, true); + + if (!isEligible(inputs)) return nullptr; + // After sanity check, clear inputs as it will be recomputed + // after CFG is prepared for splitting. + inputs.clear(); // Assumption: this is a single-entry code region, and the header is the first // block in the region. @@ -1359,7 +1367,6 @@ return nullptr; } } - ValueSet inputs, outputs, SinkingCands, HoistingCands; BasicBlock *CommonExit = nullptr; // Calculate the entry frequency of the new function before we change the root @@ -1563,5 +1570,16 @@ }); LLVM_DEBUG(if (verifyFunction(*oldFunction)) report_fatal_error("verification of oldFunction failed!")); + LLVM_DEBUG(if (!validateInputDataDependencies(inputs)) + report_fatal_error("verification of newFunction args failed!")); return newFunction; } + +bool CodeExtractor::validateInputDataDependencies(const ValueSet &inputs) const { + for (const Value *Arg : inputs) { + Type *Ty = Arg->getType(); + if (!Ty->isFirstClassType() || Ty->isMetadataTy() || Ty->isTokenTy()) + return false; + } + return true; +} Index: llvm/test/Transforms/HotColdSplit/token-arg.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/HotColdSplit/token-arg.ll @@ -0,0 +1,37 @@ +; RUN: opt -S -hotcoldsplit < %s | FileCheck %s + +; CHECK-LABEL: @zot +; CHECK-LABEL: bb +; CHECK: invoke void @barney() +; CHECK-LABEL: bb3 +; CHECK: call void @barney() +; CHECK-LABEL: bb4 +; CHECK: call void @barney.1() +; CHECK-NOT: .cold + +define void @zot() personality i8* bitcast (i32 (...)* @bar to i8*) { +bb: + invoke void @barney() + to label %bb1 unwind label %bb2 + +bb1: ; preds = %bb + ret void + +bb2: ; preds = %bb + %tmp = cleanuppad within none [] + br label %bb3 + +bb3: ; preds = %bb2 + call void @barney() [ "funclet"(token %tmp) ] + br label %bb4 + +bb4: ; preds = %bb3 + call void @barney.1() [ "funclet"(token %tmp) ] + unreachable +} + +declare void @barney() + +declare i32 @bar(...) + +declare void @barney.1() Index: llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp =================================================================== --- llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -59,7 +59,9 @@ getBlockByName(Func, "body2") }; CodeExtractor CE(Candidates); - EXPECT_TRUE(CE.isEligible()); + SetVector Inputs, Outputs, Sinks; + CE.findInputsOutputs(Inputs, Outputs, Sinks, true); + EXPECT_TRUE(CE.isEligible(Inputs)); Function *Outlined = CE.extractCodeRegion(); EXPECT_TRUE(Outlined); @@ -109,7 +111,9 @@ }; CodeExtractor CE(ExtractedBlocks); - EXPECT_TRUE(CE.isEligible()); + SetVector Inputs, Outputs, Sinks; + CE.findInputsOutputs(Inputs, Outputs, Sinks, true); + EXPECT_TRUE(CE.isEligible(Inputs)); Function *Outlined = CE.extractCodeRegion(); EXPECT_TRUE(Outlined); @@ -183,7 +187,9 @@ }; CodeExtractor CE(ExtractedBlocks); - EXPECT_TRUE(CE.isEligible()); + SetVector Inputs, Outputs, Sinks; + CE.findInputsOutputs(Inputs, Outputs, Sinks, true); + EXPECT_TRUE(CE.isEligible(Inputs)); Function *Outlined = CE.extractCodeRegion(); EXPECT_TRUE(Outlined); @@ -217,7 +223,9 @@ getBlockByName(Func, "lpad") }; CodeExtractor CE(Blocks); - EXPECT_TRUE(CE.isEligible()); + SetVector Inputs, Outputs, Sinks; + CE.findInputsOutputs(Inputs, Outputs, Sinks, true); + EXPECT_TRUE(CE.isEligible(Inputs)); Function *Outlined = CE.extractCodeRegion(); EXPECT_TRUE(Outlined);