diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -478,7 +478,7 @@ Args.insert(Ptr); // Instruction to lift before P. - SmallVector ToLift; + SmallVector ToLift{SI}; // Memory locations of lifted instructions. SmallVector MemLocs{StoreLoc}; @@ -549,10 +549,40 @@ } } - // We made it, we need to lift + // Find MSSA insertion point. Normally P will always have a corresponding + // memory access before which we can insert. However, with non-standard AA + // pipelines, there may be a mismatch between AA and MSSA, in which case we + // will scan for a memory access before P. In either case, we know for sure + // that at least the load will have a memory access. + // TODO: Simplify this once P will be determined by MSSA, in which case the + // discrepancy can no longer occur. + MemoryUseOrDef *MemInsertPoint = nullptr; + if (MSSAU) { + if (MemoryUseOrDef *MA = MSSAU->getMemorySSA()->getMemoryAccess(P)) { + MemInsertPoint = cast(--MA->getIterator()); + } else { + const Instruction *ConstP = P; + for (const Instruction &I : make_range(++ConstP->getReverseIterator(), + ++LI->getReverseIterator())) { + if (MemoryUseOrDef *MA = MSSAU->getMemorySSA()->getMemoryAccess(&I)) { + MemInsertPoint = MA; + break; + } + } + } + } + + // We made it, we need to lift. for (auto *I : llvm::reverse(ToLift)) { LLVM_DEBUG(dbgs() << "Lifting " << *I << " before " << *P << "\n"); I->moveBefore(P); + if (MSSAU) { + assert(MemInsertPoint && "Must have found insert point"); + if (MemoryUseOrDef *MA = MSSAU->getMemorySSA()->getMemoryAccess(I)) { + MSSAU->moveAfter(MA, MemInsertPoint); + MemInsertPoint = MA; + } + } } return true; @@ -636,9 +666,8 @@ << *M << "\n"); if (MSSAU) { - assert(isa(MSSAU->getMemorySSA()->getMemoryAccess(P))); auto *LastDef = - cast(MSSAU->getMemorySSA()->getMemoryAccess(P)); + cast(MSSAU->getMemorySSA()->getMemoryAccess(SI)); auto *NewAccess = MSSAU->createMemoryAccessAfter(M, LastDef, LastDef); MSSAU->insertDef(cast(NewAccess), /*RenameUses=*/true); diff --git a/llvm/test/Transforms/MemCpyOpt/preserve-memssa.ll b/llvm/test/Transforms/MemCpyOpt/preserve-memssa.ll --- a/llvm/test/Transforms/MemCpyOpt/preserve-memssa.ll +++ b/llvm/test/Transforms/MemCpyOpt/preserve-memssa.ll @@ -148,6 +148,21 @@ ret void } +define void @test8(%t* noalias %src, %t* %dst) { +; CHECK-LABEL: @test8( +; CHECK-NEXT: [[TMP1:%.*]] = bitcast %t* [[SRC:%.*]] to i8* +; CHECK-NEXT: [[TMP2:%.*]] = bitcast %t* [[DST:%.*]] to i8* +; CHECK-NEXT: [[TMP3:%.*]] = bitcast %t* [[SRC]] to i8* +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 1 [[TMP2]], i8* align 1 [[TMP3]], i64 8224, i1 false) +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 1 [[TMP1]], i8 0, i64 8224, i1 false) +; CHECK-NEXT: ret void +; + %1 = load %t, %t* %src + store %t zeroinitializer, %t* %src + store %t %1, %t* %dst + ret void +} + declare void @clobber() ; Function Attrs: argmemonly nounwind willreturn