Index: llvm/lib/Analysis/LoopInfo.cpp =================================================================== --- llvm/lib/Analysis/LoopInfo.cpp +++ llvm/lib/Analysis/LoopInfo.cpp @@ -301,15 +301,16 @@ if (!CmpInst) return nullptr; - Instruction *LatchCmpOp0 = dyn_cast(CmpInst->getOperand(0)); - Instruction *LatchCmpOp1 = dyn_cast(CmpInst->getOperand(1)); + Value *LatchCmpOp0 = CmpInst->getOperand(0); + Value *LatchCmpOp1 = CmpInst->getOperand(1); for (PHINode &IndVar : Header->phis()) { InductionDescriptor IndDesc; if (!InductionDescriptor::isInductionPHI(&IndVar, this, &SE, IndDesc)) continue; - Instruction *StepInst = IndDesc.getInductionBinOp(); + BasicBlock *Latch = this->getLoopLatch(); + Value *StepInst = IndVar.getIncomingValueForBlock(Latch); // case 1: // IndVar = phi[{InitialValue, preheader}, {StepInst, latch}] Index: llvm/unittests/Analysis/LoopInfoTest.cpp =================================================================== --- llvm/unittests/Analysis/LoopInfoTest.cpp +++ llvm/unittests/Analysis/LoopInfoTest.cpp @@ -1547,3 +1547,49 @@ EXPECT_EQ(L->getLoopGuardBranch(), nullptr); }); } + +TEST(LoopInfoTest, LoopInductionVariable) { + const char *ModuleStr = + "; Function Attrs: nofree norecurse nosync nounwind readonly\n" + "define dso_local signext i32 @foo(i8* nocapture readonly %hdrptr) local_unnamed_addr #0 {\n" + "entry:\n" + " %0 = bitcast i8* %hdrptr to i16*\n" + " br label %for.body\n" + "for.body: ; preds = %entry, %for.body\n" + " %addr.010 = phi i16* [ %0, %entry ], [ %incdec.ptr, %for.body ]\n" + " %sum.09 = phi i64 [ 0, %entry ], [ %add, %for.body ]\n" + " %count.08 = phi i32 [ 6000, %entry ], [ %dec, %for.body ]\n" + " %1 = load i16, i16* %addr.010, align 2, !tbaa !4\n" + " %conv = zext i16 %1 to i64\n" + " %add = add i64 %sum.09, %conv\n" + " %dec = add nsw i32 %count.08, -1\n" + " %incdec.ptr = getelementptr inbounds i16, i16* %addr.010, i64 1\n" + " %cmp = icmp ugt i32 %count.08, 1\n" + " br i1 %cmp, label %for.body, label %for.end, !llvm.loop !8\n" + "for.end: ; preds = %for.body\n" + " %conv27 = and i64 %add, 65535\n" + " %cmp3 = icmp eq i64 %conv27, 65535\n" + " %conv4 = zext i1 %cmp3 to i32\n" + " ret i32 %conv4\n" + "}\n" + "!4 = !{!5, !5, i64 0}\n" + "!5 = !{!\"short\", !6, i64 0}\n" + "!6 = !{!\"omnipotent char\", !7, i64 0}\n" + "!7 = !{!\"Simple C/C++ TBAA\"}\n" + "!8 = distinct !{!8, !9}\n" + "!9 = !{!\"llvm.loop.unroll.disable\"}\n"; + + // Parse the module. + LLVMContext Context; + std::unique_ptr M = makeLLVMModule(Context, ModuleStr); + + runWithLoopInfoPlus( + *M, "foo", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + Function::iterator FI = F.begin(); + // First basic block is entry - skip it. + BasicBlock *Header = &*(++FI); + Loop *L = LI.getLoopFor(Header); + EXPECT_NE(L, nullptr); + EXPECT_EQ(L->getInductionVariable(SE)->getName(), "count.08"); + }); +}