diff --git a/llvm/include/llvm/Transforms/IPO/HotColdSplitting.h b/llvm/include/llvm/Transforms/IPO/HotColdSplitting.h --- a/llvm/include/llvm/Transforms/IPO/HotColdSplitting.h +++ b/llvm/include/llvm/Transforms/IPO/HotColdSplitting.h @@ -12,7 +12,9 @@ #ifndef LLVM_TRANSFORMS_IPO_HOTCOLDSPLITTING_H #define LLVM_TRANSFORMS_IPO_HOTCOLDSPLITTING_H +#include "llvm/ADT/StringSet.h" #include "llvm/IR/PassManager.h" +#include "llvm/Support/SpecialCaseList.h" namespace llvm { @@ -37,7 +39,8 @@ function_ref GTTI, std::function *GORE, function_ref LAC) - : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE), LookupAC(LAC) {} + : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE), LookupAC(LAC), + FileMarkedColdFunctions(nullptr) {} bool run(Module &M); private: @@ -55,6 +58,8 @@ function_ref GetTTI; std::function *GetORE; function_ref LookupAC; + StringSet<> CmdMarkedColdFunctions; + std::unique_ptr FileMarkedColdFunctions; }; /// Pass to outline cold regions. @@ -66,4 +71,3 @@ } // end namespace llvm #endif // LLVM_TRANSFORMS_IPO_HOTCOLDSPLITTING_H - diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp --- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp +++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp @@ -29,6 +29,7 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" @@ -59,6 +60,8 @@ #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/SpecialCaseList.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Scalar.h" @@ -69,6 +72,10 @@ #include "llvm/Transforms/Utils/ValueMapper.h" #include #include +#include +#include +#include +#include #define DEBUG_TYPE "hotcoldsplit" @@ -85,6 +92,18 @@ cl::desc("Base penalty for splitting cold code (as a " "multiple of TCC_Basic)")); +static cl::opt + ColdFunctionsList("cold-functions-list", cl::init(""), cl::Hidden, + cl::desc("Comma-separated list of functions to mark" + " as cold during hot/cold splitting.")); + +static cl::opt + ColdFunctionsFile("cold-functions-file", cl::init(""), cl::Hidden, + cl::desc("File name containing a newline-separated list" + " of function names to mark as cold during" + " hot/cold splitting.")); + + namespace { // Same as blockEndsInUnreachable in CodeGen/BranchFolding.cpp. Do not modify // this function unless you modify the MBB version as well. @@ -202,6 +221,22 @@ if (PSI->isFunctionEntryCold(&F)) return true; + // Alternatively, if user supplies any extra information + // on cold functions via command-line or file input, + // use them to determine if function is cold or not. + if (CmdMarkedColdFunctions.find(F.getName()) != + CmdMarkedColdFunctions.end()) { + LLVM_DEBUG(dbgs() << "isFunctionCold: " << F.getName() << " is cold " + << " via command line info.\n"); + return true; + } + + if (FileMarkedColdFunctions && + FileMarkedColdFunctions->inSection("", "", F.getName())) { + LLVM_DEBUG(dbgs() << "isFunctionCold: " << F.getName() << " is cold " + << " via file info.\n"); + return true; + } return false; } @@ -656,6 +691,32 @@ bool HotColdSplitting::run(Module &M) { bool Changed = false; bool HasProfileSummary = (M.getProfileSummary(/* IsCS */ false) != nullptr); + + // Read in user-defined cold function names, if any. + if (ColdFunctionsList != "") { + LLVM_DEBUG(dbgs() << "Reading in cold functions from command line.\n"); + std::stringstream CFStream(ColdFunctionsList); + while (CFStream.good()) { + std::string CFName; + std::getline(CFStream, CFName, ','); + LLVM_DEBUG(dbgs() << " Function " << CFName + << " listed as cold from command line.\n"); + CmdMarkedColdFunctions.insert(CFName); + } + } + + // Read in user-defined cold function names supplied + // by a file. + if (ColdFunctionsFile != "") { + // Use the SpecialCaseList helper to read in the + // cold functions file. + LLVM_DEBUG(dbgs() << "Reading in functions from file " + << ColdFunctionsFile); + std::unique_ptr FS = vfs::createPhysicalFileSystem(); + FileMarkedColdFunctions = + SpecialCaseList::createOrDie({ColdFunctionsFile}, *FS); + } + for (auto It = M.begin(), End = M.end(); It != End; ++It) { Function &F = *It;