diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -161,10 +161,11 @@ class LoopIdiomRecognize { Loop *CurLoop = nullptr; - LoopNest *LN; + LoopNest *LN = nullptr; Loop *TopLoop = nullptr; Loop *FallBackLoop = nullptr; BasicBlock *RuntimeCheckBB = nullptr; + BranchInst *RuntimeCheckBI = nullptr; AliasAnalysis *AA; DominatorTree *DT; LoopInfo *LI; @@ -461,9 +462,7 @@ // the RuntimeCheckBB. Conditions are stored when: // - detect runtime store size in StridedStore (SizeAddrSpacePairList) if (Changed && isTopLoopVersioned()) { - // Get the branch instruction in the runtime check basic block. - BranchInst *BI = dyn_cast(RuntimeCheckBB->getTerminator()); - assert(BI && "Expects a BranchInst"); + assert(RuntimeCheckBI && "should be fetched when calling versionTopLoop()"); // Create conditional branch instructions with conditions: // - Store size overflow half of the width of the pointer @@ -472,20 +471,18 @@ LLVMContext &Context = TopLoop->getHeader()->getContext(); Value *Cond = ConstantInt::getFalse(Context); - IRBuilder<> Builder(BI); + IRBuilder<> Builder(RuntimeCheckBI); for (auto Pair : *SizeAddrSpacePairList) { const SCEV *Ev = Pair.first; unsigned AddrSpace = Pair.second; Value *NewCond0 = - generateOverflowPredicate(Ev, AddrSpace, BI, DL, SE, Builder); - Value *NewCond1 = generateSltZeroPredicate(Ev, BI, DL, SE, Builder); + generateOverflowPredicate(Ev, AddrSpace, RuntimeCheckBI, DL, SE, Builder); + Value *NewCond1 = generateSltZeroPredicate(Ev, RuntimeCheckBI, DL, SE, Builder); Cond = Builder.CreateOr(Cond, NewCond0); Cond = Builder.CreateOr(Cond, NewCond1); } - BranchInst::Create(FallBackLoop->getLoopPreheader(), - LN->getOutermostLoop().getLoopPreheader(), Cond, BI); - deleteDeadInstruction(BI); + RuntimeCheckBI->setCondition(Cond); } return Changed; @@ -1689,9 +1686,10 @@ LoopVersioning LV(LAI, LAI.getRuntimePointerChecking()->getChecks(), TopLoop, LI, DT, SE); - LV.versionLoopWithPlainRuntimeCheck(); + LV.versionLoop(); RuntimeCheckBB = LV.getRuntimeCheckBB(); + RuntimeCheckBI = LV.getRuntimeCheckBI(); FallBackLoop = LV.getNonVersionedLoop(); }