diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h --- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -32,14 +32,14 @@ /// all materialized values are safe to speculate anywhere their operands are /// defined, and the expander is capable of expanding the expression. /// CanonicalMode indicates whether the expander will be used in canonical mode. -bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE, - bool CanonicalMode = true); +bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE, bool CanonicalMode); /// Return true if the given expression is safe to expand in the sense that /// all materialized values are defined and safe to speculate at the specified /// location and their operands are defined at this location. +/// CanonicalMode indicates whether the expander will be used in canonical mode. bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint, - ScalarEvolution &SE); + ScalarEvolution &SE, bool CanonicalMode); /// struct for holding enough information to help calculate the cost of the /// given SCEV when expanded into IR. @@ -270,6 +270,16 @@ SmallVectorImpl &DeadInsts, const TargetTransformInfo *TTI = nullptr); + /// Return true if the given expression is safe to expand in the sense that + /// all materialized values are safe to speculate anywhere their operands are + /// defined, and the expander is capable of expanding the expression. + bool isSafeToExpand(const SCEV *S) const; + + /// Return true if the given expression is safe to expand in the sense that + /// all materialized values are defined and safe to speculate at the specified + /// location and their operands are defined at this location. + bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint) const; + /// Insert code to directly compute the specified SCEV expression into the /// program. The code is inserted into the specified block. Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I) { diff --git a/llvm/lib/CodeGen/HardwareLoops.cpp b/llvm/lib/CodeGen/HardwareLoops.cpp --- a/llvm/lib/CodeGen/HardwareLoops.cpp +++ b/llvm/lib/CodeGen/HardwareLoops.cpp @@ -407,13 +407,13 @@ BasicBlock *Predecessor = BB->getSinglePredecessor(); // If it's not safe to create a while loop then don't force it and create a // do-while loop instead - if (!isSafeToExpandAt(ExitCount, Predecessor->getTerminator(), SE)) + if (!SCEVE.isSafeToExpandAt(ExitCount, Predecessor->getTerminator())) UseLoopGuard = false; else BB = Predecessor; } - if (!isSafeToExpandAt(ExitCount, BB->getTerminator(), SE)) { + if (!SCEVE.isSafeToExpandAt(ExitCount, BB->getTerminator())) { LLVM_DEBUG(dbgs() << "- Bailing, unsafe to expand ExitCount " << *ExitCount << "\n"); return nullptr; diff --git a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp --- a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp +++ b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp @@ -568,7 +568,7 @@ const SCEVAddRecExpr *BasePtrSCEV = cast(BaseSCEV); // Make sure the base is able to expand. - if (!isSafeToExpand(BasePtrSCEV->getStart(), *SE)) + if (!SCEVE.isSafeToExpand(BasePtrSCEV->getStart())) return MadeChange; assert(BasePtrSCEV->isAffine() && @@ -602,7 +602,7 @@ // Make sure offset is able to expand. Only need to check one time as the // offsets are reused between different chains. if (!BaseElemIdx) - if (!isSafeToExpand(OffsetSCEV, *SE)) + if (!SCEVE.isSafeToExpand(OffsetSCEV)) return false; Value *OffsetValue = SCEVE.expandCodeFor( @@ -1018,14 +1018,13 @@ if (!BasePtrSCEV->isAffine()) return MadeChange; - if (!isSafeToExpand(BasePtrSCEV->getStart(), *SE)) - return MadeChange; - - SmallPtrSet DeletedPtrs; - BasicBlock *Header = L->getHeader(); SCEVExpander SCEVE(*SE, Header->getModule()->getDataLayout(), "loopprepare-formrewrite"); + if (!SCEVE.isSafeToExpand(BasePtrSCEV->getStart())) + return MadeChange; + + SmallPtrSet DeletedPtrs; // For some DS form load/store instructions, it can also be an update form, // if the stride is constant and is a multipler of 4. Use update form if diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -1738,7 +1738,7 @@ // through *explicit* control flow. We have to eliminate the possibility of // implicit exits (see below) before we know it's truly exact. const SCEV *ExactBTC = SE->getBackedgeTakenCount(L); - if (isa(ExactBTC) || !isSafeToExpand(ExactBTC, *SE)) + if (isa(ExactBTC) || !Rewriter.isSafeToExpand(ExactBTC)) return false; assert(SE->isLoopInvariant(ExactBTC, L) && "BTC must be loop invariant"); @@ -1769,7 +1769,8 @@ return true; const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); - if (isa(ExitCount) || !isSafeToExpand(ExitCount, *SE)) + if (isa(ExitCount) || + !Rewriter.isSafeToExpand(ExitCount)) return true; assert(SE->isLoopInvariant(ExitCount, L) && diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp --- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -1451,7 +1451,7 @@ return false; } - if (!isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt, SE)) { + if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) { LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" << " preloop exit limit " << *ExitPreLoopAtSCEV << " at block " << InsertPt->getParent()->getName() @@ -1478,7 +1478,7 @@ return false; } - if (!isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt, SE)) { + if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) { LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" << " main loop exit limit " << *ExitMainLoopAtSCEV << " at block " << InsertPt->getParent()->getName() diff --git a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp --- a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -388,15 +388,15 @@ if (!isStrideLargeEnough(P.LSCEVAddRec, TargetMinStride)) continue; + BasicBlock *BB = P.InsertPt->getParent(); + SCEVExpander SCEVE(*SE, BB->getModule()->getDataLayout(), "prefaddr"); const SCEV *NextLSCEV = SE->getAddExpr(P.LSCEVAddRec, SE->getMulExpr( SE->getConstant(P.LSCEVAddRec->getType(), ItersAhead), P.LSCEVAddRec->getStepRecurrence(*SE))); - if (!isSafeToExpand(NextLSCEV, *SE)) + if (!SCEVE.isSafeToExpand(NextLSCEV)) continue; - BasicBlock *BB = P.InsertPt->getParent(); Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), 0/*PtrAddrSpace*/); - SCEVExpander SCEVE(*SE, BB->getModule()->getDataLayout(), "prefaddr"); Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt); IRBuilder<> Builder(P.InsertPt); diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -1129,7 +1129,7 @@ // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. - if (!isSafeToExpand(Start, *SE)) + if (!Expander.isSafeToExpand(Start)) return Changed; // Okay, we have a strided store "p[i]" of a splattable value. We can turn @@ -1163,7 +1163,7 @@ // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. - if (!isSafeToExpand(NumBytesS, *SE)) + if (!Expander.isSafeToExpand(NumBytesS)) return Changed; Value *NumBytes = diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp --- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -523,7 +523,8 @@ // evaluate outside the loop, which is what we actually need here. for (const SCEV *Op : Ops) if (!SE->isLoopInvariant(Op, L) || - !isSafeToExpandAt(Op, Preheader->getTerminator(), *SE)) + !isSafeToExpandAt(Op, Preheader->getTerminator(), *SE, + /* CanonicalMode */ true)) return Use; return Preheader->getTerminator(); } @@ -589,8 +590,8 @@ LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); return None; } - if (!isSafeToExpandAt(LatchStart, Guard, *SE) || - !isSafeToExpandAt(LatchLimit, Guard, *SE)) { + if (!Expander.isSafeToExpandAt(LatchStart, Guard) || + !Expander.isSafeToExpandAt(LatchLimit, Guard)) { LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); return None; } @@ -632,8 +633,8 @@ LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); return None; } - if (!isSafeToExpandAt(LatchStart, Guard, *SE) || - !isSafeToExpandAt(LatchLimit, Guard, *SE)) { + if (!Expander.isSafeToExpandAt(LatchStart, Guard) || + !Expander.isSafeToExpandAt(LatchLimit, Guard)) { LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); return None; } @@ -1159,7 +1160,7 @@ const SCEV *MinEC = getMinAnalyzeableBackedgeTakenCount(*SE, *DT, L); if (isa(MinEC) || MinEC->getType()->isPointerTy() || !SE->isLoopInvariant(MinEC, L) || - !isSafeToExpandAt(MinEC, WidenableBR, *SE)) + !Rewriter.isSafeToExpandAt(MinEC, WidenableBR)) return ChangedLoop; // Subtlety: We need to avoid inserting additional uses of the WC. We know @@ -1198,7 +1199,7 @@ const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); if (isa(ExitCount) || ExitCount->getType()->isPointerTy() || - !isSafeToExpandAt(ExitCount, WidenableBR, *SE)) + !Rewriter.isSafeToExpandAt(ExitCount, WidenableBR)) continue; const bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -3335,7 +3335,8 @@ // x == y --> x - y == 0 const SCEV *N = SE.getSCEV(NV); - if (SE.isLoopInvariant(N, L) && isSafeToExpand(N, SE) && + if (SE.isLoopInvariant(N, L) && + isSafeToExpand(N, SE, /* CanonicalMode */ false) && (!NV->getType()->isPointerTy() || SE.getPointerBase(N) == SE.getPointerBase(S))) { // S is normalized, so normalize N before folding it into S diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1357,7 +1357,7 @@ const SCEV *ExitValue = SE->getSCEVAtScope(Inst, L->getParentLoop()); if (isa(ExitValue) || !SE->isLoopInvariant(ExitValue, L) || - !isSafeToExpand(ExitValue, *SE)) { + !Rewriter.isSafeToExpand(ExitValue)) { // TODO: This should probably be sunk into SCEV in some way; maybe a // getSCEVForExit(SCEV*, L, ExitingBB)? It can be generalized for // most SCEV expressions and other recurrence types (e.g. shift @@ -1370,7 +1370,7 @@ ExitValue = AddRec->evaluateAtIteration(ExitCount, *SE); if (isa(ExitValue) || !SE->isLoopInvariant(ExitValue, L) || - !isSafeToExpand(ExitValue, *SE)) + !Rewriter.isSafeToExpand(ExitValue)) continue; } diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -2557,6 +2557,15 @@ return User->getOperand(OpIdx); } +bool SCEVExpander::isSafeToExpand(const SCEV *S) const { + return llvm::isSafeToExpand(S, SE, CanonicalMode); +} + +bool SCEVExpander::isSafeToExpandAt(const SCEV *S, + const Instruction *InsertionPoint) const { + return llvm::isSafeToExpandAt(S, InsertionPoint, SE, CanonicalMode); +} + namespace { // Search for a SCEV subexpression that is not safe to expand. Any expression // that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely @@ -2623,8 +2632,8 @@ } bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint, - ScalarEvolution &SE) { - if (!isSafeToExpand(S, SE)) + ScalarEvolution &SE, bool CanonicalMode) { + if (!isSafeToExpand(S, SE, CanonicalMode)) return false; // We have to prove that the expanded site of S dominates InsertionPoint. // This is easy when not in the same block, but hard when S is an instruction diff --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp --- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -660,7 +660,7 @@ auto *IP = GetLoopInvariantInsertPosition(L, I); - if (!isSafeToExpandAt(S, IP, *SE)) { + if (!Rewriter.isSafeToExpandAt(S, IP)) { LLVM_DEBUG(dbgs() << "INDVARS: Can not replace IV user: " << *I << " with non-speculable loop invariant: " << *S << '\n'); return false; diff --git a/llvm/test/Transforms/IndVarSimplify/pr50506.ll b/llvm/test/Transforms/IndVarSimplify/pr50506.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/IndVarSimplify/pr50506.ll @@ -0,0 +1,43 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -indvars < %s | FileCheck %s + +; This test used to assert when expanding an addrec into a loop without +; preheader. + +define void @test(ptr %tgt) { +; CHECK-LABEL: @test( +; CHECK-NEXT: bb31: +; CHECK-NEXT: indirectbr ptr [[TGT:%.*]], [label [[EXIT:%.*]], label %bb33] +; CHECK: exit: +; CHECK-NEXT: ret void +; CHECK: bb33: +; CHECK-NEXT: [[TMP34:%.*]] = phi i32 [ [[TMP50:%.*]], [[BB49:%.*]] ], [ 0, [[BB31:%.*]] ] +; CHECK-NEXT: br i1 false, label [[BB40_PREHEADER:%.*]], label [[BB49]] +; CHECK: bb40.preheader: +; CHECK-NEXT: br label [[BB40:%.*]] +; CHECK: bb40: +; CHECK-NEXT: br label [[BB40]] +; CHECK: bb49: +; CHECK-NEXT: [[TMP50]] = add i32 [[TMP34]], 1 +; CHECK-NEXT: br label [[BB33:%.*]] +; +bb31: + indirectbr ptr %tgt, [label %exit, label %bb33] + +exit: + ret void + +bb33: ; preds = %bb49, %bb31 + %tmp34 = phi i32 [ %tmp50, %bb49 ], [ 0, %bb31 ] + %tmp36 = add i32 %tmp34, 1 + br i1 false, label %bb40, label %bb49 + +bb40: ; preds = %bb38, %bb37 + %tmp41 = phi i32 [ %tmp36, %bb33 ], [ %tmp39, %bb40 ] + %tmp39 = add i32 %tmp41, 0 + br label %bb40 + +bb49: ; preds = %bb35 + %tmp50 = add i32 %tmp34, 1 + br label %bb33 +} diff --git a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp --- a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp +++ b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp @@ -259,17 +259,17 @@ Instruction *Ret = Builder.CreateRetVoid(); ScalarEvolution SE = buildSE(*F); + SCEVExpander Exp(SE, M.getDataLayout(), "expander"); const SCEV *S = SE.getSCEV(Phi); EXPECT_TRUE(isa(S)); const SCEVAddRecExpr *AR = cast(S); EXPECT_TRUE(AR->isAffine()); - EXPECT_FALSE(isSafeToExpandAt(AR, Top->getTerminator(), SE)); - EXPECT_FALSE(isSafeToExpandAt(AR, LPh->getTerminator(), SE)); - EXPECT_TRUE(isSafeToExpandAt(AR, L->getTerminator(), SE)); - EXPECT_TRUE(isSafeToExpandAt(AR, Post->getTerminator(), SE)); + EXPECT_FALSE(Exp.isSafeToExpandAt(AR, Top->getTerminator())); + EXPECT_FALSE(Exp.isSafeToExpandAt(AR, LPh->getTerminator())); + EXPECT_TRUE(Exp.isSafeToExpandAt(AR, L->getTerminator())); + EXPECT_TRUE(Exp.isSafeToExpandAt(AR, Post->getTerminator())); EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT)); - SCEVExpander Exp(SE, M.getDataLayout(), "expander"); Exp.expandCodeFor(SE.getSCEV(Add), nullptr, Ret); EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT)); }