Index: llvm/lib/Transforms/Scalar/GuardWidening.cpp =================================================================== --- llvm/lib/Transforms/Scalar/GuardWidening.cpp +++ llvm/lib/Transforms/Scalar/GuardWidening.cpp @@ -42,6 +42,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" @@ -116,6 +117,7 @@ DominatorTree &DT; PostDominatorTree *PDT; LoopInfo &LI; + AssumptionCache ∾ MemorySSAUpdater *MSSAU; /// Together, these describe the region of interest. This might be all of @@ -273,10 +275,10 @@ public: explicit GuardWideningImpl(DominatorTree &DT, PostDominatorTree *PDT, - LoopInfo &LI, MemorySSAUpdater *MSSAU, - DomTreeNode *Root, - std::function BlockFilter) - : DT(DT), PDT(PDT), LI(LI), MSSAU(MSSAU), Root(Root), + LoopInfo &LI, AssumptionCache &AC, + MemorySSAUpdater *MSSAU, DomTreeNode *Root, + std::function BlockFilter) + : DT(DT), PDT(PDT), LI(LI), AC(AC), MSSAU(MSSAU), Root(Root), BlockFilter(BlockFilter) {} /// The entry point for this pass. @@ -468,7 +470,7 @@ if (!Inst || DT.dominates(Inst, Loc) || Visited.count(Inst)) return true; - if (!isSafeToSpeculativelyExecute(Inst, Loc, nullptr, &DT) || + if (!isSafeToSpeculativelyExecute(Inst, Loc, &AC, &DT) || Inst->mayReadFromMemory()) return false; @@ -488,7 +490,7 @@ if (!Inst || DT.dominates(Inst, Loc)) return; - assert(isSafeToSpeculativelyExecute(Inst, Loc, nullptr, &DT) && + assert(isSafeToSpeculativelyExecute(Inst, Loc, &AC, &DT) && !Inst->mayReadFromMemory() && "Should've checked with isAvailableAt!"); for (Value *Op : Inst->operands()) @@ -764,11 +766,12 @@ auto &DT = AM.getResult(F); auto &LI = AM.getResult(F); auto &PDT = AM.getResult(F); + auto &AC = AM.getResult(F); auto *MSSAA = AM.getCachedResult(F); std::unique_ptr MSSAU; if (MSSAA) MSSAU = std::make_unique(&MSSAA->getMSSA()); - if (!GuardWideningImpl(DT, &PDT, LI, MSSAU ? MSSAU.get() : nullptr, + if (!GuardWideningImpl(DT, &PDT, LI, AC, MSSAU ? MSSAU.get() : nullptr, DT.getRootNode(), [](BasicBlock *) { return true; }) .run()) return PreservedAnalyses::all(); @@ -791,8 +794,10 @@ std::unique_ptr MSSAU; if (AR.MSSA) MSSAU = std::make_unique(AR.MSSA); - if (!GuardWideningImpl(AR.DT, nullptr, AR.LI, MSSAU ? MSSAU.get() : nullptr, - AR.DT.getNode(RootBB), BlockFilter).run()) + if (!GuardWideningImpl(AR.DT, nullptr, AR.LI, AR.AC, + MSSAU ? MSSAU.get() : nullptr, AR.DT.getNode(RootBB), + BlockFilter) + .run()) return PreservedAnalyses::all(); auto PA = getLoopPassPreservedAnalyses(); @@ -814,12 +819,13 @@ return false; auto &DT = getAnalysis().getDomTree(); auto &LI = getAnalysis().getLoopInfo(); + auto &AC = getAnalysis().getAssumptionCache(F); auto &PDT = getAnalysis().getPostDomTree(); auto *MSSAWP = getAnalysisIfAvailable(); std::unique_ptr MSSAU; if (MSSAWP) MSSAU = std::make_unique(&MSSAWP->getMSSA()); - return GuardWideningImpl(DT, &PDT, LI, MSSAU ? MSSAU.get() : nullptr, + return GuardWideningImpl(DT, &PDT, LI, AC, MSSAU ? MSSAU.get() : nullptr, DT.getRootNode(), [](BasicBlock *) { return true; }) .run(); @@ -848,6 +854,8 @@ return false; auto &DT = getAnalysis().getDomTree(); auto &LI = getAnalysis().getLoopInfo(); + auto &AC = getAnalysis().getAssumptionCache( + *L->getHeader()->getParent()); auto *PDTWP = getAnalysisIfAvailable(); auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr; auto *MSSAWP = getAnalysisIfAvailable(); @@ -861,8 +869,9 @@ auto BlockFilter = [&](BasicBlock *BB) { return BB == RootBB || L->contains(BB); }; - return GuardWideningImpl(DT, PDT, LI, MSSAU ? MSSAU.get() : nullptr, - DT.getNode(RootBB), BlockFilter).run(); + return GuardWideningImpl(DT, PDT, LI, AC, MSSAU ? MSSAU.get() : nullptr, + DT.getNode(RootBB), BlockFilter) + .run(); } void getAnalysisUsage(AnalysisUsage &AU) const override {