diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h --- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -180,6 +180,23 @@ ChainedPhis.clear(); } + /// Return a vector containing all instructions inserted during expansion. + SmallVector getAllInsertedInstructions() const { + SmallVector Result; + for (auto &VH : InsertedValues) { + Value *V = VH; + if (auto *Inst = dyn_cast(V)) + Result.push_back(Inst); + } + for (auto &VH : InsertedPostIncValues) { + Value *V = VH; + if (auto *Inst = dyn_cast(V)) + Result.push_back(Inst); + } + + return Result; + } + /// Return true for expressions that can't be evaluated at runtime /// within given \b Budget. /// @@ -452,6 +469,27 @@ /// If no PHIs have been created, return the unchanged operand \p OpIdx. Value *fixupLCSSAFormFor(Instruction *User, unsigned OpIdx); }; + +/// Helper to remove instructions inserted during SCEV expansion, unless they +/// are marked as used. +class SCEVExpanderCleaner { + SCEVExpander &Expander; + + DominatorTree &DT; + + /// Indicates whether the result of the expansion is used. If false, the + /// instructions added during expansion are removed. + bool ResultUsed; + +public: + SCEVExpanderCleaner(SCEVExpander &Expander, DominatorTree &DT) + : Expander(Expander), DT(DT), ResultUsed(false) {} + + ~SCEVExpanderCleaner(); + + /// Indicate that the result of the expansion is used. + void markResultUsed() { ResultUsed = true; } +}; } // namespace llvm #endif diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -295,31 +295,6 @@ I->eraseFromParent(); } -namespace { -class ExpandedValuesCleaner { - SCEVExpander &Expander; - TargetLibraryInfo *TLI; - SmallVector ExpandedValues; - bool Commit = false; - -public: - ExpandedValuesCleaner(SCEVExpander &Expander, TargetLibraryInfo *TLI) - : Expander(Expander), TLI(TLI) {} - - void add(Value *V) { ExpandedValues.push_back(V); } - - void commit() { Commit = true; } - - ~ExpandedValuesCleaner() { - if (!Commit) { - Expander.clear(); - for (auto *V : ExpandedValues) - RecursivelyDeleteTriviallyDeadInstructions(V, TLI); - } - } -}; -} // namespace - //===----------------------------------------------------------------------===// // // Implementation of LoopIdiomRecognize @@ -933,7 +908,7 @@ BasicBlock *Preheader = CurLoop->getLoopPreheader(); IRBuilder<> Builder(Preheader->getTerminator()); SCEVExpander Expander(*SE, *DL, "loop-idiom"); - ExpandedValuesCleaner EVC(Expander, TLI); + SCEVExpanderCleaner ExpCleaner(Expander, *DT); Type *DestInt8PtrTy = Builder.getInt8PtrTy(DestAS); Type *IntIdxTy = DL->getIndexType(DestPtr->getType()); @@ -956,7 +931,6 @@ // base pointer and checking the region. Value *BasePtr = Expander.expandCodeFor(Start, DestInt8PtrTy, Preheader->getTerminator()); - EVC.add(BasePtr); // From here on out, conservatively report to the pass manager that we've // changed the IR, even if we later clean up these added instructions. There @@ -1041,7 +1015,7 @@ if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); ++NumMemSet; - EVC.commit(); + ExpCleaner.markResultUsed(); return true; } @@ -1075,7 +1049,7 @@ IRBuilder<> Builder(Preheader->getTerminator()); SCEVExpander Expander(*SE, *DL, "loop-idiom"); - ExpandedValuesCleaner EVC(Expander, TLI); + SCEVExpanderCleaner ExpCleaner(Expander, *DT); bool Changed = false; const SCEV *StrStart = StoreEv->getStart(); @@ -1094,7 +1068,6 @@ // checking everything. Value *StoreBasePtr = Expander.expandCodeFor( StrStart, Builder.getInt8PtrTy(StrAS), Preheader->getTerminator()); - EVC.add(StoreBasePtr); // From here on out, conservatively report to the pass manager that we've // changed the IR, even if we later clean up these added instructions. There @@ -1122,7 +1095,6 @@ // mutated by the loop. Value *LoadBasePtr = Expander.expandCodeFor( LdStart, Builder.getInt8PtrTy(LdAS), Preheader->getTerminator()); - EVC.add(LoadBasePtr); if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount, StoreSize, *AA, Stores)) @@ -1138,7 +1110,6 @@ Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator()); - EVC.add(NumBytes); CallInst *NewCall = nullptr; // Check whether to generate an unordered atomic memcpy: @@ -1198,7 +1169,7 @@ if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); ++NumMemCpy; - EVC.commit(); + ExpCleaner.markResultUsed(); return true; } diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -1882,10 +1882,12 @@ // there) so that it is guaranteed to dominate any user inside the loop. if (L && SE.hasComputableLoopEvolution(S, L) && !PostIncLoops.count(L)) InsertPt = &*L->getHeader()->getFirstInsertionPt(); + while (InsertPt->getIterator() != Builder.GetInsertPoint() && (isInsertedInstruction(InsertPt) || - isa(InsertPt))) + isa(InsertPt))) { InsertPt = &*std::next(InsertPt->getIterator()); + } break; } } @@ -2630,4 +2632,40 @@ } return false; } + +SCEVExpanderCleaner::~SCEVExpanderCleaner() { + // Result is used, nothing to remove. + if (ResultUsed) + return; + + auto InsertedInstructions = Expander.getAllInsertedInstructions(); +#ifndef NDEBUG + SmallPtrSet InsertedSet(InsertedInstructions.begin(), + InsertedInstructions.end()); + (void)InsertedSet; +#endif + // Remove sets with value handles. + Expander.clear(); + + // Sort so that earlier instructions do not dominate later instructions. + stable_sort(InsertedInstructions, [this](Instruction *A, Instruction *B) { + return DT.dominates(B, A); + }); + // Remove all inserted instructions. + for (Instruction *I : InsertedInstructions) { + +#ifndef NDEBUG + assert(all_of(I->users(), + [&InsertedSet](Value *U) { + return InsertedSet.contains(cast(U)); + }) && + "removed instruction should only be used by instructions inserted " + "during expansion"); +#endif + assert(!I->getType()->isVoidTy() && + "inserted instruction should have non-void types"); + I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->eraseFromParent(); + } +} }