Index: include/llvm/Analysis/AssumptionCache.h =================================================================== --- include/llvm/Analysis/AssumptionCache.h +++ include/llvm/Analysis/AssumptionCache.h @@ -208,6 +208,10 @@ /// existing cache will be returned. AssumptionCache &getAssumptionCache(Function &F); + /// Return the cached assumptions for a function if it has already been + /// scanned. Otherwise return nullptr. + AssumptionCache *lookupAssumptionCache(Function &F); + AssumptionCacheTracker(); ~AssumptionCacheTracker() override; Index: include/llvm/Transforms/Utils/CodeExtractor.h =================================================================== --- include/llvm/Transforms/Utils/CodeExtractor.h +++ include/llvm/Transforms/Utils/CodeExtractor.h @@ -26,6 +26,7 @@ class BlockFrequency; class BlockFrequencyInfo; class BranchProbabilityInfo; +class AssumptionCache; class CallInst; class DominatorTree; class Function; @@ -56,6 +57,7 @@ const bool AggregateArgs; BlockFrequencyInfo *BFI; BranchProbabilityInfo *BPI; + AssumptionCache *AC; // If true, varargs functions can be extracted. bool AllowVarArgs; @@ -84,6 +86,7 @@ CodeExtractor(ArrayRef BBs, DominatorTree *DT = nullptr, bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr, BranchProbabilityInfo *BPI = nullptr, + AssumptionCache *AC = nullptr, bool AllowVarArgs = false, bool AllowAlloca = false, std::string Suffix = ""); @@ -94,6 +97,7 @@ CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr, BranchProbabilityInfo *BPI = nullptr, + AssumptionCache *AC = nullptr, std::string Suffix = ""); /// Perform the extraction, returning the new function. Index: lib/Analysis/AssumptionCache.cpp =================================================================== --- lib/Analysis/AssumptionCache.cpp +++ lib/Analysis/AssumptionCache.cpp @@ -240,6 +240,13 @@ return *IP.first->second; } +AssumptionCache *AssumptionCacheTracker::lookupAssumptionCache(Function &F) { + auto I = AssumptionCaches.find_as(&F); + if (I != AssumptionCaches.end()) + return I->second.get(); + return nullptr; +} + void AssumptionCacheTracker::verifyAnalysis() const { // FIXME: In the long term the verifier should not be controllable with a // flag. We should either fix all passes to correctly update the assumption Index: lib/Transforms/IPO/HotColdSplitting.cpp =================================================================== --- lib/Transforms/IPO/HotColdSplitting.cpp +++ lib/Transforms/IPO/HotColdSplitting.cpp @@ -164,8 +164,9 @@ HotColdSplitting(ProfileSummaryInfo *ProfSI, function_ref GBFI, function_ref GTTI, - std::function *GORE) - : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE) {} + std::function *GORE, + function_ref LAC) + : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE), LookupAC(LAC) {} bool run(Module &M); private: @@ -174,11 +175,13 @@ bool outlineColdRegions(Function &F, bool HasProfileSummary); Function *extractColdRegion(const BlockSequence &Region, DominatorTree &DT, BlockFrequencyInfo *BFI, TargetTransformInfo &TTI, - OptimizationRemarkEmitter &ORE, unsigned Count); + OptimizationRemarkEmitter &ORE, + AssumptionCache *AC, unsigned Count); ProfileSummaryInfo *PSI; function_ref GetBFI; function_ref GetTTI; std::function *GetORE; + function_ref LookupAC; }; class HotColdSplittingLegacyPass : public ModulePass { @@ -189,10 +192,10 @@ } void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addUsedIfAvailable(); } bool runOnModule(Module &M) override; @@ -307,12 +310,13 @@ BlockFrequencyInfo *BFI, TargetTransformInfo &TTI, OptimizationRemarkEmitter &ORE, + AssumptionCache *AC, unsigned Count) { assert(!Region.empty()); // TODO: Pass BFI and BPI to update profile information. CodeExtractor CE(Region, &DT, /* AggregateArgs */ false, /* BFI */ nullptr, - /* BPI */ nullptr, /* AllowVarArgs */ false, + /* BPI */ nullptr, AC, /* AllowVarArgs */ false, /* AllowAlloca */ false, /* Suffix */ "cold." + std::to_string(Count)); @@ -568,6 +572,7 @@ TargetTransformInfo &TTI = GetTTI(F); OptimizationRemarkEmitter &ORE = (*GetORE)(F); + AssumptionCache *AC = LookupAC(F); // Find all cold regions. for (BasicBlock *BB : RPOT) { @@ -629,8 +634,8 @@ BB->dump(); }); - Function *Outlined = - extractColdRegion(SubRegion, *DT, BFI, TTI, ORE, OutlinedFunctionID); + Function *Outlined = extractColdRegion(SubRegion, *DT, BFI, TTI, ORE, AC, + OutlinedFunctionID); if (Outlined) { ++OutlinedFunctionID; Changed = true; @@ -689,17 +694,21 @@ ORE.reset(new OptimizationRemarkEmitter(&F)); return *ORE.get(); }; + auto LookupAC = [this](Function &F) -> AssumptionCache * { + if (auto *ACT = getAnalysisIfAvailable()) + return ACT->lookupAssumptionCache(F); + return nullptr; + }; - return HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M); + return HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M); } PreservedAnalyses HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) { auto &FAM = AM.getResult(M).getManager(); - std::function GetAssumptionCache = - [&FAM](Function &F) -> AssumptionCache & { - return FAM.getResult(F); + auto LookupAC = [&FAM](Function &F) -> AssumptionCache * { + return FAM.getCachedResult(F); }; auto GBFI = [&FAM](Function &F) { @@ -720,7 +729,7 @@ ProfileSummaryInfo *PSI = &AM.getResult(M); - if (HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M)) + if (HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); } Index: lib/Transforms/IPO/LoopExtractor.cpp =================================================================== --- lib/Transforms/IPO/LoopExtractor.cpp +++ lib/Transforms/IPO/LoopExtractor.cpp @@ -14,6 +14,7 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" @@ -50,6 +51,7 @@ AU.addRequiredID(LoopSimplifyID); AU.addRequired(); AU.addRequired(); + AU.addUsedIfAvailable(); } }; } @@ -138,7 +140,10 @@ if (ShouldExtractLoop) { if (NumLoops == 0) return Changed; --NumLoops; - CodeExtractor Extractor(DT, *L); + AssumptionCache *AC = nullptr; + if (auto *ACT = getAnalysisIfAvailable()) + AC = ACT->lookupAssumptionCache(*L->getHeader()->getParent()); + CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC); if (Extractor.extractCodeRegion() != nullptr) { Changed = true; // After extraction, the loop is replaced by a function call, so Index: lib/Transforms/IPO/PartialInlining.cpp =================================================================== --- lib/Transforms/IPO/PartialInlining.cpp +++ lib/Transforms/IPO/PartialInlining.cpp @@ -199,10 +199,12 @@ PartialInlinerImpl( std::function *GetAC, + function_ref LookupAC, std::function *GTTI, Optional> GBFI, ProfileSummaryInfo *ProfSI) - : GetAssumptionCache(GetAC), GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {} + : GetAssumptionCache(GetAC), LookupAssumptionCache(LookupAC), + GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {} bool run(Module &M); // Main part of the transformation that calls helper functions to find @@ -222,9 +224,11 @@ // Two constructors, one for single region outlining, the other for // multi-region outlining. FunctionCloner(Function *F, FunctionOutliningInfo *OI, - OptimizationRemarkEmitter &ORE); + OptimizationRemarkEmitter &ORE, + function_ref LookupAC); FunctionCloner(Function *F, FunctionOutliningMultiRegionInfo *OMRI, - OptimizationRemarkEmitter &ORE); + OptimizationRemarkEmitter &ORE, + function_ref LookupAC); ~FunctionCloner(); // Prepare for function outlining: making sure there is only @@ -260,11 +264,13 @@ std::unique_ptr ClonedOMRI = nullptr; std::unique_ptr ClonedFuncBFI = nullptr; OptimizationRemarkEmitter &ORE; + function_ref LookupAC; }; private: int NumPartialInlining = 0; std::function *GetAssumptionCache; + function_ref LookupAssumptionCache; std::function *GetTTI; Optional> GetBFI; ProfileSummaryInfo *PSI; @@ -365,12 +371,17 @@ return ACT->getAssumptionCache(F); }; + auto LookupAssumptionCache = [ACT](Function &F) -> AssumptionCache * { + return ACT->lookupAssumptionCache(F); + }; + std::function GetTTI = [&TTIWP](Function &F) -> TargetTransformInfo & { return TTIWP->getTTI(F); }; - return PartialInlinerImpl(&GetAssumptionCache, &GetTTI, NoneType::None, PSI) + return PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache, + &GetTTI, NoneType::None, PSI) .run(M); } }; @@ -948,8 +959,9 @@ } PartialInlinerImpl::FunctionCloner::FunctionCloner( - Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE) - : OrigFunc(F), ORE(ORE) { + Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE, + function_ref LookupAC) + : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) { ClonedOI = llvm::make_unique(); // Clone the function, so that we can hack away on it. @@ -972,8 +984,9 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner( Function *F, FunctionOutliningMultiRegionInfo *OI, - OptimizationRemarkEmitter &ORE) - : OrigFunc(F), ORE(ORE) { + OptimizationRemarkEmitter &ORE, + function_ref LookupAC) + : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) { ClonedOMRI = llvm::make_unique(); // Clone the function, so that we can hack away on it. @@ -1111,7 +1124,9 @@ int CurrentOutlinedRegionCost = ComputeRegionCost(RegionInfo.Region); CodeExtractor CE(RegionInfo.Region, &DT, /*AggregateArgs*/ false, - ClonedFuncBFI.get(), &BPI, /* AllowVarargs */ false); + ClonedFuncBFI.get(), &BPI, + LookupAC(*RegionInfo.EntryBlock->getParent()), + /* AllowVarargs */ false); CE.findInputsOutputs(Inputs, Outputs, Sinks); @@ -1193,7 +1208,7 @@ // Extract the body of the if. Function *OutlinedFunc = CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, - ClonedFuncBFI.get(), &BPI, + ClonedFuncBFI.get(), &BPI, LookupAC(*ClonedFunc), /* AllowVarargs */ true) .extractCodeRegion(); @@ -1257,7 +1272,7 @@ std::unique_ptr OMRI = computeOutliningColdRegionsInfo(F, ORE); if (OMRI) { - FunctionCloner Cloner(F, OMRI.get(), ORE); + FunctionCloner Cloner(F, OMRI.get(), ORE, LookupAssumptionCache); #ifndef NDEBUG if (TracePartialInlining) { @@ -1290,7 +1305,7 @@ if (!OI) return {false, nullptr}; - FunctionCloner Cloner(F, OI.get(), ORE); + FunctionCloner Cloner(F, OI.get(), ORE, LookupAssumptionCache); Cloner.NormalizeReturnBlock(); Function *OutlinedFunction = Cloner.doSingleRegionFunctionOutlining(); @@ -1484,6 +1499,10 @@ return FAM.getResult(F); }; + auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * { + return FAM.getCachedResult(F); + }; + std::function GetBFI = [&FAM](Function &F) -> BlockFrequencyInfo & { return FAM.getResult(F); @@ -1496,7 +1515,8 @@ ProfileSummaryInfo *PSI = &AM.getResult(M); - if (PartialInlinerImpl(&GetAssumptionCache, &GetTTI, {GetBFI}, PSI) + if (PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache, &GetTTI, + {GetBFI}, PSI) .run(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); Index: lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- lib/Transforms/Utils/CodeExtractor.cpp +++ lib/Transforms/Utils/CodeExtractor.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BlockFrequencyInfoImpl.h" #include "llvm/Analysis/BranchProbabilityInfo.h" @@ -235,18 +236,20 @@ CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI, bool AllowVarArgs, - bool AllowAlloca, std::string Suffix) + BranchProbabilityInfo *BPI, AssumptionCache *AC, + bool AllowVarArgs, bool AllowAlloca, + std::string Suffix) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), AllowVarArgs(AllowVarArgs), + BPI(BPI), AC(AC), AllowVarArgs(AllowVarArgs), Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)), Suffix(Suffix) {} CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI, std::string Suffix) + BranchProbabilityInfo *BPI, AssumptionCache *AC, + std::string Suffix) : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), AllowVarArgs(false), + BPI(BPI), AC(AC), AllowVarArgs(false), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT, /* AllowVarArgs */ false, /* AllowAlloca */ false)), @@ -1414,6 +1417,10 @@ if (oldFunction->hasPersonalityFn()) newFunction->setPersonalityFn(oldFunction->getPersonalityFn()); + // Invalidate function's assumption cache to force re-scanning. + if (AC) + AC->clear(); + // Update the branch weights for the exit block. if (BFI && NumExitBlocks > 1) calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); Index: test/Transforms/CodeExtractor/extract-assume.ll =================================================================== --- test/Transforms/CodeExtractor/extract-assume.ll +++ test/Transforms/CodeExtractor/extract-assume.ll @@ -0,0 +1,29 @@ +; RUN: opt -passes="function(slp-vectorizer),module(hotcoldsplit),function(slp-vectorizer,print)" -disable-output %s 2>&1 | FileCheck %s +; +; Make sure this compiles. Check that function assumption cache is refreshed +; after extracting blocks with assume calls from the function. + +; CHECK: Cached assumptions for function: fun +; CHECK-NEXT: Cached assumptions for function: fun.cold +; CHECK-NEXT: %cmp = icmp uge i32 %x, 64 + +declare void @fun2(i32) #0 + +define void @fun(i32 %x) { +entry: + br i1 undef, label %if.then, label %if.else + +if.then: + ret void + +if.else: + %cmp = icmp uge i32 %x, 64 + call void @llvm.assume(i1 %cmp) + call void @fun2(i32 %x) + unreachable +} + +declare void @llvm.assume(i1) #1 + +attributes #0 = { alwaysinline } +attributes #1 = { nounwind }