Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -2130,7 +2130,7 @@ } // Okay, check to see if the same value occurs in the operand list more than - // once. If so, merge them together into an multiply expression. Since we + // once. If so, merge them together into a multiply expression. Since we // sorted the list, these values are required to be adjacent. Type *Ty = Ops[0]->getType(); bool FoundMatch = false; Index: llvm/lib/Analysis/ScalarEvolutionExpander.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolutionExpander.cpp +++ llvm/lib/Analysis/ScalarEvolutionExpander.cpp @@ -166,6 +166,16 @@ return ReuseOrCreateCast(I, Ty, Op, IP); } +// Return true when S may contain the value zero. +static inline bool mayBeValueZero(Value *V) { + if (ConstantInt *C = dyn_cast(V)) + if (!C->isZero()) + return false; + + // All other expressions may have a zero value. + return true; +} + /// InsertBinop - Insert the specified binary operator, doing a small amount /// of work to avoid inserting an obviously redundant operation. Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode, @@ -198,14 +208,17 @@ DebugLoc Loc = Builder.GetInsertPoint()->getDebugLoc(); SCEVInsertPointGuard Guard(Builder, this); - // Move the insertion point out of as many loops as we can. - while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { - if (!L->isLoopInvariant(LHS) || !L->isLoopInvariant(RHS)) break; - BasicBlock *Preheader = L->getLoopPreheader(); - if (!Preheader) break; + // Only move the insertion point up when it is not a division by zero. + if (Opcode != Instruction::UDiv || !mayBeValueZero(RHS)) { + // Move the insertion point out of as many loops as we can. + while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { + if (!L->isLoopInvariant(LHS) || !L->isLoopInvariant(RHS)) break; + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) break; - // Ok, move up a level. - Builder.SetInsertPoint(Preheader->getTerminator()); + // Ok, move up a level. + Builder.SetInsertPoint(Preheader->getTerminator()); + } } // If we haven't found this binop, insert it. @@ -1663,31 +1676,46 @@ // Compute an insertion point for this SCEV object. Hoist the instructions // as far out in the loop nest as possible. Instruction *InsertPt = &*Builder.GetInsertPoint(); - for (Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock());; - L = L->getParentLoop()) - if (SE.isLoopInvariant(S, L)) { - if (!L) break; - if (BasicBlock *Preheader = L->getLoopPreheader()) - InsertPt = Preheader->getTerminator(); - else { - // LSR sets the insertion point for AddRec start/step values to the - // block start to simplify value reuse, even though it's an invalid - // position. SCEVExpander must correct for this in all cases. - InsertPt = &*L->getHeader()->getFirstInsertionPt(); - } - } else { - // If the SCEV is computable at this level, insert it into the header - // after the PHIs (and after any other instructions that we've inserted - // there) so that it is guaranteed to dominate any user inside the loop. - if (L && SE.hasComputableLoopEvolution(S, L) && !PostIncLoops.count(L)) - InsertPt = &*L->getHeader()->getFirstInsertionPt(); - while (InsertPt->getIterator() != Builder.GetInsertPoint() && - (isInsertedInstruction(InsertPt) || - isa(InsertPt))) { - InsertPt = &*std::next(InsertPt->getIterator()); + if (!SCEVExprContains(S, [](const SCEV *S) { + if (const auto *D = dyn_cast(S)) { + if (const auto *SC = dyn_cast(D->getRHS())) + if (!SC->getValue()->isZero()) + // Division by non-zero constants can be hoisted. + return false; + + // All other divisions should not be moved as they may be divisions by + // zero and should be kept within the conditions of the surrounding + // loops that guard their execution (see PR30935.) + return true; + } + return false; + })) { + for (Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock());; + L = L->getParentLoop()) + if (SE.isLoopInvariant(S, L)) { + if (!L) break; + if (BasicBlock *Preheader = L->getLoopPreheader()) + InsertPt = Preheader->getTerminator(); + else { + // LSR sets the insertion point for AddRec start/step values to the + // block start to simplify value reuse, even though it's an invalid + // position. SCEVExpander must correct for this in all cases. + InsertPt = &*L->getHeader()->getFirstInsertionPt(); + } + } else { + // If the SCEV is computable at this level, insert it into the header + // after the PHIs (and after any other instructions that we've inserted + // there) so that it is guaranteed to dominate any user inside the loop. + if (L && SE.hasComputableLoopEvolution(S, L) && !PostIncLoops.count(L)) + InsertPt = &*L->getHeader()->getFirstInsertionPt(); + while (InsertPt->getIterator() != Builder.GetInsertPoint() && + (isInsertedInstruction(InsertPt) || + isa(InsertPt))) { + InsertPt = &*std::next(InsertPt->getIterator()); + } + break; } - break; - } + } // Check to see if we already expanded this here. auto I = InsertedExpressions.find(std::make_pair(S, InsertPt)); Index: llvm/test/Transforms/LoopIdiom/pr30935.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LoopIdiom/pr30935.ll @@ -0,0 +1,94 @@ +; RUN: opt -loop-idiom -S < %s | FileCheck %s + +; CHECK-LABEL: define i32 @main( +; CHECK: udiv +; CHECK-NOT: udiv +; CHECK: call void @llvm.memset.p0i8.i64 + +@a = common local_unnamed_addr global [4 x i8] zeroinitializer, align 1 +@b = common local_unnamed_addr global i32 0, align 4 +@c = common local_unnamed_addr global i32 0, align 4 +@d = common local_unnamed_addr global i32 0, align 4 +@e = common local_unnamed_addr global i32 0, align 4 +@f = common local_unnamed_addr global i32 0, align 4 +@g = common local_unnamed_addr global i32 0, align 4 +@h = common local_unnamed_addr global i64 0, align 8 + + +define i32 @main() local_unnamed_addr #0 { +entry: + %0 = load i32, i32* @e, align 4 + %tobool19 = icmp eq i32 %0, 0 + %1 = load i32, i32* @c, align 4 + %cmp10 = icmp eq i32 %1, 0 + %2 = load i32, i32* @g, align 4 + %3 = load i32, i32* @b, align 4 + %tobool = icmp eq i32 %0, 0 + br label %for.cond + +for.cond.loopexit: ; preds = %for.inc14 + br label %for.cond.backedge + +for.cond: ; preds = %for.cond.backedge, %entry + %.pr = load i32, i32* @f, align 4 + %cmp20 = icmp eq i32 %.pr, 0 + br i1 %cmp20, label %for.cond2.preheader.preheader, label %for.cond.backedge + +for.cond.backedge: ; preds = %for.cond, %for.cond.loopexit + br label %for.cond + +for.cond2.preheader.preheader: ; preds = %for.cond + br label %for.cond2.preheader + +for.cond2.preheader: ; preds = %for.cond2.preheader.preheader, %for.inc14 + br i1 %tobool19, label %for.cond9, label %for.body3.lr.ph + +for.body3.lr.ph: ; preds = %for.cond2.preheader + %div = udiv i32 %2, %3 + %conv = zext i32 %div to i64 + br label %for.body3 + +for.cond4.for.cond2.loopexit_crit_edge: ; preds = %for.body6 + store i32 0, i32* @d, align 4 + br label %for.cond2.loopexit + +for.cond2.loopexit: ; preds = %for.cond4.for.cond2.loopexit_crit_edge, %for.body3 + br i1 %tobool, label %for.cond2.for.cond9_crit_edge, label %for.body3 + +for.body3: ; preds = %for.body3.lr.ph, %for.cond2.loopexit + %.pr17 = load i32, i32* @d, align 4 + %tobool518 = icmp eq i32 %.pr17, 0 + br i1 %tobool518, label %for.cond2.loopexit, label %for.body6.preheader + +for.body6.preheader: ; preds = %for.body3 + %4 = zext i32 %.pr17 to i64 + br label %for.body6 + +for.body6: ; preds = %for.body6.preheader, %for.body6 + %indvars.iv = phi i64 [ %4, %for.body6.preheader ], [ %indvars.iv.next, %for.body6 ] + %add = add nuw nsw i64 %conv, %indvars.iv + %arrayidx = getelementptr inbounds [4 x i8], [4 x i8]* @a, i64 0, i64 %add + store i8 1, i8* %arrayidx, align 1 + %5 = trunc i64 %indvars.iv to i32 + %inc = add i32 %5, 1 + %tobool5 = icmp eq i32 %inc, 0 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + br i1 %tobool5, label %for.cond4.for.cond2.loopexit_crit_edge, label %for.body6 + +for.cond2.for.cond9_crit_edge: ; preds = %for.cond2.loopexit + store i64 %conv, i64* @h, align 8 + br label %for.cond9 + +for.cond9: ; preds = %for.cond2.for.cond9_crit_edge, %for.cond2.preheader + br i1 %cmp10, label %for.body12, label %for.inc14 + +for.body12: ; preds = %for.cond9 + ret i32 0 + +for.inc14: ; preds = %for.cond9 + %6 = load i32, i32* @f, align 4 + %inc15 = add i32 %6, 1 + store i32 %inc15, i32* @f, align 4 + %cmp = icmp eq i32 %inc15, 0 + br i1 %cmp, label %for.cond2.preheader, label %for.cond.loopexit +} Index: llvm/unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -349,6 +349,13 @@ llvm_unreachable("Expected to find instruction!"); } +static Argument *getArgByName(Function &F, StringRef Name) { + for (auto &A : F.args()) + if (A.getName() == Name) + return &A; + llvm_unreachable("Expected to find argument!"); +} + TEST_F(ScalarEvolutionsTest, CommutativeExprOperandOrder) { LLVMContext C; SMDiagnostic Err; @@ -532,5 +539,74 @@ EXPECT_NE(nullptr, SE.getSCEV(Acc[0])); } +TEST_F(ScalarEvolutionsTest, BadHoistingSCEVExpander_PR30942) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" " + " " + "define void @f_1(i32 %x, i32 %y, i32 %n, i1* %cond_buf) " + " local_unnamed_addr { " + "entry: " + " %entrycond = icmp sgt i32 %n, 0 " + " br i1 %entrycond, label %loop.ph, label %for.end " + " " + "loop.ph: " + " br label %loop " + " " + "loop: " + " %iv1 = phi i32 [ %iv1.inc, %right ], [ 0, %loop.ph ] " + " %iv1.inc = add nuw nsw i32 %iv1, 1 " + " %cond = load volatile i1, i1* %cond_buf " + " br i1 %cond, label %left, label %right " + " " + "left: " + " %div = udiv i32 %x, %y " + " br label %right " + " " + "right: " + " %exitcond = icmp eq i32 %iv1.inc, %n " + " br i1 %exitcond, label %for.end.loopexit, label %loop " + " " + "for.end.loopexit: " + " br label %for.end " + " " + "for.end: " + " ret void " + "} ", + Err, C); + + assert(M && "Could not parse module?"); + assert(!verifyModule(*M) && "Must have been well formed!"); + + runWithFunctionAndSE(*M, "f_1", [&](Function &F, ScalarEvolution &SE) { + SCEVExpander Expander(SE, M->getDataLayout(), "unittests"); + auto *DivInst = getInstructionByName(F, "div"); + + { + auto *DivSCEV = SE.getSCEV(DivInst); + auto *DivExpansion = Expander.expandCodeFor( + DivSCEV, DivSCEV->getType(), DivInst->getParent()->getTerminator()); + auto *DivExpansionInst = dyn_cast(DivExpansion); + ASSERT_NE(DivExpansionInst, nullptr); + EXPECT_EQ(DivInst->getParent(), DivExpansionInst->getParent()); + } + + { + auto *ArgY = getArgByName(F, "y"); + auto *DivFromScratchSCEV = + SE.getUDivExpr(SE.getOne(ArgY->getType()), SE.getSCEV(ArgY)); + + auto *DivFromScratchExpansion = Expander.expandCodeFor( + DivFromScratchSCEV, DivFromScratchSCEV->getType(), + DivInst->getParent()->getTerminator()); + auto *DivFromScratchExpansionInst = + dyn_cast(DivFromScratchExpansion); + ASSERT_NE(DivFromScratchExpansionInst, nullptr); + EXPECT_EQ(DivInst->getParent(), DivFromScratchExpansionInst->getParent()); + } + }); +} + } // end anonymous namespace } // end namespace llvm