diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp --- a/llvm/lib/Analysis/LoopInfo.cpp +++ b/llvm/lib/Analysis/LoopInfo.cpp @@ -301,15 +301,16 @@ if (!CmpInst) return nullptr; - Instruction *LatchCmpOp0 = dyn_cast(CmpInst->getOperand(0)); - Instruction *LatchCmpOp1 = dyn_cast(CmpInst->getOperand(1)); + Value *LatchCmpOp0 = CmpInst->getOperand(0); + Value *LatchCmpOp1 = CmpInst->getOperand(1); for (PHINode &IndVar : Header->phis()) { InductionDescriptor IndDesc; if (!InductionDescriptor::isInductionPHI(&IndVar, this, &SE, IndDesc)) continue; - Instruction *StepInst = IndDesc.getInductionBinOp(); + BasicBlock *Latch = getLoopLatch(); + Value *StepInst = IndVar.getIncomingValueForBlock(Latch); // case 1: // IndVar = phi[{InitialValue, preheader}, {StepInst, latch}] diff --git a/llvm/unittests/Analysis/LoopInfoTest.cpp b/llvm/unittests/Analysis/LoopInfoTest.cpp --- a/llvm/unittests/Analysis/LoopInfoTest.cpp +++ b/llvm/unittests/Analysis/LoopInfoTest.cpp @@ -1547,3 +1547,39 @@ EXPECT_EQ(L->getLoopGuardBranch(), nullptr); }); } + +TEST(LoopInfoTest, LoopInductionVariable) { + const char *ModuleStr = + "define i32 @foo(i32* %addr) {\n" + "entry:\n" + " br label %for.body\n" + "for.body:\n" + " %sum.08 = phi i32 [ 0, %entry ], [ %add, %for.body ]\n" + " %addr.addr.06 = phi i32* [ %addr, %entry ], [ %incdec.ptr, %for.body " + "]\n" + " %count.07 = phi i32 [ 6000, %entry ], [ %dec, %for.body ]\n" + " %0 = load i32, i32* %addr.addr.06, align 4\n" + " %add = add nsw i32 %0, %sum.08\n" + " %dec = add nsw i32 %count.07, -1\n" + " %incdec.ptr = getelementptr inbounds i32, i32* %addr.addr.06, i64 1\n" + " %cmp = icmp ugt i32 %count.07, 1\n" + " br i1 %cmp, label %for.body, label %for.end\n" + "for.end:\n" + " %cmp1 = icmp eq i32 %add, -1\n" + " %conv = zext i1 %cmp1 to i32\n" + " ret i32 %conv\n" + "}\n"; + + // Parse the module. + LLVMContext Context; + std::unique_ptr M = makeLLVMModule(Context, ModuleStr); + + runWithLoopInfoPlus( + *M, "foo", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + Function::iterator FI = F.begin(); + BasicBlock *Header = &*(++FI); + Loop *L = LI.getLoopFor(Header); + EXPECT_NE(L, nullptr); + EXPECT_EQ(L->getInductionVariable(SE)->getName(), "count.07"); + }); +}