diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -496,17 +496,27 @@ for (Value *V : FI.LinearIVUses) { for (Value *U : V->users()) { if (auto *GEP = dyn_cast(U)) { - // The IV is used as the operand of a GEP, and the IV is at least as - // wide as the address space of the GEP. In this case, the GEP would - // wrap around the address space before the IV increment wraps, which - // would be UB. - if (GEP->isInBounds() && - V->getType()->getIntegerBitWidth() >= - DL.getPointerTypeSizeInBits(GEP->getType())) { - LLVM_DEBUG( - dbgs() << "use of linear IV would be UB if overflow occurred: "; - GEP->dump()); - return OverflowResult::NeverOverflows; + for (Value *GEPUser : U->users()) { + Instruction *GEPUserInst = dyn_cast(GEPUser); + if (!isa(GEPUserInst) && + !(isa(GEPUserInst) && + GEP == GEPUserInst->getOperand(1))) + continue; + if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst, + FI.InnerLoop)) + continue; + // The IV is used as the operand of a GEP which dominates the loop + // latch, and the IV is at least as wide as the address space of the + // GEP. In this case, the GEP would wrap around the address space + // before the IV increment wraps, which would be UB. + if (GEP->isInBounds() && + V->getType()->getIntegerBitWidth() >= + DL.getPointerTypeSizeInBits(GEP->getType())) { + LLVM_DEBUG( + dbgs() << "use of linear IV would be UB if overflow occurred: "; + GEP->dump()); + return OverflowResult::NeverOverflows; + } } } } diff --git a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll --- a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll +++ b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll @@ -786,6 +786,54 @@ ret void } +; GEP doesn't dominate the loop latch so can't guarantee N*M won't overflow. +@first = global i32 1, align 4 +@a = external global [0 x i8], align 1 +define void @overflow(i32 %lim, i8* %a) { +entry: + %cmp17.not = icmp eq i32 %lim, 0 + br i1 %cmp17.not, label %for.cond.cleanup, label %for.cond1.preheader.preheader + +for.cond1.preheader.preheader: + br label %for.cond1.preheader + +for.cond1.preheader: + %i.018 = phi i32 [ %inc6, %for.cond.cleanup3 ], [ 0, %for.cond1.preheader.preheader ] + %mul = mul i32 %i.018, 100000 + br label %for.body4 + +for.cond.cleanup.loopexit: + br label %for.cond.cleanup + +for.cond.cleanup: + ret void + +for.cond.cleanup3: + %inc6 = add i32 %i.018, 1 + %cmp = icmp ult i32 %inc6, %lim + br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup.loopexit + +for.body4: + %j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %if.end ] + %add = add i32 %j.016, %mul + %0 = load i32, i32* @first, align 4 + %tobool.not = icmp eq i32 %0, 0 + br i1 %tobool.not, label %if.end, label %if.then + +if.then: + %arrayidx = getelementptr inbounds [0 x i8], [0 x i8]* @a, i32 0, i32 %add + %1 = load i8, i8* %arrayidx, align 1 + tail call void asm sideeffect "", "r"(i8 %1) + store i32 0, i32* @first, align 4 + br label %if.end + +if.end: + tail call void asm sideeffect "", "r"(i32 %add) + %inc = add nuw nsw i32 %j.016, 1 + %cmp2 = icmp ult i32 %j.016, 99999 + br i1 %cmp2, label %for.body4, label %for.cond.cleanup3 +} + declare void @objc_enumerationMutation(i8*) declare dso_local void @f(i32*) declare dso_local void @g(...)