diff --git a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h --- a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h +++ b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h @@ -30,6 +30,8 @@ class MemCpyInst; class MemMoveInst; class MemoryDependenceResults; +class MemorySSA; +class MemorySSAUpdater; class MemSetInst; class StoreInst; class TargetLibraryInfo; @@ -41,6 +43,7 @@ AliasAnalysis *AA = nullptr; AssumptionCache *AC = nullptr; DominatorTree *DT = nullptr; + MemorySSAUpdater *MSSAU = nullptr; public: MemCpyOptPass() = default; @@ -50,7 +53,7 @@ // Glue for the old PM. bool runImpl(Function &F, MemoryDependenceResults *MD_, TargetLibraryInfo *TLI_, AliasAnalysis *AA_, - AssumptionCache *AC_, DominatorTree *DT_); + AssumptionCache *AC_, DominatorTree *DT_, MemorySSA *MSSA_); private: // Helper functions diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -23,6 +23,8 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" @@ -278,6 +280,7 @@ AU.addPreserved(); AU.addRequired(); AU.addPreserved(); + AU.addPreserved(); } }; @@ -315,7 +318,27 @@ MemsetRanges Ranges(DL); BasicBlock::iterator BI(StartInst); + + // Keeps track of the last memory use or def before the insertion point for + // the new memset. The new MemoryDef for the inserted memsets will be inserted + // after MemInsertPoint. It points to either LastMemDef or to the last user + // before the insertion point of the memset, if there are any such users. + MemoryUseOrDef *MemInsertPoint = nullptr; + // Keeps track of the last MemoryDef between StartInst and the insertion point + // for the new memset. This will become the defining access of the inserted + // memsets. + MemoryDef *LastMemDef = nullptr; for (++BI; !BI->isTerminator(); ++BI) { + if (MSSAU) { + auto *CurrentAcc = cast_or_null( + MSSAU->getMemorySSA()->getMemoryAccess(&*BI)); + if (CurrentAcc) { + MemInsertPoint = CurrentAcc; + if (auto *CurrentDef = dyn_cast(CurrentAcc)) + LastMemDef = CurrentDef; + } + } + if (!isa(BI) && !isa(BI)) { // If the instruction is readnone, ignore it, otherwise bail out. We // don't even allow readonly here because we don't want something like: @@ -394,15 +417,27 @@ : Range.TheStores) dbgs() << *SI << '\n'; dbgs() << "With: " << *AMemSet << '\n'); - if (!Range.TheStores.empty()) AMemSet->setDebugLoc(Range.TheStores[0]->getDebugLoc()); + if (MSSAU) { + assert(LastMemDef && MemInsertPoint && + "Both LastMemDef and MemInsertPoint need to be set"); + auto *NewDef = cast( + MSSAU->createMemoryAccessAfter(AMemSet, LastMemDef, MemInsertPoint)); + MSSAU->insertDef(NewDef, /*RenameUses=*/true); + LastMemDef = NewDef; + MemInsertPoint = NewDef; + } + // Zap all the stores. for (Instruction *SI : Range.TheStores) { + if (MSSAU) + MSSAU->removeMemoryAccess(SI); MD->removeInstruction(SI); SI->eraseFromParent(); } + ++NumMemSetInfer; } @@ -573,6 +608,17 @@ LLVM_DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI << " => " << *M << "\n"); + if (MSSAU) { + assert(isa(MSSAU->getMemorySSA()->getMemoryAccess(P))); + auto *LastDef = + cast(MSSAU->getMemorySSA()->getMemoryAccess(P)); + auto *NewAccess = + MSSAU->createMemoryAccessAfter(M, LastDef, LastDef); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + MSSAU->removeMemoryAccess(SI); + MSSAU->removeMemoryAccess(LI); + } + MD->removeInstruction(SI); SI->eraseFromParent(); MD->removeInstruction(LI); @@ -621,6 +667,11 @@ DL.getTypeStoreSize(SI->getOperand(0)->getType()), commonAlignment(SI->getAlign(), LI->getAlign()), C); if (changed) { + if (MSSAU) { + MSSAU->removeMemoryAccess(SI); + MSSAU->removeMemoryAccess(LI); + } + MD->removeInstruction(SI); SI->eraseFromParent(); MD->removeInstruction(LI); @@ -658,6 +709,15 @@ LLVM_DEBUG(dbgs() << "Promoting " << *SI << " to " << *M << "\n"); + if (MSSAU) { + assert(isa(MSSAU->getMemorySSA()->getMemoryAccess(SI))); + auto *LastDef = + cast(MSSAU->getMemorySSA()->getMemoryAccess(SI)); + auto *NewAccess = MSSAU->createMemoryAccessAfter(M, LastDef, LastDef); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + MSSAU->removeMemoryAccess(SI); + } + MD->removeInstruction(SI); SI->eraseFromParent(); NumMemSetInfer++; @@ -939,14 +999,23 @@ // TODO: Is this worth it if we're creating a less aligned memcpy? For // example we could be moving from movaps -> movq on x86. IRBuilder<> Builder(M); + Instruction *NewM; if (UseMemMove) - Builder.CreateMemMove(M->getRawDest(), M->getDestAlign(), - MDep->getRawSource(), MDep->getSourceAlign(), - M->getLength(), M->isVolatile()); + NewM = Builder.CreateMemMove(M->getRawDest(), M->getDestAlign(), + MDep->getRawSource(), MDep->getSourceAlign(), + M->getLength(), M->isVolatile()); else - Builder.CreateMemCpy(M->getRawDest(), M->getDestAlign(), - MDep->getRawSource(), MDep->getSourceAlign(), - M->getLength(), M->isVolatile()); + NewM = Builder.CreateMemCpy(M->getRawDest(), M->getDestAlign(), + MDep->getRawSource(), MDep->getSourceAlign(), + M->getLength(), M->isVolatile()); + + if (MSSAU) { + assert(isa(MSSAU->getMemorySSA()->getMemoryAccess(M))); + auto *LastDef = cast(MSSAU->getMemorySSA()->getMemoryAccess(M)); + auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + MSSAU->removeMemoryAccess(M); + } // Remove the instruction we're replacing. MD->removeInstruction(M); @@ -1012,11 +1081,25 @@ Value *SizeDiff = Builder.CreateSub(DestSize, SrcSize); Value *MemsetLen = Builder.CreateSelect( Ule, ConstantInt::getNullValue(DestSize->getType()), SizeDiff); - Builder.CreateMemSet( + Instruction *NewMemSet = Builder.CreateMemSet( Builder.CreateGEP(Dest->getType()->getPointerElementType(), Dest, SrcSize), MemSet->getOperand(1), MemsetLen, MaybeAlign(Align)); + if (MSSAU) { + assert(isa(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)) && + "MemCpy must be a MemoryDef"); + // The new memset is inserted after the memcpy, but it is known that its + // defining access is the memset about to be removed which immediately + // precedes the memcpy. + auto *LastDef = + cast(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)); + auto *NewAccess = MSSAU->createMemoryAccessBefore( + NewMemSet, LastDef->getDefiningAccess(), LastDef); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + MSSAU->removeMemoryAccess(MemSet); + } + MD->removeInstruction(MemSet); MemSet->eraseFromParent(); return true; @@ -1081,8 +1164,16 @@ } IRBuilder<> Builder(MemCpy); - Builder.CreateMemSet(MemCpy->getRawDest(), MemSet->getOperand(1), CopySize, - MaybeAlign(MemCpy->getDestAlignment())); + Instruction *NewM = + Builder.CreateMemSet(MemCpy->getRawDest(), MemSet->getOperand(1), + CopySize, MaybeAlign(MemCpy->getDestAlignment())); + if (MSSAU) { + auto *LastDef = + cast(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)); + auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + } + return true; } @@ -1098,6 +1189,9 @@ // If the source and destination of the memcpy are the same, then zap it. if (M->getSource() == M->getDest()) { ++BBI; + if (MSSAU) + MSSAU->removeMemoryAccess(M); + MD->removeInstruction(M); M->eraseFromParent(); return true; @@ -1109,8 +1203,18 @@ if (Value *ByteVal = isBytewiseValue(GV->getInitializer(), M->getModule()->getDataLayout())) { IRBuilder<> Builder(M); - Builder.CreateMemSet(M->getRawDest(), ByteVal, M->getLength(), - MaybeAlign(M->getDestAlignment()), false); + Instruction *NewM = + Builder.CreateMemSet(M->getRawDest(), ByteVal, M->getLength(), + MaybeAlign(M->getDestAlignment()), false); + if (MSSAU) { + auto *LastDef = + cast(MSSAU->getMemorySSA()->getMemoryAccess(M)); + auto *NewAccess = + MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); + MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); + MSSAU->removeMemoryAccess(M); + } + MD->removeInstruction(M); M->eraseFromParent(); ++NumCpyToSet; @@ -1145,6 +1249,9 @@ M->getSourceAlign().valueOrOne()); if (performCallSlotOptzn(M, M->getDest(), M->getSource(), CopySize->getZExtValue(), Alignment, C)) { + if (MSSAU) + MSSAU->removeMemoryAccess(M); + MD->removeInstruction(M); M->eraseFromParent(); return true; @@ -1161,6 +1268,9 @@ return processMemCpyMemCpyDependence(M, MDep); } else if (SrcDepInfo.isDef()) { if (hasUndefContents(SrcDepInfo.getInst(), CopySize)) { + if (MSSAU) + MSSAU->removeMemoryAccess(M); + MD->removeInstruction(M); M->eraseFromParent(); ++NumMemCpyInstr; @@ -1171,6 +1281,8 @@ if (SrcDepInfo.isClobber()) if (MemSetInst *MDep = dyn_cast(SrcDepInfo.getInst())) if (performMemCpyToMemSetOptzn(M, MDep)) { + if (MSSAU) + MSSAU->removeMemoryAccess(M); MD->removeInstruction(M); M->eraseFromParent(); ++NumCpyToSet; @@ -1201,6 +1313,9 @@ M->setCalledFunction(Intrinsic::getDeclaration(M->getModule(), Intrinsic::memcpy, ArgTys)); + // For MemorySSA nothing really changes (except that memcpy may imply stricter + // aliasing guarantees). + // MemDep may have over conservative information about this instruction, just // conservatively flush it from the cache. MD->removeInstruction(M); @@ -1338,8 +1453,10 @@ auto *AA = &AM.getResult(F); auto *AC = &AM.getResult(F); auto *DT = &AM.getResult(F); + auto *MSSA = AM.getCachedResult(F); - bool MadeChange = runImpl(F, &MD, &TLI, AA, AC, DT); + bool MadeChange = + runImpl(F, &MD, &TLI, AA, AC, DT, MSSA ? &MSSA->getMSSA() : nullptr); if (!MadeChange) return PreservedAnalyses::all(); @@ -1347,18 +1464,23 @@ PA.preserveSet(); PA.preserve(); PA.preserve(); + if (MSSA) + PA.preserve(); return PA; } bool MemCpyOptPass::runImpl(Function &F, MemoryDependenceResults *MD_, TargetLibraryInfo *TLI_, AliasAnalysis *AA_, - AssumptionCache *AC_, DominatorTree *DT_) { + AssumptionCache *AC_, DominatorTree *DT_, + MemorySSA *MSSA_) { bool MadeChange = false; MD = MD_; TLI = TLI_; AA = AA_; AC = AC_; DT = DT_; + MemorySSAUpdater MSSAU_(MSSA_); + MSSAU = MSSA_ ? &MSSAU_ : nullptr; // If we don't have at least memset and memcpy, there is little point of doing // anything here. These are required by a freestanding implementation, so if // even they are disabled, there is no point in trying hard. @@ -1371,6 +1493,9 @@ MadeChange = true; } + if (MSSA_ && VerifyMemorySSA) + MSSA_->verifyMemorySSA(); + MD = nullptr; return MadeChange; } @@ -1385,6 +1510,8 @@ auto *AA = &getAnalysis().getAAResults(); auto *AC = &getAnalysis().getAssumptionCache(F); auto *DT = &getAnalysis().getDomTree(); + auto *MSSAWP = getAnalysisIfAvailable(); - return Impl.runImpl(F, MD, TLI, AA, AC, DT); + return Impl.runImpl(F, MD, TLI, AA, AC, DT, + MSSAWP ? &MSSAWP->getMSSA() : nullptr); } diff --git a/llvm/test/Transforms/MemCpyOpt/preserve-memssa.ll b/llvm/test/Transforms/MemCpyOpt/preserve-memssa.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/MemCpyOpt/preserve-memssa.ll @@ -0,0 +1,139 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -aa-pipeline=basic-aa -passes='require,memcpyopt' -verify-memoryssa -S %s | FileCheck %s + +; REQUIRES: asserts + +target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-apple-macosx10.15.0" + +%t = type <{ i8*, [4 x i8], i8*, i8*, i32, [8192 x i8] }> + + +define i32 @test1(%t* %ptr) { +; CHECK-LABEL: @test1( +; CHECK-NEXT: invoke.cont6: +; CHECK-NEXT: [[P_1:%.*]] = getelementptr inbounds [[T:%.*]], %t* [[PTR:%.*]], i64 0, i32 0 +; CHECK-NEXT: [[P_1_C:%.*]] = bitcast i8** [[P_1]] to i8* +; CHECK-NEXT: [[P_2:%.*]] = getelementptr inbounds [[T]], %t* [[PTR]], i64 0, i32 4 +; CHECK-NEXT: [[P_3:%.*]] = getelementptr inbounds [[T]], %t* [[PTR]], i64 0, i32 5, i64 0 +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8** [[P_1]] to i8* +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 8 [[TMP0]], i8 0, i64 20, i1 false) +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i32* [[P_2]] to i8* +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 8 [[TMP1]], i8 0, i64 8195, i1 false) +; CHECK-NEXT: ret i32 0 +; +invoke.cont6: + %p.1 = getelementptr inbounds %t, %t* %ptr, i64 0, i32 0 + %p.1.c = bitcast i8** %p.1 to i8* + call void @llvm.memset.p0i8.i64(i8* %p.1.c, i8 0, i64 20, i1 false) + store i8* null, i8** %p.1, align 8 + %p.2 = getelementptr inbounds %t, %t* %ptr, i64 0, i32 4 + store i32 0, i32* %p.2, align 8 + %p.3 = getelementptr inbounds %t, %t* %ptr, i64 0, i32 5, i64 0 + call void @llvm.memset.p0i8.i64(i8* %p.3, i8 0, i64 8191, i1 false) + ret i32 0 +} + +declare i8* @get_ptr() + +define void @test2(i8 *%in) { +; CHECK-LABEL: @test2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CALL_I1_I:%.*]] = tail call i8* @get_ptr() +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr i8, i8* [[CALL_I1_I]], i64 10 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 1 [[TMP0]], i8 0, i64 0, i1 false) +; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[CALL_I1_I]], i8* [[IN:%.*]], i64 10, i1 false) +; CHECK-NEXT: ret void +; +entry: + %call.i1.i = tail call i8* @get_ptr() + tail call void @llvm.memset.p0i8.i64(i8* %call.i1.i, i8 0, i64 10, i1 false) + tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* %call.i1.i, i8* %in, i64 10, i1 false) + ret void +} + +declare i8* @malloc(i64) + +define i32 @test3(i8* noalias %in) { +; CHECK-LABEL: @test3( +; CHECK-NEXT: [[CALL_I_I_I:%.*]] = tail call i8* @malloc(i64 20) +; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[CALL_I_I_I]], i8* [[IN:%.*]], i64 20, i1 false) +; CHECK-NEXT: ret i32 10 +; + %call.i.i.i = tail call i8* @malloc(i64 20) + tail call void @llvm.memmove.p0i8.p0i8.i64(i8* %call.i.i.i, i8* %in, i64 20, i1 false) + ret i32 10 +} + +define void @test4(i32 %n, i8* noalias %ptr.0, i8* noalias %ptr.1, i32* %ptr.2) unnamed_addr { +; CHECK-LABEL: @test4( +; CHECK-NEXT: [[ELEM_I:%.*]] = getelementptr i8, i8* [[PTR_0:%.*]], i64 8 +; CHECK-NEXT: store i32 [[N:%.*]], i32* [[PTR_2:%.*]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, i8* [[ELEM_I]], i64 10 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 1 [[TMP1]], i8 0, i64 0, i1 false) +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[ELEM_I]], i8* [[PTR_1:%.*]], i64 10, i1 false) +; CHECK-NEXT: ret void +; + %elem.i = getelementptr i8, i8* %ptr.0, i64 8 + call void @llvm.memset.p0i8.i64(i8* %elem.i, i8 0, i64 10, i1 false) + store i32 %n, i32* %ptr.2, align 8 + call void @llvm.memcpy.p0i8.p0i8.i64(i8* %elem.i, i8* %ptr.1, i64 10, i1 false) + ret void +} + +declare void @decompose(%t* nocapture) + +define void @test5(i32* %ptr) { +; CHECK-LABEL: @test5( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[EARLY_DATA:%.*]] = alloca [128 x i8], align 8 +; CHECK-NEXT: [[TMP:%.*]] = alloca [[T:%.*]], align 8 +; CHECK-NEXT: [[TMP0:%.*]] = bitcast [128 x i8]* [[EARLY_DATA]] to i8* +; CHECK-NEXT: [[TMP1:%.*]] = bitcast %t* [[TMP]] to i8* +; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 32, i8* [[TMP0]]) +; CHECK-NEXT: [[TMP2:%.*]] = load i32, i32* [[PTR:%.*]], align 8 +; CHECK-NEXT: call fastcc void @decompose(%t* [[TMP]]) +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[TMP0]], i8* [[TMP1]], i64 32, i1 false) +; CHECK-NEXT: ret void +; +entry: + %early_data = alloca [128 x i8], align 8 + %tmp = alloca %t, align 8 + %0 = bitcast [128 x i8]* %early_data to i8* + %1 = bitcast %t* %tmp to i8* + call void @llvm.lifetime.start.p0i8(i64 32, i8* %0) + %2 = load i32, i32* %ptr, align 8 + call fastcc void @decompose(%t* %tmp) + call void @llvm.memcpy.p0i8.p0i8.i64(i8* %0, i8* %1, i64 32, i1 false) + ret void +} + +define i8 @test6(i8* %ptr, i8* noalias %ptr.1) { +; CHECK-LABEL: @test6( +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 24, i8* [[PTR:%.*]]) +; CHECK-NEXT: [[TMP0:%.*]] = load i8, i8* [[PTR]], align 8 +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[PTR]], i8* [[PTR_1:%.*]], i64 24, i1 false) +; CHECK-NEXT: ret i8 [[TMP0]] +; +entry: + call void @llvm.lifetime.start.p0i8(i64 24, i8* %ptr) + %0 = load i8, i8* %ptr, align 8 + call void @llvm.memmove.p0i8.p0i8.i64(i8* %ptr, i8* %ptr.1, i64 24, i1 false) + ret i8 %0 +} + +; Function Attrs: argmemonly nounwind willreturn +declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #0 + +; Function Attrs: argmemonly nounwind willreturn +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1 immarg) #0 + +; Function Attrs: argmemonly nounwind willreturn writeonly +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg) #1 + +; Function Attrs: argmemonly nounwind willreturn +declare void @llvm.memmove.p0i8.p0i8.i64(i8* nocapture, i8* nocapture readonly, i64, i1 immarg) #0 + +attributes #0 = { argmemonly nounwind willreturn } +attributes #1 = { argmemonly nounwind willreturn writeonly }