diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp --- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -38,6 +38,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" @@ -850,6 +851,96 @@ return false; } +using PhiTransPtr = std::pair; +using ExMemoryAccess = std::pair; + +static PhiTransPtr phiTranslatePtr(const PhiTransPtr &Ptr, BasicBlock *FromBB, + BasicBlock *ToBB, const DataLayout &DL, + DominatorTree &DT, AssumptionCache *AC) { + PhiTransPtr ResPtr; + DenseMap Visited; + SmallVector, 16> WorkList; + + if (FromBB == ToBB) { + return Ptr; + } + + WorkList.push_back({FromBB, Ptr}); + + while (!WorkList.empty()) { + auto CurrNode = WorkList.pop_back_val(); + BasicBlock *CurrBB = CurrNode.first; + const PhiTransPtr &CurrPtr = CurrNode.second; + int64_t Offset = 0; + Value *const BasePtr = + GetPointerBaseWithConstantOffset(CurrPtr.first, Offset, DL); + const int64_t BaseOffset = CurrPtr.second + Offset; + + for (pred_iterator PI = pred_begin(CurrBB), E = pred_end(CurrBB); PI != E; + ++PI) { + BasicBlock *PredBB = *PI; + PHITransAddr TransAddr{BasePtr, DL, AC}; + + // TODO: + if (!DT.dominates(ToBB, PredBB)) + continue; + + if (TransAddr.NeedsPHITranslationFromBlock(CurrBB) && + (!TransAddr.IsPotentiallyPHITranslatable() || + TransAddr.PHITranslateValue(CurrBB, PredBB, &DT, false))) + return PhiTransPtr{}; + + auto Inserted = + Visited.try_emplace(PredBB, TransAddr.getAddr(), BaseOffset); + auto &TransPtr = Inserted.first->second; + if (!Inserted.second) { + if (TransAddr.getAddr() != TransPtr.first || + BaseOffset != TransPtr.second) + // We already visited this block before. If it was with a different + // address - bail out! + return PhiTransPtr{}; + continue; + } + + if (PredBB == ToBB) { + ResPtr = TransPtr; + continue; + } + + WorkList.push_back({PredBB, TransPtr}); + } + } + + assert(ResPtr.first.pointsToAliveValue() && + "PHI translation is expected to complete successfully"); + return ResPtr; +} + +static ExMemoryAccess +phiTransFromMemoryAccessTo(const ExMemoryAccess &FromAccess, + MemoryAccess *ToAccess, const DataLayout &DL, + DominatorTree &DT, AssumptionCache *AC) { + PhiTransPtr ResAddr; + + if (FromAccess.second.first.pointsToAliveValue()) { + + ResAddr = phiTranslatePtr(FromAccess.second, FromAccess.first->getBlock(), + ToAccess->getBlock(), DL, DT, AC); + } + + return std::make_pair(ToAccess, ResAddr); +} + +static ExMemoryAccess getExDefiningAccess(const ExMemoryAccess &CurrAccess, + const DataLayout &DL, + DominatorTree &DT, + AssumptionCache *AC) { + assert(isa(CurrAccess.first) && "TODO"); + MemoryAccess *Next = + cast(CurrAccess.first)->getDefiningAccess(); + return phiTransFromMemoryAccessTo(CurrAccess, Next, DL, DT, AC); +} + struct DSEState { Function &F; AliasAnalysis &AA; @@ -869,6 +960,7 @@ const TargetLibraryInfo &TLI; const DataLayout &DL; const LoopInfo &LI; + AssumptionCache *AC; // Whether the function contains any irreducible control flow, useful for // being accurately able to detect loops. @@ -897,14 +989,16 @@ DSEState(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT, PostDominatorTree &PDT, const TargetLibraryInfo &TLI, - const LoopInfo &LI) + const LoopInfo &LI, AssumptionCache *AC) : F(F), AA(AA), BatchAA(AA), MSSA(MSSA), DT(DT), PDT(PDT), TLI(TLI), - DL(F.getParent()->getDataLayout()), LI(LI) {} + DL(F.getParent()->getDataLayout()), LI(LI), AC(AC), + ContainsIrreducibleLoops(false) {} static DSEState get(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT, PostDominatorTree &PDT, - const TargetLibraryInfo &TLI, const LoopInfo &LI) { - DSEState State(F, AA, MSSA, DT, PDT, TLI, LI); + const TargetLibraryInfo &TLI, const LoopInfo &LI, + AssumptionCache *AC) { + DSEState State(F, AA, MSSA, DT, PDT, TLI, LI, AC); // Collect blocks with throwing instructions not modeled in MemorySSA and // alloc-like objects. unsigned PO = 0; @@ -1359,8 +1453,8 @@ // such MemoryDef, return None. The returned value may not (completely) // overwrite \p DefLoc. Currently we bail out when we encounter an aliasing // MemoryUse (read). - Optional - getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *StartAccess, + Optional + getDomMemoryDef(MemoryDef *KillingDef, const ExMemoryAccess &ExStartAccess, const MemoryLocation &KillingLoc, const Value *KillingUndObj, unsigned &ScanLimit, unsigned &WalkerStepLimit, bool IsMemTerm, unsigned &PartialLimit) { @@ -1369,30 +1463,47 @@ return None; } - MemoryAccess *Current = StartAccess; + MemoryAccess *const StartAccess = ExStartAccess.first; Instruction *KillingI = KillingDef->getMemoryInst(); LLVM_DEBUG(dbgs() << " trying to get dominating access\n"); - // Find the next clobbering Mod access for DefLoc, starting at StartAccess. + ExMemoryAccess ExCurrentAccess = ExStartAccess; Optional CurrentLoc; - for (;; Current = cast(Current)->getDefiningAccess()) { + Optional> PhiTransKillingLocAndOffset; + + // Find the next clobbering Mod access for KillingLoc, starting at + // StartAccess. + for (;; + ExCurrentAccess = getExDefiningAccess(ExCurrentAccess, DL, DT, AC)) { + MemoryAccess *CurrentAccess = ExCurrentAccess.first; + + PhiTransKillingLocAndOffset.reset(); + if (ExCurrentAccess.second.first.pointsToAliveValue()) { + PhiTransKillingLocAndOffset = { + KillingLoc.getWithNewPtr(ExCurrentAccess.second.first) + .getWithoutAATags(), + ExCurrentAccess.second.second}; + } + LLVM_DEBUG({ - dbgs() << " visiting " << *Current; - if (!MSSA.isLiveOnEntryDef(Current) && isa(Current)) - dbgs() << " (" << *cast(Current)->getMemoryInst() + dbgs() << " visiting " << *CurrentAccess; + if (!MSSA.isLiveOnEntryDef(CurrentAccess) && + isa(CurrentAccess)) + dbgs() << " (" + << *cast(CurrentAccess)->getMemoryInst() << ")"; dbgs() << "\n"; }); // Reached TOP. - if (MSSA.isLiveOnEntryDef(Current)) { + if (MSSA.isLiveOnEntryDef(CurrentAccess)) { LLVM_DEBUG(dbgs() << " ... found LiveOnEntryDef\n"); return None; } // Cost of a step. Accesses in the same block are more likely to be valid // candidates for elimination, hence consider them cheaper. - unsigned StepCost = KillingDef->getBlock() == Current->getBlock() + unsigned StepCost = KillingDef->getBlock() == CurrentAccess->getBlock() ? MemorySSASameBBStepCost : MemorySSAOtherBBStepCost; if (WalkerStepLimit <= StepCost) { @@ -1403,14 +1514,14 @@ // Return for MemoryPhis. They cannot be eliminated directly and the // caller is responsible for traversing them. - if (isa(Current)) { + if (isa(CurrentAccess)) { LLVM_DEBUG(dbgs() << " ... found MemoryPhi\n"); - return Current; + return ExCurrentAccess; } // Below, check if CurrentDef is a valid candidate to be eliminated by // KillingDef. If it is not, check the next candidate. - MemoryDef *CurrentDef = cast(Current); + MemoryDef *CurrentDef = cast(CurrentAccess); Instruction *CurrentI = CurrentDef->getMemoryInst(); if (canSkipDef(CurrentDef, !isInvisibleToCallerBeforeRet(KillingUndObj))) @@ -1438,12 +1549,13 @@ return None; // Quick check if there are direct uses that are read-clobbers. - if (any_of(Current->uses(), [this, &KillingLoc, StartAccess](Use &U) { - if (auto *UseOrDef = dyn_cast(U.getUser())) - return !MSSA.dominates(StartAccess, UseOrDef) && - isReadClobber(KillingLoc, UseOrDef->getMemoryInst()); - return false; - })) { + if (any_of( + CurrentAccess->uses(), [this, &KillingLoc, StartAccess](Use &U) { + if (auto *UseOrDef = dyn_cast(U.getUser())) + return !MSSA.dominates(StartAccess, UseOrDef) && + isReadClobber(KillingLoc, UseOrDef->getMemoryInst()); + return false; + })) { LLVM_DEBUG(dbgs() << " ... found a read clobber\n"); return None; } @@ -1454,6 +1566,7 @@ continue; // If Current does not have an analyzable write location, skip it + int64_t CurrOffset = 0; CurrentLoc = getLocForWriteEx(CurrentI); if (!CurrentLoc) continue; @@ -1474,10 +1587,14 @@ if (!isMemTerminator(*CurrentLoc, CurrentI, KillingI)) continue; } else { - int64_t KillingOffset = 0; - int64_t KilledOffset = 0; - auto OR = isOverwrite(KillingI, CurrentI, KillingLoc, *CurrentLoc, - KillingOffset, KilledOffset); + auto &KillingMemLoc = PhiTransKillingLocAndOffset + ? (*PhiTransKillingLocAndOffset).first + : KillingLoc; + int64_t KillingMemOffset = PhiTransKillingLocAndOffset + ? (*PhiTransKillingLocAndOffset).second + : 0; + auto OR = isOverwrite(KillingI, CurrentI, KillingMemLoc, *CurrentLoc, + KillingMemOffset, CurrOffset); // If Current does not write to the same object as KillingDef, check // the next candidate. if (OR == OW_Unknown || OR == OW_None) @@ -1504,7 +1621,7 @@ // they cover all paths from KilledAccess to any function exit. SmallPtrSet KillingDefs; KillingDefs.insert(KillingDef->getMemoryInst()); - MemoryAccess *KilledAccess = Current; + MemoryAccess *KilledAccess = ExCurrentAccess.first; MemoryLocation KilledLoc = *CurrentLoc; Instruction *KilledI = cast(KilledAccess)->getMemoryInst(); LLVM_DEBUG(dbgs() << " Checking for reads of " << *KilledAccess << " (" @@ -1656,7 +1773,7 @@ // post-dominates KilledAccess. if (KillingBlocks.count(CommonPred)) { if (PDT.dominates(CommonPred, KilledAccess->getBlock())) - return {KilledAccess}; + return {ExCurrentAccess}; return None; } @@ -1679,32 +1796,32 @@ // killing blocks before reaching KilledAccess. for (unsigned I = 0; I < WorkList.size(); I++) { NumCFGChecks++; - BasicBlock *Current = WorkList[I]; - if (KillingBlocks.count(Current)) + BasicBlock *CurrentBB = WorkList[I]; + if (KillingBlocks.count(CurrentBB)) continue; - if (Current == KilledAccess->getBlock()) + if (CurrentBB == KilledAccess->getBlock()) return None; // KilledAccess is reachable from the entry, so we don't have to // explore unreachable blocks further. - if (!DT.isReachableFromEntry(Current)) + if (!DT.isReachableFromEntry(CurrentBB)) continue; - for (BasicBlock *Pred : predecessors(Current)) + for (BasicBlock *Pred : predecessors(CurrentBB)) WorkList.insert(Pred); if (WorkList.size() >= MemorySSAPathCheckLimit) return None; } NumCFGSuccess++; - return {KilledAccess}; + return {ExCurrentAccess}; } return None; } // No aliasing MemoryUses of KilledAccess found, KilledAccess is // potentially dead. - return {KilledAccess}; + return {ExCurrentAccess}; } // Delete dead memory defs @@ -1907,10 +2024,11 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT, PostDominatorTree &PDT, const TargetLibraryInfo &TLI, - const LoopInfo &LI) { + const LoopInfo &LI, AssumptionCache *AC) { bool MadeChange = false; - DSEState State = DSEState::get(F, AA, MSSA, DT, PDT, TLI, LI); + const DataLayout &DL = F.getParent()->getDataLayout(); + DSEState State = DSEState::get(F, AA, MSSA, DT, PDT, TLI, LI, AC); // For each store: for (unsigned I = 0; I < State.MemDefs.size(); I++) { MemoryDef *KillingDef = State.MemDefs[I]; @@ -1934,7 +2052,6 @@ assert(KillingLoc.Ptr && "KillingLoc should not be null"); const Value *KillingUndObj = getUnderlyingObject(KillingLoc.Ptr); - MemoryAccess *Current = KillingDef; LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by " << *KillingDef << " (" << *KillingI << ")\n"); @@ -1942,21 +2059,25 @@ unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit; unsigned PartialLimit = MemorySSAPartialStoreLimit; // Worklist of MemoryAccesses that may be killed by KillingDef. - SetVector ToCheck; + SmallVector ToCheck; - if (KillingUndObj) - ToCheck.insert(KillingDef->getDefiningAccess()); + if (KillingUndObj) { + ExMemoryAccess ExKillingDef{KillingDef, + {const_cast(KillingLoc.Ptr), 0}}; + ToCheck.push_back(getExDefiningAccess(ExKillingDef, DL, DT, AC)); + } bool Shortend = false; bool IsMemTerm = State.isMemTerminatorInst(KillingI); // Check if MemoryAccesses in the worklist are killed by KillingDef. - for (unsigned I = 0; I < ToCheck.size(); I++) { - Current = ToCheck[I]; + while (!ToCheck.empty()) { + const ExMemoryAccess CurrentAccess = ToCheck.pop_back_val(); + MemoryAccess *Current = CurrentAccess.first; if (State.SkipStores.count(Current)) continue; - Optional MaybeKilledAccess = State.getDomMemoryDef( - KillingDef, Current, KillingLoc, KillingUndObj, ScanLimit, + Optional MaybeKilledAccess = State.getDomMemoryDef( + KillingDef, CurrentAccess, KillingLoc, KillingUndObj, ScanLimit, WalkerStepLimit, IsMemTerm, PartialLimit); if (!MaybeKilledAccess) { @@ -1964,7 +2085,8 @@ continue; } - MemoryAccess *KilledAccess = *MaybeKilledAccess; + ExMemoryAccess &ExKilledAccess = *MaybeKilledAccess; + MemoryAccess *KilledAccess = ExKilledAccess.first; LLVM_DEBUG(dbgs() << " Checking if we can kill " << *KilledAccess); if (isa(KilledAccess)) { LLVM_DEBUG(dbgs() << "\n ... adding incoming values to worklist\n"); @@ -1978,14 +2100,15 @@ // strictly dominate our starting def. if (State.PostOrderNumbers[IncomingBlock] > State.PostOrderNumbers[PhiBlock]) - ToCheck.insert(IncomingAccess); + ToCheck.push_back(phiTransFromMemoryAccessTo( + ExKilledAccess, IncomingAccess, DL, DT, AC)); } continue; } auto *KilledDefAccess = cast(KilledAccess); Instruction *KilledI = KilledDefAccess->getMemoryInst(); LLVM_DEBUG(dbgs() << " (" << *KilledI << ")\n"); - ToCheck.insert(KilledDefAccess->getDefiningAccess()); + ToCheck.push_back(getExDefiningAccess(ExKilledAccess, DL, DT, AC)); NumGetDomMemoryDefPassed++; if (!DebugCounter::shouldExecute(MemorySSACounter)) @@ -2006,6 +2129,12 @@ // Check if KilledI overwrites KillingI. int64_t KillingOffset = 0; int64_t KilledOffset = 0; + + if (ExKilledAccess.second.first.pointsToAliveValue()) { + KillingLoc.Ptr = ExKilledAccess.second.first; + KillingOffset = ExKilledAccess.second.second; + } + OverwriteResult OR = State.isOverwrite(KillingI, KilledI, KillingLoc, KilledLoc, KillingOffset, KilledOffset); @@ -2087,8 +2216,9 @@ MemorySSA &MSSA = AM.getResult(F).getMSSA(); PostDominatorTree &PDT = AM.getResult(F); LoopInfo &LI = AM.getResult(F); + AssumptionCache &AC = AM.getResult(F); - bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, LI); + bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, LI, &AC); #ifdef LLVM_ENABLE_STATS if (AreStatisticsEnabled()) @@ -2129,8 +2259,10 @@ PostDominatorTree &PDT = getAnalysis().getPostDomTree(); LoopInfo &LI = getAnalysis().getLoopInfo(); + AssumptionCache &AC = + getAnalysis().getAssumptionCache(F); - bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, LI); + bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, LI, &AC); #ifdef LLVM_ENABLE_STATS if (AreStatisticsEnabled()) @@ -2154,6 +2286,8 @@ AU.addPreserved(); AU.addRequired(); AU.addPreserved(); + AU.addRequired(); + AU.addPreserved(); } }; diff --git a/llvm/test/Transforms/DeadStoreElimination/phi-translation.ll b/llvm/test/Transforms/DeadStoreElimination/phi-translation.ll --- a/llvm/test/Transforms/DeadStoreElimination/phi-translation.ll +++ b/llvm/test/Transforms/DeadStoreElimination/phi-translation.ll @@ -1,7 +1,7 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt -dse -S %s | FileCheck %s -; TODO: Both the stores in %then and %else can be eliminated by translating %p +; Both the stores in %then and %else can be eliminated by translating %p ; through the phi. define void @memoryphi_translate_1(i1 %c) { ; CHECK-LABEL: @memoryphi_translate_1( @@ -10,10 +10,8 @@ ; CHECK-NEXT: [[A_2:%.*]] = alloca i8, align 1 ; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[ELSE:%.*]] ; CHECK: then: -; CHECK-NEXT: store i8 0, i8* [[A_1]], align 1 ; CHECK-NEXT: br label [[END:%.*]] ; CHECK: else: -; CHECK-NEXT: store i8 9, i8* [[A_2]], align 1 ; CHECK-NEXT: br label [[END]] ; CHECK: end: ; CHECK-NEXT: [[P:%.*]] = phi i8* [ [[A_1]], [[THEN]] ], [ [[A_2]], [[ELSE]] ] @@ -39,7 +37,7 @@ ret void } -; TODO: The store in %else can be eliminated by translating %p through the phi. +; The store in %else can be eliminated by translating %p through the phi. ; The store in %then cannot be eliminated, because %a.1 is read before the final ; store. define i8 @memoryphi_translate_2(i1 %c) { @@ -52,7 +50,6 @@ ; CHECK-NEXT: store i8 0, i8* [[A_1]], align 1 ; CHECK-NEXT: br label [[END:%.*]] ; CHECK: else: -; CHECK-NEXT: store i8 9, i8* [[A_2]], align 1 ; CHECK-NEXT: br label [[END]] ; CHECK: end: ; CHECK-NEXT: [[P:%.*]] = phi i8* [ [[A_1]], [[THEN]] ], [ [[A_2]], [[ELSE]] ] @@ -80,7 +77,7 @@ ret i8 %l } -; TODO: The store in %then can be eliminated by translating %p through the phi. +; The store in %then can be eliminated by translating %p through the phi. ; The store in %else cannot be eliminated, because %a.2 is read before the final ; store. define i8 @memoryphi_translate_3(i1 %c) { @@ -90,7 +87,6 @@ ; CHECK-NEXT: [[A_2:%.*]] = alloca i8, align 1 ; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[ELSE:%.*]] ; CHECK: then: -; CHECK-NEXT: store i8 0, i8* [[A_1]], align 1 ; CHECK-NEXT: br label [[END:%.*]] ; CHECK: else: ; CHECK-NEXT: store i8 9, i8* [[A_2]], align 1 @@ -241,8 +237,6 @@ ; CHECK: else: ; CHECK-NEXT: call void @fn() ; CHECK-NEXT: [[BC:%.*]] = bitcast i8* undef to i16* -; CHECK-NEXT: [[GEP_1:%.*]] = getelementptr inbounds i16, i16* [[BC]], i64 2 -; CHECK-NEXT: store i16 8, i16* [[GEP_1]], align 2 ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: [[P:%.*]] = phi i16* [ [[PTR:%.*]], [[THEN]] ], [ [[BC]], [[ELSE]] ]