diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp --- a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp @@ -2196,20 +2196,27 @@ // UDivExpr is very likely a UDiv that ScalarEvolution's HowFarToZero or // HowManyLessThans produced to compute a precise expression, rather than a // UDiv from the user's code. If we can't find a UDiv in the code with some - // simple searching, assume the former consider UDivExpr expensive to - // compute. + // simple searching, we need to account for it's cost. + BasicBlock *ExitingBB = L->getExitingBlock(); - if (!ExitingBB) - return true; + if (At || ExitingBB) { + if (!At) + At = &ExitingBB->back(); + + // At the beginning of this function we already tried to find existing + // value for plain 'S'. Now try to lookup 'S + 1' since it is common + // pattern involving division. This is just a simple search heuristic. + if (getRelatedExistingExpansion( + SE.getAddExpr(S, SE.getConstant(S->getType(), 1)), At, L)) + return false; // Consider it to be free. + } - // At the beginning of this function we already tried to find existing value - // for plain 'S'. Now try to lookup 'S + 1' since it is common pattern - // involving division. This is just a simple search heuristic. - if (!At) - At = &ExitingBB->back(); - if (!getRelatedExistingExpansion( - SE.getAddExpr(S, SE.getConstant(S->getType(), 1)), At, L)) - return true; + // Need to count the cost of this UDiv. + BudgetRemaining -= TTI.getOperationCost(Instruction::UDiv, S->getType()); + return isHighCostExpansionHelper(UDivExpr->getLHS(), L, At, BudgetRemaining, + TTI, Processed) || + isHighCostExpansionHelper(UDivExpr->getRHS(), L, At, BudgetRemaining, + TTI, Processed); } // HowManyLessThans uses a Max expression whenever the loop is not guarded by diff --git a/llvm/test/Transforms/IndVarSimplify/exit_value_test2.ll b/llvm/test/Transforms/IndVarSimplify/exit_value_test2.ll --- a/llvm/test/Transforms/IndVarSimplify/exit_value_test2.ll +++ b/llvm/test/Transforms/IndVarSimplify/exit_value_test2.ll @@ -19,6 +19,9 @@ ; CHECK-NEXT: [[CMP8:%.*]] = icmp ugt i32 [[LEN:%.*]], 11 ; CHECK-NEXT: br i1 [[CMP8]], label [[WHILE_BODY_LR_PH:%.*]], label [[WHILE_END:%.*]] ; CHECK: while.body.lr.ph: +; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[LEN]], -12 +; CHECK-NEXT: [[TMP1:%.*]] = udiv i32 [[TMP0]], 12 +; CHECK-NEXT: [[TMP2:%.*]] = mul i32 [[TMP1]], 12 ; CHECK-NEXT: br label [[WHILE_BODY:%.*]] ; CHECK: while.body: ; CHECK-NEXT: [[KEYLEN_010:%.*]] = phi i32 [ [[LEN]], [[WHILE_BODY_LR_PH]] ], [ [[SUB:%.*]], [[WHILE_BODY]] ] @@ -36,10 +39,10 @@ ; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[SUB]], 11 ; CHECK-NEXT: br i1 [[CMP]], label [[WHILE_BODY]], label [[WHILE_COND_WHILE_END_CRIT_EDGE:%.*]] ; CHECK: while.cond.while.end_crit_edge: -; CHECK-NEXT: [[SUB_LCSSA:%.*]] = phi i32 [ [[SUB]], [[WHILE_BODY]] ] +; CHECK-NEXT: [[TMP3:%.*]] = sub i32 [[TMP0]], [[TMP2]] ; CHECK-NEXT: br label [[WHILE_END]] ; CHECK: while.end: -; CHECK-NEXT: [[KEYLEN_0_LCSSA:%.*]] = phi i32 [ [[SUB_LCSSA]], [[WHILE_COND_WHILE_END_CRIT_EDGE]] ], [ [[LEN]], [[ENTRY:%.*]] ] +; CHECK-NEXT: [[KEYLEN_0_LCSSA:%.*]] = phi i32 [ [[TMP3]], [[WHILE_COND_WHILE_END_CRIT_EDGE]] ], [ [[LEN]], [[ENTRY:%.*]] ] ; CHECK-NEXT: call void @_Z3mixRjj(i32* dereferenceable(4) [[A]], i32 [[KEYLEN_0_LCSSA]]) ; CHECK-NEXT: [[T4:%.*]] = load i32, i32* [[A]], align 4 ; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 4, i8* [[T]])