Index: include/llvm/Transforms/Utils/MemorySSA.h =================================================================== --- include/llvm/Transforms/Utils/MemorySSA.h +++ include/llvm/Transforms/Utils/MemorySSA.h @@ -398,8 +398,6 @@ MemoryAccess *getIncomingValue(unsigned I) const { return getOperand(I); } void setIncomingValue(unsigned I, MemoryAccess *V) { assert(V && "PHI node got a null value!"); - assert(getType() == V->getType() && - "All operands to PHI node must be the same type as the PHI node!"); setOperand(I, V); } static unsigned getOperandNumForIncomingValue(unsigned I) { return I; } @@ -536,6 +534,40 @@ return It == PerBlockAccesses.end() ? nullptr : It->second.get(); } + /// \brief Create an empty MemoryPhi in MemorySSA + MemoryPhi *createMemoryPhi(BasicBlock *BB); + + enum InsertionPlace { Beginning, End }; + + /// \brief Create a MemoryAccess in MemorySSA at a specified point in a block, + /// with a specified clobbering definition. + /// + /// Returns the new MemoryAccess. + /// This should be called when a memory instruction is created that is being + /// used to replace an existing memory instruction. It will *not* create PHI + /// nodes, or verify the clobbering definition. The insertion place is used + /// solely to determine where in the memoryssa access lists the instruction + /// will be placed. The caller is expected to keep ordering the same as + /// instructions. + /// It will return the new MemoryAccess. + MemoryAccess *createMemoryAccessInBB(Instruction *I, MemoryAccess *Definition, + const BasicBlock *BB, + InsertionPlace Point); + /// \brief Create a MemoryAccess in MemorySSA before or after an existing + /// MemoryAccess. + /// + /// Returns the new MemoryAccess. + /// This should be called when a memory instruction is created that is being + /// used to replace an existing memory instruction. It will *not* create PHI + /// nodes, or verify the clobbering definition. The clobbering definition + /// must be non-null. + MemoryAccess *createMemoryAccessBefore(Instruction *I, + MemoryAccess *Definition, + MemoryAccess *InsertPt); + MemoryAccess *createMemoryAccessAfter(Instruction *I, + MemoryAccess *Definition, + MemoryAccess *InsertPt); + /// \brief Remove a MemoryAccess from MemorySSA, including updating all /// definitions and uses. /// This should be called when a memory instruction that has a MemoryAccess @@ -544,8 +576,6 @@ /// on the MemoryAccess for that store/load. void removeMemoryAccess(MemoryAccess *); - enum InsertionPlace { Beginning, End }; - /// \brief Given two memory accesses in the same basic block, determine /// whether MemoryAccess \p A dominates MemoryAccess \p B. bool locallyDominates(const MemoryAccess *A, const MemoryAccess *B) const; @@ -560,6 +590,7 @@ friend class MemorySSAPrinterLegacyPass; void verifyDefUses(Function &F) const; void verifyDomination(Function &F) const; + void verifyOrdering(Function &F) const; private: void verifyUseInDefs(MemoryAccess *, MemoryAccess *) const; @@ -571,13 +602,14 @@ void markUnreachableAsLiveOnEntry(BasicBlock *BB); bool dominatesUse(const MemoryAccess *, const MemoryAccess *) const; MemoryUseOrDef *createNewAccess(Instruction *); + MemoryUseOrDef *createDefinedAccess(Instruction *, MemoryAccess *); MemoryAccess *findDominatingDef(BasicBlock *, enum InsertionPlace); void removeFromLookups(MemoryAccess *); MemoryAccess *renameBlock(BasicBlock *, MemoryAccess *); void renamePass(DomTreeNode *, MemoryAccess *IncomingVal, SmallPtrSet &Visited); - AccessList *getOrCreateAccessList(BasicBlock *); + AccessList *getOrCreateAccessList(const BasicBlock *); AliasAnalysis *AA; DominatorTree *DT; Function &F; Index: lib/Transforms/Utils/MemorySSA.cpp =================================================================== --- lib/Transforms/Utils/MemorySSA.cpp +++ lib/Transforms/Utils/MemorySSA.cpp @@ -226,7 +226,7 @@ MA.dropAllReferences(); } -MemorySSA::AccessList *MemorySSA::getOrCreateAccessList(BasicBlock *BB) { +MemorySSA::AccessList *MemorySSA::getOrCreateAccessList(const BasicBlock *BB) { auto Res = PerBlockAccesses.insert(std::make_pair(BB, nullptr)); if (Res.second) @@ -320,7 +320,7 @@ for (auto &BB : IDFBlocks) { // Insert phi node AccessList *Accesses = getOrCreateAccessList(BB); - MemoryPhi *Phi = new MemoryPhi(F.getContext(), BB, NextID++); + MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); ValueToMemoryAccess.insert(std::make_pair(BB, Phi)); // Phi's always are placed at the front of the block. Accesses->push_front(Phi); @@ -358,6 +358,68 @@ return Walker.get(); } +MemoryPhi *MemorySSA::createMemoryPhi(BasicBlock *BB) { + assert(!getMemoryAccess(BB) && "MemoryPhi already exists for this BB"); + AccessList *Accesses = getOrCreateAccessList(BB); + MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); + ValueToMemoryAccess.insert(std::make_pair(BB, Phi)); + // Phi's always are placed at the front of the block. + Accesses->push_front(Phi); + return Phi; +} + +MemoryUseOrDef *MemorySSA::createDefinedAccess(Instruction *I, + MemoryAccess *Definition) { + assert(!isa(I) && "Cannot create a defined access for a PHI"); + MemoryUseOrDef *NewAccess = createNewAccess(I); + assert( + NewAccess != nullptr && + "Tried to create a memory access for a non-memory touching instruction"); + NewAccess->setDefiningAccess(Definition); + return NewAccess; +} + +MemoryAccess *MemorySSA::createMemoryAccessInBB(Instruction *I, + MemoryAccess *Definition, + const BasicBlock *BB, + InsertionPlace Point) { + MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); + auto *Accesses = getOrCreateAccessList(BB); + if (Point == Beginning) { + // It goes after any phi nodes + auto AI = std::find_if( + Accesses->begin(), Accesses->end(), + [](const MemoryAccess &MA) { return !isa(MA); }); + + Accesses->insert(AI, NewAccess); + } else { + Accesses->push_back(NewAccess); + } + + return NewAccess; +} +MemoryAccess *MemorySSA::createMemoryAccessBefore(Instruction *I, + MemoryAccess *Definition, + MemoryAccess *InsertPt) { + assert(I->getParent() == InsertPt->getBlock() && + "New and old access must be in the same block"); + MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); + auto *Accesses = getOrCreateAccessList(InsertPt->getBlock()); + Accesses->insert(AccessList::iterator(InsertPt), NewAccess); + return NewAccess; +} + +MemoryAccess *MemorySSA::createMemoryAccessAfter(Instruction *I, + MemoryAccess *Definition, + MemoryAccess *InsertPt) { + assert(I->getParent() == InsertPt->getBlock() && + "New and old access must be in the same block"); + MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); + auto *Accesses = getOrCreateAccessList(InsertPt->getBlock()); + Accesses->insertAfter(AccessList::iterator(InsertPt), NewAccess); + return NewAccess; +} + /// \brief Helper function to create new memory accesses MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I) { // The assume intrinsic has a control dependency which we model by claiming @@ -518,6 +580,45 @@ void MemorySSA::verifyMemorySSA() const { verifyDefUses(F); verifyDomination(F); + verifyOrdering(F); +} + +/// \brief Verify that the order and existence of MemoryAccesses matches the +/// order and existence of memory affecting instructions. +void MemorySSA::verifyOrdering(Function &F) const { + // Walk all the blocks, comparing what the lookups think and what the access + // lists think, as well as the order in the blocks vs the order in the access + // lists. + SmallVector ActualAccesses; + for (BasicBlock &B : F) { + const AccessList *AL = getBlockAccesses(&B); + MemoryAccess *Phi = getMemoryAccess(&B); + if (Phi) + ActualAccesses.push_back(Phi); + for (Instruction &I : B) { + MemoryAccess *MA = getMemoryAccess(&I); + assert((!MA || AL) && "We have memory affecting instructions " + "in this block but they are not in the " + "access list"); + if (MA) + ActualAccesses.push_back(MA); + } + // Either we hit the assert, really have no accesses, or we have both + // accesses and an access list + if (!AL) + continue; + assert(AL->size() == ActualAccesses.size() && + "We don't have the same number of accesses in the block as on the " + "access list"); + auto ALI = AL->begin(); + auto AAI = ActualAccesses.begin(); + while (ALI != AL->end() && AAI != ActualAccesses.end()) { + assert(&*ALI == *AAI && "Not the same accesses in the same order"); + ++ALI; + ++AAI; + } + ActualAccesses.clear(); + } } /// \brief Verify the domination properties of MemorySSA by checking that each @@ -595,9 +696,13 @@ void MemorySSA::verifyDefUses(Function &F) const { for (BasicBlock &B : F) { // Phi nodes are attached to basic blocks - if (MemoryPhi *Phi = getMemoryAccess(&B)) + if (MemoryPhi *Phi = getMemoryAccess(&B)) { + assert(Phi->getNumOperands() == + std::distance(pred_begin(&B), pred_end(&B)) && + "Incomplete MemoryPhi Node"); for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) verifyUseInDefs(Phi->getIncomingValue(I), Phi); + } for (Instruction &I : B) { if (MemoryAccess *MA = getMemoryAccess(&I)) { Index: unittests/Transforms/Utils/MemorySSA.cpp =================================================================== --- unittests/Transforms/Utils/MemorySSA.cpp +++ unittests/Transforms/Utils/MemorySSA.cpp @@ -6,11 +6,11 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// -#include "llvm/IR/DataLayout.h" #include "llvm/Transforms/Utils/MemorySSA.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" @@ -65,6 +65,90 @@ : M("MemorySSATest", C), B(C), DL(DLString), TLI(TLII), F(nullptr) {} }; +TEST_F(MemorySSATest, CreateALoadAndPhi) { + // We create a diamond where there is a store on one side, and then after + // running memory ssa, create a load after the merge point, and use it to test + // updating by creating an access for the load and a memoryphi. + F = Function::Create( + FunctionType::get(B.getVoidTy(), {B.getInt8PtrTy()}, false), + GlobalValue::ExternalLinkage, "F", &M); + BasicBlock *Entry(BasicBlock::Create(C, "", F)); + BasicBlock *Left(BasicBlock::Create(C, "", F)); + BasicBlock *Right(BasicBlock::Create(C, "", F)); + BasicBlock *Merge(BasicBlock::Create(C, "", F)); + B.SetInsertPoint(Entry); + B.CreateCondBr(B.getTrue(), Left, Right); + B.SetInsertPoint(Left); + Argument *PointerArg = &*F->arg_begin(); + StoreInst *StoreInst = B.CreateStore(B.getInt8(16), PointerArg); + BranchInst::Create(Merge, Left); + BranchInst::Create(Merge, Right); + + setupAnalyses(); + MemorySSA &MSSA = Analyses->MSSA; + // Add the load + B.SetInsertPoint(Merge); + LoadInst *LoadInst = B.CreateLoad(PointerArg); + // Should be no phi to start + EXPECT_EQ(MSSA.getMemoryAccess(Merge), nullptr); + + // Create the phi + MemoryPhi *MP = MSSA.createMemoryPhi(Merge); + MemoryDef *StoreAccess = cast(MSSA.getMemoryAccess(StoreInst)); + MP->addIncoming(StoreAccess, Left); + MP->addIncoming(MSSA.getLiveOnEntryDef(), Right); + + // Create the load memory acccess + MemoryUse *LoadAccess = cast( + MSSA.createMemoryAccessInBB(LoadInst, MP, Merge, MemorySSA::Beginning)); + MemoryAccess *DefiningAccess = LoadAccess->getDefiningAccess(); + EXPECT_TRUE(isa(DefiningAccess)); + MSSA.verifyMemorySSA(); +} + +TEST_F(MemorySSATest, RemoveAPhi) { + // We create a diamond where there is a store on one side, and then a load + // after the merge point. This enables us to test a bunch of different + // removal cases. + F = Function::Create( + FunctionType::get(B.getVoidTy(), {B.getInt8PtrTy()}, false), + GlobalValue::ExternalLinkage, "F", &M); + BasicBlock *Entry(BasicBlock::Create(C, "", F)); + BasicBlock *Left(BasicBlock::Create(C, "", F)); + BasicBlock *Right(BasicBlock::Create(C, "", F)); + BasicBlock *Merge(BasicBlock::Create(C, "", F)); + B.SetInsertPoint(Entry); + B.CreateCondBr(B.getTrue(), Left, Right); + B.SetInsertPoint(Left); + Argument *PointerArg = &*F->arg_begin(); + StoreInst *StoreInst = B.CreateStore(B.getInt8(16), PointerArg); + BranchInst::Create(Merge, Left); + BranchInst::Create(Merge, Right); + B.SetInsertPoint(Merge); + LoadInst *LoadInst = B.CreateLoad(PointerArg); + + setupAnalyses(); + MemorySSA &MSSA = Analyses->MSSA; + // Before, the load will be a use of a phi. + MemoryUse *LoadAccess = cast(MSSA.getMemoryAccess(LoadInst)); + MemoryDef *StoreAccess = cast(MSSA.getMemoryAccess(StoreInst)); + MemoryAccess *DefiningAccess = LoadAccess->getDefiningAccess(); + EXPECT_TRUE(isa(DefiningAccess)); + // Kill the store + MSSA.removeMemoryAccess(StoreAccess); + MemoryPhi *MP = cast(DefiningAccess); + // Verify the phi ended up as liveonentry, liveonentry + for (auto &Op : MP->incoming_values()) + EXPECT_TRUE(MSSA.isLiveOnEntryDef(cast(Op.get()))); + // Replace the phi uses with the live on entry def + MP->replaceAllUsesWith(MSSA.getLiveOnEntryDef()); + // Verify the load is now defined by liveOnEntryDef + EXPECT_TRUE(MSSA.isLiveOnEntryDef(LoadAccess->getDefiningAccess())); + // Remove the PHI + MSSA.removeMemoryAccess(MP); + MSSA.verifyMemorySSA(); +} + TEST_F(MemorySSATest, RemoveMemoryAccess) { // We create a diamond where there is a store on one side, and then a load // after the merge point. This enables us to test a bunch of different @@ -136,9 +220,8 @@ // store i8 2, i8* %A // } TEST_F(MemorySSATest, TestTripleStore) { - F = Function::Create( - FunctionType::get(B.getVoidTy(), {}, false), - GlobalValue::ExternalLinkage, "F", &M); + F = Function::Create(FunctionType::get(B.getVoidTy(), {}, false), + GlobalValue::ExternalLinkage, "F", &M); B.SetInsertPoint(BasicBlock::Create(C, "", F)); Type *Int8 = Type::getInt8Ty(C); Value *Alloca = B.CreateAlloca(Int8, ConstantInt::get(Int8, 1), "A"); @@ -169,9 +252,8 @@ // mostly redundant) unless the initial node being walked is a clobber for the // query. In that case, we'd cache that the node clobbered itself. TEST_F(MemorySSATest, TestStoreAndLoad) { - F = Function::Create( - FunctionType::get(B.getVoidTy(), {}, false), - GlobalValue::ExternalLinkage, "F", &M); + F = Function::Create(FunctionType::get(B.getVoidTy(), {}, false), + GlobalValue::ExternalLinkage, "F", &M); B.SetInsertPoint(BasicBlock::Create(C, "", F)); Type *Int8 = Type::getInt8Ty(C); Value *Alloca = B.CreateAlloca(Int8, ConstantInt::get(Int8, 1), "A"); @@ -200,9 +282,8 @@ // This test checks that repeated calls to either function returns what they're // meant to. TEST_F(MemorySSATest, TestStoreDoubleQuery) { - F = Function::Create( - FunctionType::get(B.getVoidTy(), {}, false), - GlobalValue::ExternalLinkage, "F", &M); + F = Function::Create(FunctionType::get(B.getVoidTy(), {}, false), + GlobalValue::ExternalLinkage, "F", &M); B.SetInsertPoint(BasicBlock::Create(C, "", F)); Type *Int8 = Type::getInt8Ty(C); Value *Alloca = B.CreateAlloca(Int8, ConstantInt::get(Int8, 1), "A"); @@ -214,7 +295,8 @@ MemoryAccess *StoreAccess = MSSA.getMemoryAccess(SI); MemoryLocation StoreLoc = MemoryLocation::get(SI); - MemoryAccess *Clobber = Walker->getClobberingMemoryAccess(StoreAccess, StoreLoc); + MemoryAccess *Clobber = + Walker->getClobberingMemoryAccess(StoreAccess, StoreLoc); MemoryAccess *LiveOnEntry = Walker->getClobberingMemoryAccess(SI); EXPECT_EQ(Clobber, StoreAccess);