Index: include/llvm/Transforms/Utils/MemorySSA.h =================================================================== --- include/llvm/Transforms/Utils/MemorySSA.h +++ include/llvm/Transforms/Utils/MemorySSA.h @@ -689,6 +689,7 @@ // for moves. It does not always leave the IR in a correct state, and relies // on the updater to fixup what it breaks, so it is not public. void moveTo(MemoryUseOrDef *What, BasicBlock *BB, AccessList::iterator Where); + void moveTo(MemoryUseOrDef *What, BasicBlock *BB, InsertionPlace Point); private: class CachingWalker; Index: include/llvm/Transforms/Utils/MemorySSAUpdater.h =================================================================== --- include/llvm/Transforms/Utils/MemorySSAUpdater.h +++ include/llvm/Transforms/Utils/MemorySSAUpdater.h @@ -24,10 +24,9 @@ // That's it. // // For moving, first, move the instruction itself using the normal SSA -// instruction moving API, then just call moveBefore or moveAfter with the right -// arguments. +// instruction moving API, then just call moveBefore, moveAfter,or moveTo with +// the right arguments. // -// walk memory instructions using a use/def graph. //===----------------------------------------------------------------------===// #ifndef LLVM_TRANSFORMS_UTILS_MEMORYSSAUPDATER_H @@ -68,10 +67,13 @@ void insertUse(MemoryUse *Use); void moveBefore(MemoryUseOrDef *What, MemoryUseOrDef *Where); void moveAfter(MemoryUseOrDef *What, MemoryUseOrDef *Where); - + void moveToPlace(MemoryUseOrDef *What, BasicBlock *BB, + MemorySSA::InsertionPlace Where); private: + // Move What before Where in the MemorySSA IR. + template void moveTo(MemoryUseOrDef *What, BasicBlock *BB, - MemorySSA::AccessList::iterator Where); + WhereType Where); MemoryAccess *getPreviousDef(MemoryAccess *); MemoryAccess *getPreviousDefInBlock(MemoryAccess *); MemoryAccess *getPreviousDefFromEnd(BasicBlock *); Index: lib/Transforms/Utils/MemorySSAUpdater.cpp =================================================================== --- lib/Transforms/Utils/MemorySSAUpdater.cpp +++ lib/Transforms/Utils/MemorySSAUpdater.cpp @@ -349,8 +349,9 @@ } // Move What before Where in the MemorySSA IR. +template void MemorySSAUpdater::moveTo(MemoryUseOrDef *What, BasicBlock *BB, - MemorySSA::AccessList::iterator Where) { + WhereType Where) { // Replace all our users with our defining access. What->replaceAllUsesWith(What->getDefiningAccess()); @@ -363,6 +364,7 @@ else insertUse(cast(What)); } + // Move What before Where in the MemorySSA IR. void MemorySSAUpdater::moveBefore(MemoryUseOrDef *What, MemoryUseOrDef *Where) { moveTo(What, Where->getBlock(), Where->getIterator()); @@ -373,4 +375,8 @@ moveTo(What, Where->getBlock(), ++Where->getIterator()); } +void MemorySSAUpdater::moveToPlace(MemoryUseOrDef *What, BasicBlock *BB, + MemorySSA::InsertionPlace Where) { + return moveTo(What, BB, Where); +} } // namespace llvm Index: unittests/Transforms/Utils/MemorySSA.cpp =================================================================== --- unittests/Transforms/Utils/MemorySSA.cpp +++ unittests/Transforms/Utils/MemorySSA.cpp @@ -365,6 +365,62 @@ MSSA.verifyMemorySSA(); } +TEST_F(MemorySSATest, MoveAStoreAllAround) { + // We create a diamond where there is a in the entry, a store on one side, and + // a load at the end. After building MemorySSA, we test updating by moving + // the store from the side block to the entry block, then to the other side + // block, then to before the load. This does not destroy the old access. + F = Function::Create( + FunctionType::get(B.getVoidTy(), {B.getInt8PtrTy()}, false), + GlobalValue::ExternalLinkage, "F", &M); + BasicBlock *Entry(BasicBlock::Create(C, "", F)); + BasicBlock *Left(BasicBlock::Create(C, "", F)); + BasicBlock *Right(BasicBlock::Create(C, "", F)); + BasicBlock *Merge(BasicBlock::Create(C, "", F)); + B.SetInsertPoint(Entry); + Argument *PointerArg = &*F->arg_begin(); + StoreInst *EntryStore = B.CreateStore(B.getInt8(16), PointerArg); + B.CreateCondBr(B.getTrue(), Left, Right); + B.SetInsertPoint(Left); + auto *SideStore = B.CreateStore(B.getInt8(16), PointerArg); + BranchInst::Create(Merge, Left); + BranchInst::Create(Merge, Right); + B.SetInsertPoint(Merge); + auto *MergeLoad = B.CreateLoad(PointerArg); + setupAnalyses(); + MemorySSA &MSSA = *Analyses->MSSA; + MemorySSAUpdater Updater(&MSSA); + + // Move the store + auto *EntryStoreAccess = MSSA.getMemoryAccess(EntryStore); + auto *SideStoreAccess = MSSA.getMemoryAccess(SideStore); + // Before, the load will point to a phi of the EntryStore and SideStore. + auto *LoadAccess = cast(MSSA.getMemoryAccess(MergeLoad)); + EXPECT_TRUE(isa(LoadAccess->getDefiningAccess())); + MemoryPhi *MergePhi = cast(LoadAccess->getDefiningAccess()); + EXPECT_EQ(MergePhi->getIncomingValue(1), EntryStoreAccess); + EXPECT_EQ(MergePhi->getIncomingValue(0), SideStoreAccess); + // Move the store before the entry store + SideStore->moveBefore(*EntryStore->getParent(), EntryStore->getIterator()); + Updater.moveBefore(SideStoreAccess, EntryStoreAccess); + // After, it's a phi of the entry store. + EXPECT_EQ(MergePhi->getIncomingValue(0), EntryStoreAccess); + EXPECT_EQ(MergePhi->getIncomingValue(1), EntryStoreAccess); + MSSA.verifyMemorySSA(); + // Now move the store to the right branch + SideStore->moveBefore(*Right, Right->begin()); + Updater.moveToPlace(SideStoreAccess, Right, MemorySSA::Beginning); + MSSA.verifyMemorySSA(); + EXPECT_EQ(MergePhi->getIncomingValue(0), EntryStoreAccess); + EXPECT_EQ(MergePhi->getIncomingValue(1), SideStoreAccess); + // Now move it before the load + SideStore->moveBefore(MergeLoad); + Updater.moveBefore(SideStoreAccess, LoadAccess); + EXPECT_EQ(MergePhi->getIncomingValue(0), EntryStoreAccess); + EXPECT_EQ(MergePhi->getIncomingValue(1), EntryStoreAccess); + MSSA.verifyMemorySSA(); +} + TEST_F(MemorySSATest, RemoveAPhi) { // We create a diamond where there is a store on one side, and then a load // after the merge point. This enables us to test a bunch of different