Index: include/llvm/Analysis/LoopUnrollAnalyzer.h =================================================================== --- include/llvm/Analysis/LoopUnrollAnalyzer.h +++ include/llvm/Analysis/LoopUnrollAnalyzer.h @@ -48,8 +48,8 @@ public: UnrolledInstAnalyzer(unsigned Iteration, DenseMap &SimplifiedValues, - ScalarEvolution &SE) - : SimplifiedValues(SimplifiedValues), SE(SE) { + ScalarEvolution &SE, const Loop *L) + : SimplifiedValues(SimplifiedValues), SE(SE), L(L) { IterationNumber = SE.getConstant(APInt(64, Iteration)); } @@ -80,6 +80,7 @@ DenseMap &SimplifiedValues; ScalarEvolution &SE; + const Loop *L; bool simplifyInstWithSCEV(Instruction *I); Index: lib/Analysis/LoopUnrollAnalyzer.cpp =================================================================== --- lib/Analysis/LoopUnrollAnalyzer.cpp +++ lib/Analysis/LoopUnrollAnalyzer.cpp @@ -40,6 +40,9 @@ if (!AR) return false; + if (AR->getLoop() != L) + return false; + const SCEV *ValueAtIteration = AR->evaluateAtIteration(IterationNumber, SE); // Check if the AddRec expression becomes a constant. if (auto *SC = dyn_cast(ValueAtIteration)) { Index: lib/Transforms/Scalar/LoopUnrollPass.cpp =================================================================== --- lib/Transforms/Scalar/LoopUnrollPass.cpp +++ lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -265,7 +265,7 @@ while (!SimplifiedInputValues.empty()) SimplifiedValues.insert(SimplifiedInputValues.pop_back_val()); - UnrolledInstAnalyzer Analyzer(Iteration, SimplifiedValues, SE); + UnrolledInstAnalyzer Analyzer(Iteration, SimplifiedValues, SE, L); BBWorklist.clear(); BBWorklist.insert(L->getHeader()); Index: unittests/Analysis/UnrollAnalyzer.cpp =================================================================== --- unittests/Analysis/UnrollAnalyzer.cpp +++ unittests/Analysis/UnrollAnalyzer.cpp @@ -38,7 +38,7 @@ TripCount = SE->getSmallConstantTripCount(L, Exiting); for (unsigned Iteration = 0; Iteration < TripCount; Iteration++) { DenseMap SimplifiedValues; - UnrolledInstAnalyzer Analyzer(Iteration, SimplifiedValues, *SE); + UnrolledInstAnalyzer Analyzer(Iteration, SimplifiedValues, *SE, L); for (auto *BB : L->getBlocks()) for (Instruction &I : *BB) Analyzer.visit(I); @@ -124,6 +124,53 @@ EXPECT_TRUE(I2 != SimplifiedValuesVector[TripCount - 1].end()); EXPECT_TRUE(dyn_cast((*I2).second)->getZExtValue()); } + +TEST(UnrollAnalyzerTest, OuterLoopSimplification) { + const char *ModuleStr = + "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" + "define void @foo() {\n" + "entry:\n" + " br label %outer.loop\n" + "outer.loop:\n" + " %iv.outer = phi i64 [ 0, %entry ], [ %iv.outer.next, %outer.loop.latch ]\n" + " br label %inner.loop\n" + "inner.loop:\n" + " %iv.inner = phi i64 [ 0, %outer.loop ], [ %iv.inner.next, %inner.loop ]\n" + " %iv.inner.next = add nuw nsw i64 %iv.inner, 1\n" + " %exitcond.inner = icmp eq i64 %iv.inner.next, 1000\n" + " br i1 %exitcond.inner, label %outer.loop.latch, label %inner.loop\n" + "outer.loop.latch:\n" + " %iv.outer.next = add nuw nsw i64 %iv.outer, 1\n" + " %exitcond.outer = icmp eq i64 %iv.outer.next, 40\n" + " br i1 %exitcond.outer, label %exit, label %outer.loop\n" + "exit:\n" + " ret void\n" + "}\n"; + + UnrollAnalyzerTest *P = new UnrollAnalyzerTest(); + std::unique_ptr M = makeLLVMModule(P, ModuleStr); + legacy::PassManager Passes; + Passes.add(P); + Passes.run(*M); + + Module::iterator MI = M->begin(); + Function *F = &*MI++; + Function::iterator FI = F->begin(); + FI++; + BasicBlock *Header = &*FI++; + BasicBlock *InnerBody = &*FI++; + + BasicBlock::iterator BBI = Header->begin(); + Instruction *Y1 = &*BBI++; + BBI = InnerBody->begin(); + Instruction *Y2 = &*BBI++; + // Check that we can simplify IV of the outer loop, but can't simplify the IV + // of the inner loop if we only know the iteration number of the outer loop. + auto I1 = SimplifiedValuesVector[0].find(Y1); + EXPECT_TRUE(I1 != SimplifiedValuesVector[0].end()); + auto I2 = SimplifiedValuesVector[0].find(Y2); + EXPECT_TRUE(I2 == SimplifiedValuesVector[0].end()); +} } // end namespace llvm INITIALIZE_PASS_BEGIN(UnrollAnalyzerTest, "unrollanalyzertestpass",