Index: include/llvm/Transforms/Utils/MemorySSA.h =================================================================== --- include/llvm/Transforms/Utils/MemorySSA.h +++ include/llvm/Transforms/Utils/MemorySSA.h @@ -602,11 +602,20 @@ void verifyDomination(Function &F) const; void verifyOrdering(Function &F) const; + // This is used by the use optimizer class + AccessList *getWritableBlockAccesses(const BasicBlock *BB) const { + auto It = PerBlockAccesses.find(BB); + return It == PerBlockAccesses.end() ? nullptr : It->second.get(); + } + private: class CachingWalker; + class OptimizeUses; CachingWalker *getWalkerImpl(); void buildMemorySSA(); + void optimizeUses(); + void verifyUseInDefs(MemoryAccess *, MemoryAccess *) const; using AccessMap = DenseMap>; Index: lib/Transforms/Utils/MemorySSA.cpp =================================================================== --- lib/Transforms/Utils/MemorySSA.cpp +++ lib/Transforms/Utils/MemorySSA.cpp @@ -61,6 +61,11 @@ INITIALIZE_PASS_END(MemorySSAPrinterLegacyPass, "print-memoryssa", "Memory SSA Printer", false, false) +static cl::opt MaxCheckLimit( + "memssa-check-limit", cl::Hidden, cl::init(100), + cl::desc("The maximum number of stores/phis MemorySSA" + "will consider trying to walk past (default = 100)")); + static cl::opt VerifyMemorySSA("verify-memoryssa", cl::init(false), cl::Hidden, cl::desc("Verify MemorySSA in legacy printer pass.")); @@ -111,16 +116,17 @@ } }; -static bool instructionClobbersQuery(MemoryDef *MD, const MemoryLocation &Loc, - const UpwardsMemoryQuery &Query, +static bool instructionClobbersQuery(MemoryDef *MD, + const MemoryLocation &UseLoc, + const Instruction *UseInst, AliasAnalysis &AA) { - Instruction *DefMemoryInst = MD->getMemoryInst(); - assert(DefMemoryInst && "Defining instruction not actually an instruction"); - - if (!Query.IsCall) - return AA.getModRefInfo(DefMemoryInst, Loc) & MRI_Mod; + Instruction *DefInst = MD->getMemoryInst(); + assert(DefInst && "Defining instruction not actually an instruction"); + ImmutableCallSite UseCS(UseInst); + if (!UseCS) + return AA.getModRefInfo(DefInst, UseLoc) & MRI_Mod; - ModRefInfo I = AA.getModRefInfo(DefMemoryInst, ImmutableCallSite(Query.Inst)); + ModRefInfo I = AA.getModRefInfo(DefInst, UseCS); return I != MRI_NoModRef; } @@ -257,8 +263,9 @@ // // Also, note that this can't be hoisted out of the `Worklist` loop, // since MD may only act as a clobber for 1 of N MemoryLocations. - FoundClobber = FoundClobber || MSSA.isLiveOnEntryDef(MD) || - instructionClobbersQuery(MD, MAP.second, Query, AA); + FoundClobber = + FoundClobber || MSSA.isLiveOnEntryDef(MD) || + instructionClobbersQuery(MD, MAP.second, Query.Inst, AA); } break; } @@ -268,7 +275,7 @@ if (auto *MD = dyn_cast(MA)) { (void)MD; - assert(!instructionClobbersQuery(MD, MAP.second, Query, AA) && + assert(!instructionClobbersQuery(MD, MAP.second, Query.Inst, AA) && "Found clobber before reaching ClobberAt!"); continue; } @@ -428,7 +435,7 @@ if (auto *MD = dyn_cast(Current)) if (MSSA.isLiveOnEntryDef(MD) || - instructionClobbersQuery(MD, Desc.Loc, *Query, AA)) + instructionClobbersQuery(MD, Desc.Loc, Query->Inst, AA)) return {MD, true, false}; // Cache checks must be done last, because if Current is a clobber, the @@ -1065,6 +1072,287 @@ return Res.first->second.get(); } +/// Our current alias analysis API differentiates heavily between calls and +/// non-calls, and functions called on one usually assert on the other. +/// This class encapsulates the distinction to simplify other code that wants +/// "Memory affecting instructions and related data" to use as a key. +/// For example, this class is used as a densemap key in the use optimizer. +class MemoryLocOrCall { +public: + MemoryLocOrCall() : isCall(false) {} + MemoryLocOrCall(MemoryUseOrDef *MUD) + : MemoryLocOrCall(MUD->getMemoryInst()) {} + + MemoryLocOrCall(Instruction *Inst) { + if (ImmutableCallSite(Inst)) { + isCall = true; + CS = ImmutableCallSite(Inst); + } else { + isCall = false; + // There is no such thing as a memorylocation for a fence inst, and it is + // unique in that regard. + if (!isa(Inst)) + Loc = MemoryLocation::get(Inst); + } + } + + explicit MemoryLocOrCall(MemoryLocation Loc) : isCall(false), Loc(Loc) {} + + bool isCall; + ImmutableCallSite getCS() const { + assert(isCall); + return CS; + } + MemoryLocation getLoc() const { + assert(!isCall); + return Loc; + } + bool operator==(const MemoryLocOrCall &Other) const { + if (isCall != Other.isCall) + return false; + + if (isCall) + return CS.getCalledValue() == Other.CS.getCalledValue(); + else + return Loc == Other.Loc; + } + +private: + union { + ImmutableCallSite CS; + MemoryLocation Loc; + }; +}; + +template <> struct DenseMapInfo { + static inline MemoryLocOrCall getEmptyKey() { + return MemoryLocOrCall(DenseMapInfo::getEmptyKey()); + } + static inline MemoryLocOrCall getTombstoneKey() { + return MemoryLocOrCall(DenseMapInfo::getTombstoneKey()); + } + static unsigned getHashValue(const MemoryLocOrCall &MLOC) { + if (MLOC.isCall) + return hash_combine(MLOC.isCall, + DenseMapInfo::getHashValue( + MLOC.getCS().getCalledValue())); + else + return hash_combine( + MLOC.isCall, + DenseMapInfo::getHashValue(MLOC.getLoc())); + } + static bool isEqual(const MemoryLocOrCall &LHS, const MemoryLocOrCall &RHS) { + return LHS == RHS; + } +}; + +/// This class is a batch walker of all MemoryUse's in the program, and points +/// their defining access at the thing that actually clobbers them. Because it +/// is a batch walker that touches everything, it does not operate like the +/// other walkers. This walker is basically performing a top-down SSA renaming +/// pass, where the version stack is used as the cache. This enables it to be +/// significantly more time and memory efficient than using the regular walker, +/// which is walking bottom-up. +class MemorySSA::OptimizeUses { +public: + OptimizeUses(MemorySSA *MSSA, MemorySSAWalker *Walker, AliasAnalysis *AA, + DominatorTree *DT) + : MSSA(MSSA), Walker(Walker), AA(AA), DT(DT) { + Walker = MSSA->getWalker(); + } + + void optimizeUses(); + +private: + /// This represents where a given memorylocation is in the stack. + struct MemlocStackInfo { + // This essentially is keeping track of versions of the stack. Whenever + // the stack changes due to pushes or pops, these versions increase. + unsigned long StackEpoch; + unsigned long PopEpoch; + // This is the lower bound of places on the stack to check. It is equal to + // the place the last stack walk ended. + // Note: Correctness depends on this being initialized to 0, which densemap + // does + unsigned long LowerBound; + // This is where the last walk for this memory location ended. + unsigned long LastKill; + bool LastKillValid; + }; + void optimizeUsesInBlock(const BasicBlock *, unsigned long &, unsigned long &, + SmallVectorImpl &, + DenseMap &); + MemorySSA *MSSA; + MemorySSAWalker *Walker; + AliasAnalysis *AA; + DominatorTree *DT; +}; + +static bool instructionClobbersQuery(MemoryDef *MD, MemoryLocOrCall &UseMLOC, + AliasAnalysis &AA) { + Instruction *DefInst = MD->getMemoryInst(); + assert(DefInst && "Defining instruction not actually an instruction"); + if (!UseMLOC.isCall) + return AA.getModRefInfo(DefInst, UseMLOC.getLoc()) & MRI_Mod; + + ModRefInfo I = AA.getModRefInfo(DefInst, UseMLOC.getCS()); + return I != MRI_NoModRef; +} + +/// Optimize the uses in a given block This is basically the SSA renaming +/// algorithm, with one caveat: We are able to use a single stack for all +/// MemoryUses. This is because the set of *possible* reaching MemoryDefs is +/// the same for every MemoryUse. The *actual* clobbering MemoryDef is just +/// going to be some position in that stack of possible ones. +/// +/// We track the stack positions that each MemoryLocation needs +/// to check, and last ended at. This is because we only want to check the +/// things that changed since last time. The same MemoryLocation should +/// get clobbered by the same store (getModRefInfo does not use invariantness or +/// things like this, and if they start, we can modify MemoryLocOrCall to +/// include relevant data) +void MemorySSA::OptimizeUses::optimizeUsesInBlock( + const BasicBlock *BB, unsigned long &StackEpoch, unsigned long &PopEpoch, + SmallVectorImpl &VersionStack, + DenseMap &LocStackInfo) { + + /// If no accesses, nothing to do. + MemorySSA::AccessList *Accesses = MSSA->getWritableBlockAccesses(BB); + if (Accesses == nullptr) + return; + + // Pop everything that doesn't dominate the current block off the stack, + // increment the PopEpoch to account for this. + while (!DT->dominates(VersionStack.back()->getBlock(), BB)) { + VersionStack.pop_back(); + ++PopEpoch; + } + for (MemoryAccess &MA : *Accesses) { + if (auto *MU = dyn_cast(&MA)) { + MemoryLocOrCall UseMLOC(MU); + auto &LocInfo = LocStackInfo[UseMLOC]; + // If the pop epoch changed, if means we've removed stuff from top of + // stack due to changing blocks. We may have to reset the lower bound or + // last kill info. + if (LocInfo.PopEpoch != PopEpoch) { + LocInfo.PopEpoch = PopEpoch; + LocInfo.StackEpoch = StackEpoch; + // If the lower bound was in the info we popped, we have to reset it. + if (LocInfo.LowerBound >= VersionStack.size()) { + // Reset the lower bound of things to check. + // TODO: Some day we should be able to reset to last kill, rather than + // 0. + + LocInfo.LowerBound = 0; + LocInfo.LastKillValid = false; + } + } else if (LocInfo.StackEpoch != StackEpoch) { + // If all that has changed is the StackEpoch, we only have to check the + // new things on the stack, because we've checked everything before. In + // this case, the lower bound of things to check remains the same. + LocInfo.PopEpoch = PopEpoch; + LocInfo.StackEpoch = StackEpoch; + } + if (!LocInfo.LastKillValid) { + LocInfo.LastKill = VersionStack.size() - 1; + LocInfo.LastKillValid = true; + } + + // At this point, we should have corrected last kill and LowerBound to be + // in bounds. + assert(LocInfo.LowerBound < VersionStack.size() && + "Lower bound out of range"); + assert(LocInfo.LastKill < VersionStack.size() && + "Last kill info out of range"); + // In any case, the new upper bound is the top of the stack. + unsigned long UpperBound = VersionStack.size() - 1; + + if ((UpperBound - LocInfo.LowerBound) > MaxCheckLimit) { + DEBUG(dbgs() << "We are being asked to check up to " + << UpperBound - LocInfo.LowerBound + << " loads and stores, so we didn't.\n"); + // Because we did not walk, LastKill is no longer valid, as this may + // have been a kill. + LocInfo.LastKillValid = false; + continue; + } + bool FoundClobberResult = false; + while (UpperBound > LocInfo.LowerBound) { + if (isa(VersionStack[UpperBound])) { + // For phis, use the walker, see where we ended up, go there + Instruction *UseInst = MU->getMemoryInst(); + MemoryAccess *Result = Walker->getClobberingMemoryAccess(UseInst); + // We are guaranteed to find it or something is wrong + while (VersionStack[UpperBound] != Result) { + assert(UpperBound != 0); + --UpperBound; + } + FoundClobberResult = true; + break; + } else if (MemoryDef *MD = cast(VersionStack[UpperBound])) { + if (instructionClobbersQuery(MD, UseMLOC, *AA)) { + FoundClobberResult = true; + break; + } + } + --UpperBound; + } + // At the end of this loop, UpperBound is either a clobber, or lower bound + // PHI walking may cause it to be < LowerBound, and in fact, < LastKill. + if (FoundClobberResult || UpperBound < LocInfo.LastKill) { + MU->setDefiningAccess(VersionStack[UpperBound]); + // We were last killed now by where we got to + LocInfo.LastKill = UpperBound; + } else { + // Otherwise, we checked all the new ones, and now we know we can get to + // LastKill. + MU->setDefiningAccess(VersionStack[LocInfo.LastKill]); + } + LocInfo.LowerBound = VersionStack.size() - 1; + } else { + VersionStack.push_back(&MA); + ++StackEpoch; + } + } +} + +/// Optimize uses to point to their actually clobbering definitions. +void MemorySSA::OptimizeUses::optimizeUses() { + + // We perform a non-recursive top-down dominator tree walk + struct StackInfo { + const DomTreeNode *Node; + DomTreeNode::const_iterator Iter; + }; + + SmallVector VersionStack; + SmallVector DomTreeWorklist; + DenseMap LocStackInfo; + DomTreeWorklist.push_back({DT->getRootNode(), DT->getRootNode()->begin()}); + // Bottom of the version stack is always live on entry. + VersionStack.push_back(MSSA->getLiveOnEntryDef()); + + unsigned long StackEpoch = 1; + unsigned long PopEpoch = 1; + while (!DomTreeWorklist.empty()) { + const auto *DomNode = DomTreeWorklist.back().Node; + const auto DomIter = DomTreeWorklist.back().Iter; + BasicBlock *BB = DomNode->getBlock(); + optimizeUsesInBlock(BB, StackEpoch, PopEpoch, VersionStack, LocStackInfo); + if (DomIter == DomNode->end()) { + // Hit the end, pop the worklist + DomTreeWorklist.pop_back(); + continue; + } + // Move the iterator to the next child for the next time we get to process + // children + ++DomTreeWorklist.back().Iter; + + // Now visit the next child + DomTreeWorklist.push_back({*DomIter, (*DomIter)->begin()}); + } +} + void MemorySSA::buildMemorySSA() { // We create an access to represent "live on entry", for things like // arguments or users of globals, where the memory they use is defined before @@ -1161,25 +1449,7 @@ // We're doing a batch of updates; don't drop useful caches between them. Walker->setAutoResetWalker(false); - - // Now optimize the MemoryUse's defining access to point to the nearest - // dominating clobbering def. - // This ensures that MemoryUse's that are killed by the same store are - // immediate users of that store, one of the invariants we guarantee. - for (auto DomNode : depth_first(DT)) { - BasicBlock *BB = DomNode->getBlock(); - auto AI = PerBlockAccesses.find(BB); - if (AI == PerBlockAccesses.end()) - continue; - AccessList *Accesses = AI->second.get(); - for (auto &MA : *Accesses) { - if (auto *MU = dyn_cast(&MA)) { - Instruction *Inst = MU->getMemoryInst(); - MU->setDefiningAccess(Walker->getClobberingMemoryAccess(Inst)); - } - } - } - + OptimizeUses(this, Walker, AA, DT).optimizeUses(); Walker->setAutoResetWalker(true); Walker->resetClobberWalker(); @@ -1842,10 +2112,12 @@ : StartingUseOrDef; MemoryAccess *Clobber = getClobberingMemoryAccess(DefiningAccess, Q); +#if 0 DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); DEBUG(dbgs() << *StartingUseOrDef << "\n"); DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); DEBUG(dbgs() << *Clobber << "\n"); +#endif return Clobber; } @@ -1876,11 +2148,12 @@ return DefiningAccess; MemoryAccess *Result = getClobberingMemoryAccess(DefiningAccess, Q); +#if 0 DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); DEBUG(dbgs() << *DefiningAccess << "\n"); DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); DEBUG(dbgs() << *Result << "\n"); - +#endif return Result; }