Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -29,6 +29,7 @@ #include "llvm/ADT/Hashing.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -1262,6 +1263,10 @@ /// Invalidate this result and free associated memory. void clear(); + + /// Insert all loops referred to by this BackedgeTakenCount into \p Result. + void findUsedLoops(ScalarEvolution &SE, + SmallPtrSet &Result) const; }; /// Cache the backedge-taken count of the loops for this function as they @@ -1764,14 +1769,20 @@ /// Find all of the loops transitively used in \p S, and update \c LoopUsers /// accordingly. void addToLoopUseLists(const SCEV *S); + void addToLoopUseLists(const BackedgeTakenInfo &BTI, const Loop *L); FoldingSet UniqueSCEVs; FoldingSet UniquePreds; BumpPtrAllocator SCEVAllocator; - /// This maps loops to a list of SCEV expressions that (transitively) use said - /// loop. - DenseMap> LoopUsers; + /// This maps loops to a list of entities that (transitively) use said loop. + /// A SCEV expression in the vector corresponding to a loop denotes that the + /// SCEV expression transitively uses said loop. A loop (LA) in the vector + /// corresponding to another loop (LB) denotes that LB is used in one of the + /// cached trip counts for LA. + DenseMap, 4>> + LoopUsers; /// Cache tentative mappings from UnknownSCEVs in a Loop, to a SCEV expression /// they can be rewritten into under certain predicates. Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -6293,6 +6293,7 @@ BackedgeTakenInfo Result = computeBackedgeTakenCount(L, /*AllowPredicates=*/true); + addToLoopUseLists(Result, L); return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result); } @@ -6368,6 +6369,7 @@ // recusive call to getBackedgeTakenInfo (on a different // loop), which would invalidate the iterator computed // earlier. + addToLoopUseLists(Result, L); return BackedgeTakenCounts.find(L)->second = std::move(Result); } @@ -6405,8 +6407,14 @@ auto LoopUsersItr = LoopUsers.find(CurrL); if (LoopUsersItr != LoopUsers.end()) { - for (auto *S : LoopUsersItr->second) - forgetMemoizedResults(S); + for (auto LoopOrSCEV : LoopUsersItr->second) { + if (auto *S = LoopOrSCEV.dyn_cast()) + forgetMemoizedResults(S); + else { + BackedgeTakenCounts.erase(LoopOrSCEV.get()); + PredicatedBackedgeTakenCounts.erase(LoopOrSCEV.get()); + } + } LoopUsers.erase(LoopUsersItr); } @@ -6551,6 +6559,34 @@ return false; } +static void findUsedLoopsInSCEVExpr(const SCEV *S, + SmallPtrSet &Result) { + struct FindUsedLoops { + SmallPtrSetImpl &LoopsUsed; + FindUsedLoops(SmallPtrSetImpl &LoopsUsed) + : LoopsUsed(LoopsUsed) {} + bool follow(const SCEV *S) { + if (auto *AR = dyn_cast(S)) + LoopsUsed.insert(AR->getLoop()); + return true; + } + + bool isDone() const { return false; } + }; + FindUsedLoops F(Result); + SCEVTraversal(F).visitAll(S); +} + +void ScalarEvolution::BackedgeTakenInfo::findUsedLoops( + ScalarEvolution &SE, SmallPtrSet &Result) const { + if (auto *S = getMax()) + if (S != SE.getCouldNotCompute()) + findUsedLoopsInSCEVExpr(S, Result); + for (auto &ENT : ExitNotTaken) + if (ENT.ExactNotTaken != SE.getCouldNotCompute()) + findUsedLoopsInSCEVExpr(ENT.ExactNotTaken, Result); +} + ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) : ExactNotTaken(E), MaxNotTaken(E) { assert((isa(MaxNotTaken) || @@ -11012,21 +11048,6 @@ ++I; } - auto RemoveSCEVFromBackedgeMap = - [S, this](DenseMap &Map) { - for (auto I = Map.begin(), E = Map.end(); I != E;) { - BackedgeTakenInfo &BEInfo = I->second; - if (BEInfo.hasOperand(S, this)) { - BEInfo.clear(); - Map.erase(I++); - } else - ++I; - } - }; - - RemoveSCEVFromBackedgeMap(BackedgeTakenCounts); - RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); - // TODO: There is a suspicion that we only need to do it when there is a // SCEVUnknown somewhere inside S. Need to check this. if (EraseExitLimit) @@ -11036,22 +11057,20 @@ } void ScalarEvolution::addToLoopUseLists(const SCEV *S) { - struct FindUsedLoops { - SmallPtrSet LoopsUsed; - bool follow(const SCEV *S) { - if (auto *AR = dyn_cast(S)) - LoopsUsed.insert(AR->getLoop()); - return true; - } - - bool isDone() const { return false; } - }; + SmallPtrSet LoopsUsed; + findUsedLoopsInSCEVExpr(S, LoopsUsed); + for (auto *L : LoopsUsed) + LoopUsers[L].push_back({S}); +} - FindUsedLoops F; - SCEVTraversal(F).visitAll(S); +void ScalarEvolution::addToLoopUseLists( + const ScalarEvolution::BackedgeTakenInfo &BTI, const Loop *L) { + SmallPtrSet LoopsUsed; + BTI.findUsedLoops(*this, LoopsUsed); - for (auto *L : F.LoopsUsed) - LoopUsers[L].push_back(S); + for (auto *UsedL : LoopsUsed) { + LoopUsers[UsedL].push_back({L}); + } } void ScalarEvolution::verify() const { Index: unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- unittests/Analysis/ScalarEvolutionTest.cpp +++ unittests/Analysis/ScalarEvolutionTest.cpp @@ -24,11 +24,19 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/Support/SourceMgr.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" namespace llvm { namespace { +MATCHER_P3(IsAffineAddRec, S, X, L, "") { + if (auto *AR = dyn_cast(arg)) + return AR->isAffine() && AR->getLoop() == L && AR->getOperand(0) == S && + AR->getOperand(1) == X; + return false; +} + // We use this fixture to ensure that we clean up ScalarEvolution before // deleting the PassManager. class ScalarEvolutionsTest : public testing::Test { @@ -886,90 +894,6 @@ 2004u); } -// Make sure that SCEV invalidates exit limits after invalidating the values it -// depends on when we forget a value. -TEST_F(ScalarEvolutionsTest, SCEVExitLimitForgetValue) { - /* - * Create the following code: - * func(i64 addrspace(10)* %arg) - * top: - * br label %L.ph - * L.ph: - * %load = load i64 addrspace(10)* %arg - * br label %L - * L: - * %phi = phi i64 [i64 0, %L.ph], [ %add, %L2 ] - * %add = add i64 %phi2, 1 - * %cond = icmp slt i64 %add, %load ; then becomes 2000. - * br i1 %cond, label %post, label %L2 - * post: - * ret void - * - */ - - // Create a module with non-integral pointers in it's datalayout - Module NIM("nonintegral", Context); - std::string DataLayout = M.getDataLayoutStr(); - if (!DataLayout.empty()) - DataLayout += "-"; - DataLayout += "ni:10"; - NIM.setDataLayout(DataLayout); - - Type *T_int64 = Type::getInt64Ty(Context); - Type *T_pint64 = T_int64->getPointerTo(10); - - FunctionType *FTy = - FunctionType::get(Type::getVoidTy(Context), {T_pint64}, false); - Function *F = cast(NIM.getOrInsertFunction("foo", FTy)); - - Argument *Arg = &*F->arg_begin(); - - BasicBlock *Top = BasicBlock::Create(Context, "top", F); - BasicBlock *LPh = BasicBlock::Create(Context, "L.ph", F); - BasicBlock *L = BasicBlock::Create(Context, "L", F); - BasicBlock *Post = BasicBlock::Create(Context, "post", F); - - IRBuilder<> Builder(Top); - Builder.CreateBr(LPh); - - Builder.SetInsertPoint(LPh); - auto *Load = cast(Builder.CreateLoad(T_int64, Arg, "load")); - Builder.CreateBr(L); - - Builder.SetInsertPoint(L); - PHINode *Phi = Builder.CreatePHI(T_int64, 2); - auto *Add = cast( - Builder.CreateAdd(Phi, ConstantInt::get(T_int64, 1), "add")); - auto *Cond = cast( - Builder.CreateICmp(ICmpInst::ICMP_SLT, Add, Load, "cond")); - auto *Br = cast(Builder.CreateCondBr(Cond, L, Post)); - Phi->addIncoming(ConstantInt::get(T_int64, 0), LPh); - Phi->addIncoming(Add, L); - - Builder.SetInsertPoint(Post); - Builder.CreateRetVoid(); - - ScalarEvolution SE = buildSE(*F); - auto *Loop = LI->getLoopFor(L); - const SCEV *EC = SE.getBackedgeTakenCount(Loop); - EXPECT_FALSE(isa(EC)); - EXPECT_FALSE(isa(EC)); - - SE.forgetValue(Load); - Br->eraseFromParent(); - Cond->eraseFromParent(); - Load->eraseFromParent(); - - Builder.SetInsertPoint(L); - auto *NewCond = Builder.CreateICmp( - ICmpInst::ICMP_SLT, Add, ConstantInt::get(T_int64, 2000), "new.cond"); - Builder.CreateCondBr(NewCond, L, Post); - const SCEV *NewEC = SE.getBackedgeTakenCount(Loop); - EXPECT_FALSE(isa(NewEC)); - EXPECT_TRUE(isa(NewEC)); - EXPECT_EQ(cast(NewEC)->getAPInt().getLimitedValue(), 1999u); -} - TEST_F(ScalarEvolutionsTest, SCEVAddRecFromPHIwithLargeConstants) { // Reference: https://reviews.llvm.org/D37265 // Make sure that SCEV does not blow up when constructing an AddRec @@ -1082,6 +1006,75 @@ auto Result = SE.createAddRecFromPHIWithCasts(cast(Expr)); } +TEST_F(ScalarEvolutionsTest, SCEVForgetDependentLoop) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" " + " " + "define void @f(i32 %first_limit, i1* %cond) { " + "entry: " + " br label %first_loop.ph " + " " + "first_loop.ph: " + " br label %first_loop " + " " + "first_loop: " + " %iv_first = phi i32 [0, %first_loop.ph], [%iv_first.inc, %first_loop] " + " %iv_first.inc = add i32 %iv_first, 1 " + " %known_cond = icmp slt i32 %iv_first, 2000 " + " %unknown_cond = load volatile i1, i1* %cond " + " br i1 %unknown_cond, label %first_loop, label %first_loop.exit " + " " + "first_loop.exit: " + " %iv_first.3x = mul i32 %iv_first, 3 " + " %iv_first.5x = mul i32 %iv_first, 5 " + " br label %second_loop.ph " + " " + "second_loop.ph: " + " br label %second_loop " + " " + "second_loop: " + " %iv_second = phi i32 [%iv_first.3x, %second_loop.ph], [%iv_second.inc, %second_loop] " + " %iv_second.inc = add i32 %iv_second, 1 " + " %second_loop.cond = icmp ne i32 %iv_second, %iv_first.5x " + " br i1 %second_loop.cond, label %second_loop, label %second_loop.exit " + " " + "second_loop.exit: " + " ret void " + "} " + " ", + Err, C); + + assert(M && "Could not parse module?"); + assert(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "f", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto &FirstIV = GetInstByName(F, "iv_first"); + auto &SecondIV = GetInstByName(F, "iv_second"); + + auto *FirstLoop = LI.getLoopFor(FirstIV.getParent()); + auto *SecondLoop = LI.getLoopFor(SecondIV.getParent()); + + auto *Zero = SE.getZero(FirstIV.getType()); + auto *Two = SE.getConstant(APInt(32, 2)); + + EXPECT_EQ(SE.getBackedgeTakenCount(FirstLoop), SE.getCouldNotCompute()); + EXPECT_THAT(SE.getBackedgeTakenCount(SecondLoop), + IsAffineAddRec(Zero, Two, FirstLoop)); + + auto &UnknownCond = GetInstByName(F, "unknown_cond"); + auto &KnownCond = GetInstByName(F, "known_cond"); + + UnknownCond.replaceAllUsesWith(&KnownCond); + + SE.forgetLoop(FirstLoop); + + EXPECT_EQ(SE.getBackedgeTakenCount(FirstLoop), SE.getConstant(APInt(32, 2000))); + EXPECT_EQ(SE.getBackedgeTakenCount(SecondLoop), SE.getConstant(APInt(32, 4000))); + }); +} + TEST_F(ScalarEvolutionsTest, SCEVFoldSumOfTruncs) { // Verify that the following SCEV gets folded to a zero: // (-1 * (trunc i64 (-1 * %0) to i32)) + (-1 * (trunc i64 %0 to i32)