diff --git a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h --- a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h +++ b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h @@ -30,6 +30,8 @@ class MemCpyInst; class MemMoveInst; class MemoryDependenceResults; +class MemorySSA; +class MemorySSAUpdater; class MemSetInst; class StoreInst; class TargetLibraryInfo; @@ -38,6 +40,7 @@ class MemCpyOptPass : public PassInfoMixin { MemoryDependenceResults *MD = nullptr; TargetLibraryInfo *TLI = nullptr; + MemorySSAUpdater *MSSAU = nullptr; std::function LookupAliasAnalysis; std::function LookupAssumptionCache; std::function LookupDomTree; @@ -49,7 +52,7 @@ // Glue for the old PM. bool runImpl(Function &F, MemoryDependenceResults *MD_, - TargetLibraryInfo *TLI_, + TargetLibraryInfo *TLI_, MemorySSA *MSSA_, std::function LookupAliasAnalysis_, std::function LookupAssumptionCache_, std::function LookupDomTree_); diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -23,6 +23,8 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" @@ -276,6 +278,7 @@ AU.addRequired(); AU.addPreserved(); AU.addPreserved(); + AU.addPreserved(); } }; @@ -312,6 +315,7 @@ // are stored. MemsetRanges Ranges(DL); + Instruction *LastWrite = nullptr; BasicBlock::iterator BI(StartInst); for (++BI; !BI->isTerminator(); ++BI) { if (!isa(BI) && !isa(BI)) { @@ -341,6 +345,7 @@ break; Ranges.addStore(*Offset, NextStore); + LastWrite = NextStore; } else { MemSetInst *MSI = cast(BI); @@ -354,6 +359,7 @@ break; Ranges.addMemSet(*Offset, MSI); + LastWrite = MSI; } } @@ -375,6 +381,11 @@ // Now that we have full information about ranges, loop over the ranges and // emit memset's for anything big enough to be worthwhile. Instruction *AMemSet = nullptr; + + MemoryAccess *LastAcc = nullptr; + if (MSSAU) { + LastAcc = MSSAU->getMemorySSA()->getMemoryAccess(LastWrite); + } for (const MemsetRange &Range : Ranges) { if (Range.TheStores.size() == 1) continue; @@ -392,15 +403,27 @@ : Range.TheStores) dbgs() << *SI << '\n'; dbgs() << "With: " << *AMemSet << '\n'); - if (!Range.TheStores.empty()) AMemSet->setDebugLoc(Range.TheStores[0]->getDebugLoc()); + if (MSSAU) { + auto *NewAccess = + MSSAU->createMemoryAccessAfter(AMemSet, LastAcc, LastAcc); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + } + // Zap all the stores. for (Instruction *SI : Range.TheStores) { + if (MSSAU) { + if (isa(LastAcc) && + SI == cast(LastAcc)->getMemoryInst()) + LastAcc = cast(LastAcc)->getDefiningAccess(); + MSSAU->removeMemoryAccess(SI); + } MD->removeInstruction(SI); SI->eraseFromParent(); } + ++NumMemSetInfer; } @@ -572,6 +595,17 @@ LLVM_DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI << " => " << *M << "\n"); + if (MSSAU) { + assert(isa(MSSAU->getMemorySSA()->getMemoryAccess(P))); + auto *LastDef = + cast(MSSAU->getMemorySSA()->getMemoryAccess(P)); + auto *NewAccess = + MSSAU->createMemoryAccessAfter(M, LastDef, LastDef); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + MSSAU->removeMemoryAccess(SI); + MSSAU->removeMemoryAccess(LI); + } + MD->removeInstruction(SI); SI->eraseFromParent(); MD->removeInstruction(LI); @@ -621,6 +655,11 @@ DL.getTypeStoreSize(SI->getOperand(0)->getType()), commonAlignment(SI->getAlign(), LI->getAlign()), C); if (changed) { + if (MSSAU) { + MSSAU->removeMemoryAccess(SI); + MSSAU->removeMemoryAccess(LI); + } + MD->removeInstruction(SI); SI->eraseFromParent(); MD->removeInstruction(LI); @@ -658,6 +697,15 @@ LLVM_DEBUG(dbgs() << "Promoting " << *SI << " to " << *M << "\n"); + if (MSSAU) { + assert(isa(MSSAU->getMemorySSA()->getMemoryAccess(SI))); + auto *LastDef = + cast(MSSAU->getMemorySSA()->getMemoryAccess(SI)); + auto *NewAccess = MSSAU->createMemoryAccessAfter(M, LastDef, LastDef); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + MSSAU->removeMemoryAccess(SI); + } + MD->removeInstruction(SI); SI->eraseFromParent(); NumMemSetInfer++; @@ -879,6 +927,22 @@ // Remove the memcpy. MD->removeInstruction(cpy); + + if (MSSAU) { + MSSAU->removeMemoryAccess(C); + auto *Def = MSSAU->getMemorySSA()->getMemoryAccess(C)->getDefiningAccess(); + if (Def->getBlock() == C->getParent()) { + auto *NewAccess = MSSAU->createMemoryAccessAfter(C, Def, Def); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + } else { + auto *NewAccess = MSSAU->createMemoryAccessInBB( + C, nullptr, C->getParent(), MemorySSA::Beginning); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + } + + MSSAU->removeMemoryAccess(cpy); + } + ++NumMemCpyInstr; return true; @@ -943,14 +1007,23 @@ // TODO: Is this worth it if we're creating a less aligned memcpy? For // example we could be moving from movaps -> movq on x86. IRBuilder<> Builder(M); + Instruction *NewM; if (UseMemMove) - Builder.CreateMemMove(M->getRawDest(), M->getDestAlign(), - MDep->getRawSource(), MDep->getSourceAlign(), - M->getLength(), M->isVolatile()); + NewM = Builder.CreateMemMove(M->getRawDest(), M->getDestAlign(), + MDep->getRawSource(), MDep->getSourceAlign(), + M->getLength(), M->isVolatile()); else - Builder.CreateMemCpy(M->getRawDest(), M->getDestAlign(), - MDep->getRawSource(), MDep->getSourceAlign(), - M->getLength(), M->isVolatile()); + NewM = Builder.CreateMemCpy(M->getRawDest(), M->getDestAlign(), + MDep->getRawSource(), MDep->getSourceAlign(), + M->getLength(), M->isVolatile()); + + if (MSSAU) { + assert(isa(MSSAU->getMemorySSA()->getMemoryAccess(M))); + auto *LastDef = cast(MSSAU->getMemorySSA()->getMemoryAccess(M)); + auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + MSSAU->removeMemoryAccess(M); + } // Remove the instruction we're replacing. MD->removeInstruction(M); @@ -1016,11 +1089,20 @@ Value *SizeDiff = Builder.CreateSub(DestSize, SrcSize); Value *MemsetLen = Builder.CreateSelect( Ule, ConstantInt::getNullValue(DestSize->getType()), SizeDiff); - Builder.CreateMemSet( + Instruction *NewM = Builder.CreateMemSet( Builder.CreateGEP(Dest->getType()->getPointerElementType(), Dest, SrcSize), MemSet->getOperand(1), MemsetLen, MaybeAlign(Align)); + if (MSSAU) { + assert(isa(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy))); + auto *LastDef = + cast(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)); + auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + MSSAU->removeMemoryAccess(MemSet); + } + MD->removeInstruction(MemSet); MemSet->eraseFromParent(); return true; @@ -1087,8 +1169,16 @@ } IRBuilder<> Builder(MemCpy); - Builder.CreateMemSet(MemCpy->getRawDest(), MemSet->getOperand(1), CopySize, - MaybeAlign(MemCpy->getDestAlignment())); + Instruction *NewM = + Builder.CreateMemSet(MemCpy->getRawDest(), MemSet->getOperand(1), + CopySize, MaybeAlign(MemCpy->getDestAlignment())); + if (MSSAU) { + auto *LastDef = + cast(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)); + auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + } + return true; } @@ -1104,6 +1194,9 @@ // If the source and destination of the memcpy are the same, then zap it. if (M->getSource() == M->getDest()) { ++BBI; + if (MSSAU) + MSSAU->removeMemoryAccess(M); + MD->removeInstruction(M); M->eraseFromParent(); return true; @@ -1115,8 +1208,18 @@ if (Value *ByteVal = isBytewiseValue(GV->getInitializer(), M->getModule()->getDataLayout())) { IRBuilder<> Builder(M); - Builder.CreateMemSet(M->getRawDest(), ByteVal, M->getLength(), - MaybeAlign(M->getDestAlignment()), false); + Instruction *NewM = + Builder.CreateMemSet(M->getRawDest(), ByteVal, M->getLength(), + MaybeAlign(M->getDestAlignment()), false); + if (MSSAU) { + auto *LastDef = + cast(MSSAU->getMemorySSA()->getMemoryAccess(M)); + auto *NewAccess = + MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + MSSAU->removeMemoryAccess(M); + } + MD->removeInstruction(M); M->eraseFromParent(); ++NumCpyToSet; @@ -1151,6 +1254,9 @@ M->getSourceAlign().valueOrOne()); if (performCallSlotOptzn(M, M->getDest(), M->getSource(), CopySize->getZExtValue(), Alignment, C)) { + if (MSSAU) + MSSAU->removeMemoryAccess(M); + MD->removeInstruction(M); M->eraseFromParent(); return true; @@ -1167,6 +1273,9 @@ return processMemCpyMemCpyDependence(M, MDep); } else if (SrcDepInfo.isDef()) { if (hasUndefContents(SrcDepInfo.getInst(), CopySize)) { + if (MSSAU) + MSSAU->removeMemoryAccess(M); + MD->removeInstruction(M); M->eraseFromParent(); ++NumMemCpyInstr; @@ -1177,6 +1286,8 @@ if (SrcDepInfo.isClobber()) if (MemSetInst *MDep = dyn_cast(SrcDepInfo.getInst())) if (performMemCpyToMemSetOptzn(M, MDep)) { + if (MSSAU) + MSSAU->removeMemoryAccess(M); MD->removeInstruction(M); M->eraseFromParent(); ++NumCpyToSet; @@ -1209,6 +1320,20 @@ M->setCalledFunction(Intrinsic::getDeclaration(M->getModule(), Intrinsic::memcpy, ArgTys)); + if (MSSAU) { + auto *Def = MSSAU->getMemorySSA()->getMemoryAccess(M)->getDefiningAccess(); + if (Def->getBlock() == M->getParent()) { + auto *NewAccess = MSSAU->createMemoryAccessAfter(M, Def, Def); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + } else { + auto *NewAccess = MSSAU->createMemoryAccessInBB( + M, nullptr, M->getParent(), MemorySSA::Beginning); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + } + + MSSAU->removeMemoryAccess(M); + } + // MemDep may have over conservative information about this instruction, just // conservatively flush it from the cache. MD->removeInstruction(M); @@ -1358,8 +1483,11 @@ return AM.getResult(F); }; - bool MadeChange = runImpl(F, &MD, &TLI, LookupAliasAnalysis, - LookupAssumptionCache, LookupDomTree); + auto *MSSA = AM.getCachedResult(F); + + bool MadeChange = + runImpl(F, &MD, &TLI, MSSA ? &MSSA->getMSSA() : nullptr, + LookupAliasAnalysis, LookupAssumptionCache, LookupDomTree); if (!MadeChange) return PreservedAnalyses::all(); @@ -1372,12 +1500,14 @@ bool MemCpyOptPass::runImpl( Function &F, MemoryDependenceResults *MD_, TargetLibraryInfo *TLI_, - std::function LookupAliasAnalysis_, + MemorySSA *MSSA_, std::function LookupAliasAnalysis_, std::function LookupAssumptionCache_, std::function LookupDomTree_) { + MemorySSAUpdater MSSAU_(MSSA_); bool MadeChange = false; MD = MD_; TLI = TLI_; + MSSAU = MSSA_ ? &MSSAU_ : nullptr; LookupAliasAnalysis = std::move(LookupAliasAnalysis_); LookupAssumptionCache = std::move(LookupAssumptionCache_); LookupDomTree = std::move(LookupDomTree_); @@ -1416,6 +1546,9 @@ return getAnalysis().getDomTree(); }; - return Impl.runImpl(F, MD, TLI, LookupAliasAnalysis, LookupAssumptionCache, + auto *MSSAWP = getAnalysisIfAvailable(); + + return Impl.runImpl(F, MD, TLI, MSSAWP ? &MSSAWP->getMSSA() : nullptr, + LookupAliasAnalysis, LookupAssumptionCache, LookupDomTree); }