diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1216,13 +1216,19 @@ // Collect information about PHI nodes which can be transformed in // rewriteLoopExitValues. struct RewritePhi { - PHINode *PN; - unsigned Ith; // Ith incoming value. - Value *Val; // Exit value after expansion. - bool HighCost; // High Cost when expansion. - - RewritePhi(PHINode *P, unsigned I, Value *V, bool H) - : PN(P), Ith(I), Val(V), HighCost(H) {} + PHINode *PN; // For which PHI node is this replacement? + unsigned Ith; // For which incoming value? + const SCEV *ExpansionSCEV; // The SCEV of the incoming value we are rewriting. + Instruction *ExpansionPoint; // Where we'd like to expand that SCEV? + bool HighCost; // Is this expansion a high-cost? + + Value *Expansion = nullptr; + bool ValidRewrite = false; + + RewritePhi(PHINode *P, unsigned I, const SCEV *Val, Instruction *ExpansionPt, + bool H) + : PN(P), Ith(I), ExpansionSCEV(Val), ExpansionPoint(ExpansionPt), + HighCost(H) {} }; // Check whether it is possible to delete the loop after rewriting exit @@ -1255,6 +1261,8 @@ // phase later. Skip it in the loop invariant check below. bool found = false; for (const RewritePhi &Phi : RewritePhiSet) { + if (!Phi.ValidRewrite) + continue; unsigned i = Phi.Ith; if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) { found = true; @@ -1372,42 +1380,66 @@ !isa(ExitValue) && hasHardUserWithinLoop(L, Inst)) continue; + // Check if expansions of this SCEV would count as being high cost. bool HighCost = Rewriter.isHighCostExpansion( ExitValue, L, SCEVCheapExpansionBudget, TTI, Inst); - Value *ExitVal = Rewriter.expandCodeFor(ExitValue, PN->getType(), Inst); - - LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = " - << *ExitVal << '\n' << " LoopVal = " << *Inst - << "\n"); - - if (!isValidRewrite(SE, Inst, ExitVal)) { - DeadInsts.push_back(ExitVal); - continue; - } -#ifndef NDEBUG - // If we reuse an instruction from a loop which is neither L nor one of - // its containing loops, we end up breaking LCSSA form for this loop by - // creating a new use of its instruction. - if (auto *ExitInsn = dyn_cast(ExitVal)) - if (auto *EVL = LI->getLoopFor(ExitInsn->getParent())) - if (EVL != L) - assert(EVL->contains(L) && "LCSSA breach detected!"); -#endif + // Note that we must not perform expansions until after + // we query *all* the costs, because if we perform temporary expansion + // inbetween, one that we might not intend to keep, said expansion + // *may* affect cost calculation of the the next SCEV's we'll query, + // and next SCEV may errneously get smaller cost. // Collect all the candidate PHINodes to be rewritten. - RewritePhiSet.emplace_back(PN, i, ExitVal, HighCost); + RewritePhiSet.emplace_back(PN, i, ExitValue, Inst, HighCost); } } } + // Now that we've done preliminary filtering and billed all the SCEV's, + // we can perform the last sanity check - the expansion must be valid. + for (RewritePhi &Phi : RewritePhiSet) { + Phi.Expansion = Rewriter.expandCodeFor(Phi.ExpansionSCEV, Phi.PN->getType(), + Phi.ExpansionPoint); + + LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = " + << *(Phi.Expansion) << '\n' + << " LoopVal = " << *(Phi.ExpansionPoint) << "\n"); + + // FIXME: isValidRewrite() is a hack. it should be an assert, eventually. + Phi.ValidRewrite = isValidRewrite(SE, Phi.ExpansionPoint, Phi.Expansion); + if (!Phi.ValidRewrite) { + DeadInsts.push_back(Phi.Expansion); + continue; + } + +#ifndef NDEBUG + // If we reuse an instruction from a loop which is neither L nor one of + // its containing loops, we end up breaking LCSSA form for this loop by + // creating a new use of its instruction. + if (auto *ExitInsn = dyn_cast(Phi.Expansion)) + if (auto *EVL = LI->getLoopFor(ExitInsn->getParent())) + if (EVL != L) + assert(EVL->contains(L) && "LCSSA breach detected!"); +#endif + } + + // TODO: after isValidRewrite() is an assertion, evaluate whether + // it is beneficial to change how we calculate high-cost: + // if we have SCEV 'A' which we know we will expand, should we calculate + // the cost of other SCEV's after expanding SCEV 'A', + // thus potentially giving cost bonus to those other SCEV's? + bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet); int NumReplaced = 0; // Transformation. for (const RewritePhi &Phi : RewritePhiSet) { + if (!Phi.ValidRewrite) + continue; + PHINode *PN = Phi.PN; - Value *ExitVal = Phi.Val; + Value *ExitVal = Phi.Expansion; // Only do the rewrite when the ExitValue can be expanded cheaply. // If LoopCanBeDel is true, rewrite exit value aggressively. diff --git a/llvm/test/Transforms/IndVarSimplify/pr45835.ll b/llvm/test/Transforms/IndVarSimplify/pr45835.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/IndVarSimplify/pr45835.ll @@ -0,0 +1,38 @@ +; RUN: opt < %s -indvars -replexitval=always -S | FileCheck %s --check-prefix=ALWAYS +; RUN: opt < %s -indvars -replexitval=never -S | FileCheck %s --check-prefix=NEVER +; RUN: opt < %s -indvars -replexitval=cheap -scev-cheap-expansion-budget=1 -S | FileCheck %s --check-prefix=CHEAP + +; rewriteLoopExitValues() must rewrite all or none of a PHI's values from a given block. + +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" + +@a = common global i8 0, align 1 + +define internal fastcc void @d(i8* %c) unnamed_addr #0 { +entry: + %cmp = icmp ule i8* %c, getelementptr inbounds (i8, i8* @a, i64 65535) + %add.ptr = getelementptr inbounds i8, i8* %c, i64 -65535 + br label %while.cond + +while.cond: + br i1 icmp ne (i8 0, i8 0), label %cont, label %while.end + +cont: + %a.mux = select i1 %cmp, i8* @a, i8* %add.ptr + switch i64 0, label %while.cond [ + i64 -1, label %handler.pointer_overflow.i + i64 0, label %handler.pointer_overflow.i + ] + +handler.pointer_overflow.i: + %a.mux.lcssa4 = phi i8* [ %a.mux, %cont ], [ %a.mux, %cont ] +; ALWAYS: [ %scevgep, %cont ], [ %scevgep, %cont ] +; NEVER: [ %a.mux, %cont ], [ %a.mux, %cont ] +; In cheap mode, use either one as long as it's consistent. +; CHEAP: [ %[[VAL:.*]], %cont ], [ %[[VAL]], %cont ] + %x5 = ptrtoint i8* %a.mux.lcssa4 to i64 + br label %while.end + +while.end: + ret void +}