diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp --- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -38,6 +38,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" @@ -848,6 +849,95 @@ return false; } +using PhiTransPtr = std::pair; +using ExMemoryAccess = std::pair; + +static PhiTransPtr phiTranslatePtr(const PhiTransPtr &Ptr, BasicBlock *FromBB, + BasicBlock *ToBB, const DataLayout &DL, + DominatorTree &DT, AssumptionCache *AC) { + PhiTransPtr ResPtr; + DenseMap Visited; + SmallVector, 16> WorkList; + + if (FromBB == ToBB) { + return Ptr; + } + + WorkList.push_back({FromBB, Ptr}); + + while (!WorkList.empty()) { + auto &CurrNode = WorkList.back(); + WorkList.pop_back(); + BasicBlock *CurrBB = CurrNode.first; + const PhiTransPtr &CurrPtr = CurrNode.second; + + for (pred_iterator PI = pred_begin(CurrBB), E = pred_end(CurrBB); PI != E; + ++PI) { + BasicBlock *PredBB = *PI; + int64_t Offset = 0; + Value *BasePtr = + GetPointerBaseWithConstantOffset(CurrPtr.first, Offset, DL); + Offset += CurrPtr.second; + PHITransAddr TransAddr{BasePtr, DL, AC}; + + // TODO: + if (!DT.dominates(ToBB, PredBB)) + continue; + + if (TransAddr.NeedsPHITranslationFromBlock(CurrBB) && + (!TransAddr.IsPotentiallyPHITranslatable() || + TransAddr.PHITranslateValue(CurrBB, PredBB, &DT, false))) + return PhiTransPtr{}; + + auto Inserted = Visited.try_emplace(PredBB, TransAddr.getAddr(), Offset); + auto &TransPtr = Inserted.first->second; + if (!Inserted.second) { + if (TransAddr.getAddr() != TransPtr.first || + Offset != TransPtr.second) + // We already visited this block before. If it was with a different + // address - bail out! + return PhiTransPtr{}; + continue; + } + + if (PredBB == ToBB) { + ResPtr = TransPtr; + continue; + } + + WorkList.push_back({PredBB, TransPtr}); + } + } + + assert(ResPtr.first.pointsToAliveValue() && + "PHI translation is expected to complete successfully"); + return ResPtr; +} + +static ExMemoryAccess +phiTransFromMemoryAccessTo(const ExMemoryAccess &FromAccess, + MemoryAccess *ToAccess, const DataLayout &DL, + DominatorTree &DT, AssumptionCache *AC) { + PhiTransPtr ResAddr; + + if (FromAccess.second.first.pointsToAliveValue()) { + + ResAddr = phiTranslatePtr(FromAccess.second, FromAccess.first->getBlock(), + ToAccess->getBlock(), DL, DT, AC); + } + + return std::make_pair(ToAccess, ResAddr); +} + +static ExMemoryAccess getExDefiningAccess(const ExMemoryAccess &CurrAccess, + const DataLayout &DL, DominatorTree &DT, + AssumptionCache *AC) { + assert(isa(CurrAccess.first) && "TODO"); + MemoryAccess *Next = + cast(CurrAccess.first)->getDefiningAccess(); + return phiTransFromMemoryAccessTo(CurrAccess, Next, DL, DT, AC); +} + struct DSEState { Function &F; AliasAnalysis &AA; @@ -867,6 +957,7 @@ const TargetLibraryInfo &TLI; const DataLayout &DL; const LoopInfo &LI; + AssumptionCache *AC; // Whether the function contains any irreducible control flow, useful for // being accurately able to detect loops. @@ -895,14 +986,16 @@ DSEState(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT, PostDominatorTree &PDT, const TargetLibraryInfo &TLI, - const LoopInfo &LI) + const LoopInfo &LI, AssumptionCache *AC) : F(F), AA(AA), BatchAA(AA), MSSA(MSSA), DT(DT), PDT(PDT), TLI(TLI), - DL(F.getParent()->getDataLayout()), LI(LI) {} + DL(F.getParent()->getDataLayout()), LI(LI), AC(AC), + ContainsIrreducibleLoops(false) {} static DSEState get(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT, PostDominatorTree &PDT, - const TargetLibraryInfo &TLI, const LoopInfo &LI) { - DSEState State(F, AA, MSSA, DT, PDT, TLI, LI); + const TargetLibraryInfo &TLI, const LoopInfo &LI, + AssumptionCache *AC) { + DSEState State(F, AA, MSSA, DT, PDT, TLI, LI, AC); // Collect blocks with throwing instructions not modeled in MemorySSA and // alloc-like objects. unsigned PO = 0; @@ -1372,8 +1465,8 @@ // such MemoryDef, return None. The returned value may not (completely) // overwrite \p DefLoc. Currently we bail out when we encounter an aliasing // MemoryUse (read). - Optional - getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *StartAccess, + Optional + getDomMemoryDef(MemoryDef *KillingDef, const ExMemoryAccess &ExStartAccess, const MemoryLocation &DefLoc, const Value *DefUO, unsigned &ScanLimit, unsigned &WalkerStepLimit, bool IsMemTerm, unsigned &PartialLimit) { @@ -1382,13 +1475,24 @@ return None; } - MemoryAccess *Current = StartAccess; + MemoryAccess * const StartAccess = ExStartAccess.first; Instruction *KillingI = KillingDef->getMemoryInst(); LLVM_DEBUG(dbgs() << " trying to get dominating access\n"); + ExMemoryAccess CurrentAccess = ExStartAccess; // Find the next clobbering Mod access for DefLoc, starting at StartAccess. Optional CurrentLoc; - for (;; Current = cast(Current)->getDefiningAccess()) { + for (;; CurrentAccess = getExDefiningAccess(CurrentAccess, DL, DT, AC)) { + MemoryAccess *Current = CurrentAccess.first; + int64_t DefOffset = 0; + const Value *ResDefPtr = DefLoc.Ptr; + + if (CurrentAccess.second.first.pointsToAliveValue()) { + ResDefPtr = CurrentAccess.second.first; + DefOffset = CurrentAccess.second.second; + } + + MemoryLocation ResDefLoc(ResDefPtr, DefLoc.Size, DefLoc.AATags); LLVM_DEBUG({ dbgs() << " visiting " << *Current; if (!MSSA.isLiveOnEntryDef(Current) && isa(Current)) @@ -1418,7 +1522,7 @@ // caller is responsible for traversing them. if (isa(Current)) { LLVM_DEBUG(dbgs() << " ... found MemoryPhi\n"); - return Current; + return CurrentAccess; } // Below, check if CurrentDef is a valid candidate to be eliminated by @@ -1533,7 +1637,8 @@ // they cover all paths from EarlierAccess to any function exit. SmallPtrSet KillingDefs; KillingDefs.insert(KillingDef->getMemoryInst()); - MemoryAccess *EarlierAccess = Current; + ExMemoryAccess &ExEarlierAccess = CurrentAccess; + MemoryAccess *EarlierAccess = CurrentAccess.first; Instruction *EarlierMemInst = cast(EarlierAccess)->getMemoryInst(); LLVM_DEBUG(dbgs() << " Checking for reads of " << *EarlierAccess << " (" @@ -1610,9 +1715,25 @@ // Uses which may read the original MemoryDef mean we cannot eliminate the // original MD. Stop walk. - if (isReadClobber(*CurrentLoc, UseInst)) { - LLVM_DEBUG(dbgs() << " ... found read clobber\n"); - return None; + if (UseInst->getParent() == EarlierMemInst->getParent()) { + int64_t DefOffset = 0; + const Value *ResDefPtr = DefLoc.Ptr; + + if (CurrentAccess.second.first.pointsToAliveValue()) { + ResDefPtr = CurrentAccess.second.first; + DefOffset = CurrentAccess.second.second; + } + + MemoryLocation ResDefLoc(ResDefPtr, DefLoc.Size, DefLoc.AATags); + if (isReadClobber(KillingI, ResDefLoc, DefOffset, UseInst)) { + LLVM_DEBUG(dbgs() << " ... found read clobber\n"); + return None; + } + } else { + if (isReadClobber(EarlierMemInst, *CurrentLoc, 0, UseInst)) { + LLVM_DEBUG(dbgs() << " ... found read clobber\n"); + return None; + } } // If this worklist walks back to the original memory access (and the @@ -1685,7 +1806,7 @@ // post-dominates EarlierAccess. if (KillingBlocks.count(CommonPred)) { if (PDT.dominates(CommonPred, EarlierAccess->getBlock())) - return {EarlierAccess}; + return {ExEarlierAccess}; return None; } @@ -1726,14 +1847,14 @@ return None; } NumCFGSuccess++; - return {EarlierAccess}; + return {ExEarlierAccess}; } return None; } // No aliasing MemoryUses of EarlierAccess found, EarlierAccess is // potentially dead. - return {EarlierAccess}; + return {ExEarlierAccess}; } // Delete dead memory defs @@ -1935,10 +2056,11 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT, PostDominatorTree &PDT, const TargetLibraryInfo &TLI, - const LoopInfo &LI) { + const LoopInfo &LI, AssumptionCache *AC) { bool MadeChange = false; - DSEState State = DSEState::get(F, AA, MSSA, DT, PDT, TLI, LI); + const DataLayout &DL = F.getParent()->getDataLayout(); + DSEState State = DSEState::get(F, AA, MSSA, DT, PDT, TLI, LI, AC); // For each store: for (unsigned I = 0; I < State.MemDefs.size(); I++) { MemoryDef *KillingDef = State.MemDefs[I]; @@ -1958,41 +2080,45 @@ << *SI << "\n"); continue; } - MemoryLocation SILoc = *MaybeSILoc; + const MemoryLocation SILoc = *MaybeSILoc; assert(SILoc.Ptr && "SILoc should not be null"); const Value *SILocUnd = getUnderlyingObject(SILoc.Ptr); - MemoryAccess *Current = KillingDef; LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by " - << *Current << " (" << *SI << ")\n"); + << *KillingDef << " (" << *SI << ")\n"); unsigned ScanLimit = MemorySSAScanLimit; unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit; unsigned PartialLimit = MemorySSAPartialStoreLimit; // Worklist of MemoryAccesses that may be killed by KillingDef. - SetVector ToCheck; + SmallVector ToCheck; - if (SILocUnd) - ToCheck.insert(KillingDef->getDefiningAccess()); + if (SILocUnd) { + ExMemoryAccess ExKillingDef{KillingDef, + {const_cast(SILoc.Ptr), 0}}; + ToCheck.push_back(getExDefiningAccess(ExKillingDef, DL, DT, AC)); + } bool Shortend = false; bool IsMemTerm = State.isMemTerminatorInst(SI); // Check if MemoryAccesses in the worklist are killed by KillingDef. - for (unsigned I = 0; I < ToCheck.size(); I++) { - Current = ToCheck[I]; + while (!ToCheck.empty()) { + const ExMemoryAccess CurrentAccess = ToCheck.pop_back_val(); + MemoryAccess *Current = CurrentAccess.first; if (State.SkipStores.count(Current)) continue; - Optional Next = State.getDomMemoryDef( - KillingDef, Current, SILoc, SILocUnd, ScanLimit, WalkerStepLimit, - IsMemTerm, PartialLimit); + Optional Next = State.getDomMemoryDef( + KillingDef, CurrentAccess, SILoc, SILocUnd, ScanLimit, + WalkerStepLimit, IsMemTerm, PartialLimit); if (!Next) { LLVM_DEBUG(dbgs() << " finished walk\n"); continue; } - MemoryAccess *EarlierAccess = *Next; + ExMemoryAccess &ExEarlierAccess = *Next; + MemoryAccess *EarlierAccess = ExEarlierAccess.first; LLVM_DEBUG(dbgs() << " Checking if we can kill " << *EarlierAccess); if (isa(EarlierAccess)) { LLVM_DEBUG(dbgs() << "\n ... adding incoming values to worklist\n"); @@ -2006,14 +2132,15 @@ // strictly dominate our starting def. if (State.PostOrderNumbers[IncomingBlock] > State.PostOrderNumbers[PhiBlock]) - ToCheck.insert(IncomingAccess); + ToCheck.push_back(phiTransFromMemoryAccessTo( + ExEarlierAccess, IncomingAccess, DL, DT, AC)); } continue; } auto *NextDef = cast(EarlierAccess); Instruction *NI = NextDef->getMemoryInst(); LLVM_DEBUG(dbgs() << " (" << *NI << ")\n"); - ToCheck.insert(NextDef->getDefiningAccess()); + ToCheck.push_back(getExDefiningAccess(ExEarlierAccess, DL, DT, AC)); NumGetDomMemoryDefPassed++; if (!DebugCounter::shouldExecute(MemorySSACounter)) @@ -2021,9 +2148,22 @@ MemoryLocation NILoc = *State.getLocForWriteEx(NI); + // Check if NI overwrites SI. + int64_t ResOffset = 0; + int64_t CurOffset = 0; + const Value *ResSIPtr = SILoc.Ptr; + + if (CurrentAccess.second.first.pointsToAliveValue()) { + ResSIPtr = CurrentAccess.second.first; + ResOffset = CurrentAccess.second.second; + } + + MemoryLocation ResSILoc(ResSIPtr, SILoc.Size, SILoc.AATags); + const Value *ResSILocUnd = getUnderlyingObject(ResSIPtr); + if (IsMemTerm) { const Value *NIUnd = getUnderlyingObject(NILoc.Ptr); - if (SILocUnd != NIUnd) + if (ResSILocUnd != NIUnd) continue; LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *NI << "\n KILLER: " << *SI << '\n'); @@ -2032,17 +2172,15 @@ MadeChange = true; } else { // Check if NI overwrites SI. - int64_t InstWriteOffset = 0; - int64_t DepWriteOffset = 0; - OverwriteResult OR = State.isOverwrite(SI, NI, SILoc, NILoc, - DepWriteOffset, InstWriteOffset); + OverwriteResult OR = + State.isOverwrite(SI, NI, SILoc, NILoc, CurOffset, ResOffset); if (OR == OW_MaybePartial) { auto Iter = State.IOLs.insert( std::make_pair( NI->getParent(), InstOverlapIntervalsTy())); auto &IOL = Iter.first->second; - OR = isPartialOverwrite(SILoc, NILoc, DepWriteOffset, InstWriteOffset, - NI, IOL); + OR = isPartialOverwrite(ResSILoc, NILoc, CurOffset, ResOffset, NI, + IOL); } if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) { @@ -2053,7 +2191,7 @@ // TODO: implement tryToMergeParialOverlappingStores using MemorySSA. if (Earlier && Later && DT.dominates(Earlier, Later)) { if (Constant *Merged = tryToMergePartialOverlappingStores( - Earlier, Later, InstWriteOffset, DepWriteOffset, State.DL, + Earlier, Later, ResOffset, CurOffset, State.DL, State.BatchAA, &DT)) { // Update stored value of earlier store to merged constant. @@ -2113,8 +2251,9 @@ MemorySSA &MSSA = AM.getResult(F).getMSSA(); PostDominatorTree &PDT = AM.getResult(F); LoopInfo &LI = AM.getResult(F); + AssumptionCache &AC = AM.getResult(F); - bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, LI); + bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, LI, &AC); #ifdef LLVM_ENABLE_STATS if (AreStatisticsEnabled()) @@ -2155,8 +2294,10 @@ PostDominatorTree &PDT = getAnalysis().getPostDomTree(); LoopInfo &LI = getAnalysis().getLoopInfo(); + AssumptionCache &AC = + getAnalysis().getAssumptionCache(F); - bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, LI); + bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, LI, &AC); #ifdef LLVM_ENABLE_STATS if (AreStatisticsEnabled()) @@ -2180,6 +2321,8 @@ AU.addPreserved(); AU.addRequired(); AU.addPreserved(); + AU.addRequired(); + AU.addPreserved(); } };