Index: include/llvm/Transforms/Utils/MemorySSA.h =================================================================== --- include/llvm/Transforms/Utils/MemorySSA.h +++ include/llvm/Transforms/Utils/MemorySSA.h @@ -73,6 +73,7 @@ #define LLVM_TRANSFORMS_UTILS_MEMORYSSA_H #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/GraphTraits.h" #include "llvm/ADT/iterator_range.h" #include "llvm/ADT/SmallPtrSet.h" @@ -80,6 +81,7 @@ #include "llvm/ADT/ilist.h" #include "llvm/ADT/ilist_node.h" #include "llvm/ADT/iterator.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/PHITransAddr.h" @@ -109,6 +111,9 @@ class LLVMContext; class raw_ostream; +struct AllAccessTag {}; +struct DefsOnlyTag {}; + enum { // Used to signify what the default invalid ID is for MemoryAccess's // getID() @@ -122,8 +127,13 @@ // \brief The base for all memory accesses. All memory accesses in a block are // linked together using an intrusive list. -class MemoryAccess : public User, public ilist_node { +class MemoryAccess : public User, + public ilist_node>, + public ilist_node> { public: + using AllAccessType = ilist_node>; + using DefsOnlyType = ilist_node>; + // Methods for support type inquiry through isa, cast, and // dyn_cast static inline bool classof(const MemoryAccess *) { return true; } @@ -156,6 +166,33 @@ memoryaccess_def_iterator defs_end(); const_memoryaccess_def_iterator defs_end() const; + /// \brief Get the iterators for the all access list and the defs only list + /// We default to the all access list. + AllAccessType::self_iterator getIterator() { + return this->AllAccessType::getIterator(); + } + AllAccessType::const_self_iterator getIterator() const { + return this->AllAccessType::getIterator(); + } + AllAccessType::reverse_self_iterator getReverseIterator() { + return this->AllAccessType::getReverseIterator(); + } + AllAccessType::const_reverse_self_iterator getReverseIterator() const { + return this->AllAccessType::getReverseIterator(); + } + DefsOnlyType::self_iterator getDefsIterator() { + return this->DefsOnlyType::getIterator(); + } + DefsOnlyType::const_self_iterator getDefsIterator() const { + return this->DefsOnlyType::getIterator(); + } + DefsOnlyType::reverse_self_iterator getReverseDefsIterator() { + return this->DefsOnlyType::getReverseIterator(); + } + DefsOnlyType::const_reverse_self_iterator getReverseDefsIterator() const { + return this->DefsOnlyType::getReverseIterator(); + } + protected: friend class MemorySSA; friend class MemoryUseOrDef; @@ -531,14 +568,21 @@ return LiveOnEntryDef.get(); } - using AccessList = iplist; - + using AccessList = iplist>; + using DefsList = simple_ilist>; /// \brief Return the list of MemoryAccess's for a given basic block. /// /// This list is not modifiable by the user. const AccessList *getBlockAccesses(const BasicBlock *BB) const { return getWritableBlockAccesses(BB); } + /// \brief Return the list of MemoryDef's and MemoryPhi's for a given basic + /// block. + /// + /// This list is not modifiable by the user. + const DefsList *getBlockDefs(const BasicBlock *BB) const { + return getWritableBlockDefs(BB); + } /// \brief Create an empty MemoryPhi in MemorySSA for a given basic block. /// Only one MemoryPhi for a block exists at a time, so this function will @@ -623,12 +667,20 @@ void verifyDomination(Function &F) const; void verifyOrdering(Function &F) const; - // This is used by the use optimizer class + // This is used by the use optimizer and updater. AccessList *getWritableBlockAccesses(const BasicBlock *BB) const { auto It = PerBlockAccesses.find(BB); return It == PerBlockAccesses.end() ? nullptr : It->second.get(); +} + + // This is used by the use optimizer and updater. + DefsList *getWritableBlockDefs(const BasicBlock *BB) const { + auto It = PerBlockDefs.find(BB); + return It == PerBlockDefs.end() ? nullptr : It->second.get(); } + + private: class CachingWalker; class OptimizeUses; @@ -639,6 +691,7 @@ void verifyUseInDefs(MemoryAccess *, MemoryAccess *) const; using AccessMap = DenseMap>; + using DefsMap = DenseMap>; void determineInsertionPoint(const SmallPtrSetImpl &DefiningBlocks); @@ -656,6 +709,7 @@ void renamePass(DomTreeNode *, MemoryAccess *IncomingVal, SmallPtrSet &Visited); AccessList *getOrCreateAccessList(const BasicBlock *); + DefsList *getOrCreateDefsList(const BasicBlock *); void renumberBlock(const BasicBlock *) const; AliasAnalysis *AA; @@ -665,6 +719,7 @@ // Memory SSA mappings DenseMap ValueToMemoryAccess; AccessMap PerBlockAccesses; + DefsMap PerBlockDefs; std::unique_ptr LiveOnEntryDef; // Domination mappings Index: lib/Transforms/Utils/MemorySSA.cpp =================================================================== --- lib/Transforms/Utils/MemorySSA.cpp +++ lib/Transforms/Utils/MemorySSA.cpp @@ -1247,6 +1247,13 @@ Res.first->second = make_unique(); return Res.first->second.get(); } +MemorySSA::DefsList *MemorySSA::getOrCreateDefsList(const BasicBlock *BB) { + auto Res = PerBlockDefs.insert(std::make_pair(BB, nullptr)); + + if (Res.second) + Res.first->second = make_unique(); + return Res.first->second.get(); +} /// This class is a batch walker of all MemoryUse's in the program, and points /// their defining access at the thing that actually clobbers them. Because it @@ -1480,10 +1487,12 @@ for (auto &BB : IDFBlocks) { // Insert phi node AccessList *Accesses = getOrCreateAccessList(BB); + DefsList *Defs = getOrCreateDefsList(BB); MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); ValueToMemoryAccess[BB] = Phi; // Phi's always are placed at the front of the block. Accesses->push_front(Phi); + Defs->push_front(*Phi); } } @@ -1511,15 +1520,20 @@ BBNumbers[&B] = NextBBNum++; bool InsertIntoDef = false; AccessList *Accesses = nullptr; + DefsList *Defs = nullptr; for (Instruction &I : B) { MemoryUseOrDef *MUD = createNewAccess(&I); if (!MUD) continue; - InsertIntoDef |= isa(MUD); if (!Accesses) Accesses = getOrCreateAccessList(&B); Accesses->push_back(MUD); + if (isa(MUD)) { + InsertIntoDef |= true; + Defs = getOrCreateDefsList(&B); + Defs->push_back(*MUD); + } } if (InsertIntoDef) DefiningBlocks.insert(&B); @@ -1561,10 +1575,12 @@ MemoryPhi *MemorySSA::createMemoryPhi(BasicBlock *BB) { assert(!getMemoryAccess(BB) && "MemoryPhi already exists for this BB"); AccessList *Accesses = getOrCreateAccessList(BB); + DefsList *Defs = getOrCreateDefsList(BB); MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); ValueToMemoryAccess[BB] = Phi; // Phi's always are placed at the front of the block. Accesses->push_front(Phi); + Defs->push_front(*Phi); BlockNumberingValid.erase(BB); return Phi; } @@ -1592,8 +1608,18 @@ *Accesses, [](const MemoryAccess &MA) { return !isa(MA); }); Accesses->insert(AI, NewAccess); + if (!isa(NewAccess)) { + auto *Defs = getOrCreateDefsList(BB); + Defs->push_front(*NewAccess); + + } } else { Accesses->push_back(NewAccess); + if (!isa(NewAccess)) { + auto *Defs = getOrCreateDefsList(BB); + Defs->push_back(*NewAccess); + + } } BlockNumberingValid.erase(BB); return NewAccess; @@ -1605,9 +1631,18 @@ assert(I->getParent() == InsertPt->getBlock() && "New and old access must be in the same block"); MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); - auto *Accesses = getOrCreateAccessList(InsertPt->getBlock()); + // Should already exist, since insertpt is in that block. + auto *Accesses = getWritableBlockAccesses(InsertPt->getBlock()); Accesses->insert(AccessList::iterator(InsertPt), NewAccess); BlockNumberingValid.erase(InsertPt->getBlock()); + if (!isa(NewAccess)) { + // May not exist, since we are not guaranteed insertion point is a def. + auto *Defs = getOrCreateDefsList(InsertPt->getBlock()); + auto DefsPoint = + isa(InsertPt) ? InsertPt->getDefsIterator() : Defs->end(); + Defs->insert(DefsPoint, *NewAccess); + } + return NewAccess; } @@ -1619,7 +1654,15 @@ MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); auto *Accesses = getOrCreateAccessList(InsertPt->getBlock()); Accesses->insertAfter(AccessList::iterator(InsertPt), NewAccess); + if (!isa(NewAccess)) { + // May not exist, since we are not guaranteed insertion point is a def. + auto *Defs = getOrCreateDefsList(InsertPt->getBlock()); + auto DefsPoint = isa(InsertPt) ? ++(InsertPt->getDefsIterator()) + : Defs->end(); + Defs->insert(DefsPoint, *NewAccess); + } BlockNumberingValid.erase(InsertPt->getBlock()); + return NewAccess; } @@ -1761,6 +1804,16 @@ if (VMA->second == MA) ValueToMemoryAccess.erase(VMA); + // The access list owns the reference, so we erase it from the non-owning list + // first. + if (!isa(MA)) { + auto DefsIt = PerBlockDefs.find(MA->getBlock()); + std::unique_ptr &Defs = DefsIt->second; + Defs->remove(*MA); + if (Defs->empty()) + PerBlockDefs.erase(DefsIt); + } + auto AccessIt = PerBlockAccesses.find(MA->getBlock()); std::unique_ptr &Accesses = AccessIt->second; Accesses->erase(MA); @@ -1838,26 +1891,38 @@ // lists think, as well as the order in the blocks vs the order in the access // lists. SmallVector ActualAccesses; + SmallVector ActualDefs; for (BasicBlock &B : F) { const AccessList *AL = getBlockAccesses(&B); + const auto *DL = getBlockDefs(&B); MemoryAccess *Phi = getMemoryAccess(&B); - if (Phi) + if (Phi) { ActualAccesses.push_back(Phi); + ActualDefs.push_back(Phi); + } + for (Instruction &I : B) { MemoryAccess *MA = getMemoryAccess(&I); - assert((!MA || AL) && "We have memory affecting instructions " - "in this block but they are not in the " - "access list"); - if (MA) + assert((!MA || (AL && DL)) && "We have memory affecting instructions " + "in this block but they are not in the " + "access list or defs list"); + if (MA) { ActualAccesses.push_back(MA); + if (isa(MA)) + ActualDefs.push_back(MA); + } } // Either we hit the assert, really have no accesses, or we have both - // accesses and an access list - if (!AL) + // accesses and an access list. + // Same with defs. + if (!AL && !DL) continue; assert(AL->size() == ActualAccesses.size() && "We don't have the same number of accesses in the block as on the " "access list"); + assert(DL->size() == ActualDefs.size() && + "We don't have the same number of defs in the block as on the " + "def list"); auto ALI = AL->begin(); auto AAI = ActualAccesses.begin(); while (ALI != AL->end() && AAI != ActualAccesses.end()) { @@ -1866,6 +1931,15 @@ ++AAI; } ActualAccesses.clear(); + + auto DLI = DL->begin(); + auto ADI = ActualDefs.begin(); + while (DLI != DL->end() && ADI != ActualDefs.end()) { + assert(&*DLI == *ADI && "Not the same defs in the same order"); + ++DLI; + ++ADI; + } + ActualDefs.clear(); } } @@ -2302,4 +2376,5 @@ return Use->getDefiningAccess(); return StartingAccess; } + } // namespace llvm