diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -352,13 +352,54 @@ // Check for mod of Loc between Start and End, excluding both boundaries. // Start and End can be in different blocks. -static bool writtenBetween(MemorySSA *MSSA, MemoryLocation Loc, - const MemoryUseOrDef *Start, +static bool writtenBetween(MemorySSA *MSSA, AliasAnalysis &AA, + MemoryLocation Loc, const MemoryUseOrDef *Start, const MemoryUseOrDef *End) { // TODO: Only walk until we hit Start. MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess( End->getDefiningAccess(), Loc); - return !MSSA->dominates(Clobber, Start); + + if (!MSSA->dominates(Clobber, Start)) + return true; + + // Scan uses of Clobber between Clobber and End to find any potential + // write-clobbers of Loc. + SmallPtrSet Seen; + SmallVector Worklist; + auto PushMemUses = [&Worklist, &Seen](MemoryAccess *Acc) { + for (Use &U : Acc->uses()) { + auto *MU = cast(U.getUser()); + if (!Seen.insert(MU).second) + continue; + Worklist.push_back(MU); + } + }; + PushMemUses(Clobber); + while (!Worklist.empty()) { + // Cap number of uses to visit. + if (Seen.size() > 32) + return true; + + MemoryAccess *Curr = Worklist.pop_back_val(); + if (isa(Curr)) { + PushMemUses(Curr); + continue; + } + + // Read-only use, skip it. + if (isa(Curr)) + continue; + + // Use is past End, skip it. + if (MSSA->dominates(End, Curr)) + continue; + + Instruction *CurrInst = cast(Curr)->getMemoryInst(); + if (isModSet(AA.getModRefInfo(CurrInst, Loc))) + return true; + PushMemUses(Curr); + } + return false; } /// When scanning forward over instructions, we look for some other patterns to @@ -1118,7 +1159,7 @@ // then we could still perform the xform by moving M up to the first memcpy. // TODO: It would be sufficient to check the MDep source up to the memcpy // size of M, rather than MDep. - if (writtenBetween(MSSA, MemoryLocation::getForSource(MDep), + if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep), MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(M))) return false; @@ -1557,7 +1598,7 @@ // *b = 42; // foo(*a) // It would be invalid to transform the second memcpy into foo(*b). - if (writtenBetween(MSSA, MemoryLocation::getForSource(MDep), + if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep), MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(&CB))) return false; diff --git a/llvm/test/Transforms/MemCpyOpt/memcpy-byval-forwarding-clobbers.ll b/llvm/test/Transforms/MemCpyOpt/memcpy-byval-forwarding-clobbers.ll --- a/llvm/test/Transforms/MemCpyOpt/memcpy-byval-forwarding-clobbers.ll +++ b/llvm/test/Transforms/MemCpyOpt/memcpy-byval-forwarding-clobbers.ll @@ -13,7 +13,6 @@ ; %a.2's lifetime ends before the call to @check. Cannot replace ; %a.1 with %a.2 in the call to @check. -; FIXME: Find lifetime.end, prevent optimization. define i1 @alloca_forwarding_lifetime_end_clobber() { ; CHECK-LABEL: @alloca_forwarding_lifetime_end_clobber( ; CHECK-NEXT: entry: @@ -26,7 +25,7 @@ ; CHECK-NEXT: store i8 0, i8* [[BC_A_2]], align 1 ; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[BC_A_1]], i8* [[BC_A_2]], i64 8, i1 false) ; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 8, i8* [[BC_A_2]]) -; CHECK-NEXT: [[CALL:%.*]] = call i1 @check(i64* byval(i64) align 8 [[A_2]]) +; CHECK-NEXT: [[CALL:%.*]] = call i1 @check(i64* byval(i64) align 8 [[A_1]]) ; CHECK-NEXT: ret i1 [[CALL]] ; entry: @@ -46,7 +45,6 @@ ; There is a call clobbering %a.2 before the call to @check. Cannot replace ; %a.1 with %a.2 in the call to @check. -; FIXME: Find clobber, prevent optimization. define i1 @alloca_forwarding_call_clobber() { ; CHECK-LABEL: @alloca_forwarding_call_clobber( ; CHECK-NEXT: entry: @@ -59,7 +57,7 @@ ; CHECK-NEXT: store i8 0, i8* [[BC_A_2]], align 1 ; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[BC_A_1]], i8* [[BC_A_2]], i64 8, i1 false) ; CHECK-NEXT: call void @clobber(i8* [[BC_A_2]]) -; CHECK-NEXT: [[CALL:%.*]] = call i1 @check(i64* byval(i64) align 8 [[A_2]]) +; CHECK-NEXT: [[CALL:%.*]] = call i1 @check(i64* byval(i64) align 8 [[A_1]]) ; CHECK-NEXT: ret i1 [[CALL]] ; entry: diff --git a/llvm/test/Transforms/MemCpyOpt/nonlocal-memcpy-memcpy.ll b/llvm/test/Transforms/MemCpyOpt/nonlocal-memcpy-memcpy.ll --- a/llvm/test/Transforms/MemCpyOpt/nonlocal-memcpy-memcpy.ll +++ b/llvm/test/Transforms/MemCpyOpt/nonlocal-memcpy-memcpy.ll @@ -30,7 +30,7 @@ ; CHECK-NEXT: call void @qux() ; CHECK-NEXT: unreachable ; CHECK: more: -; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 [[DST:%.*]], i8* align 8 [[SRC]], i64 64, i1 false) +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 [[DST:%.*]], i8* align 8 [[TEMP]], i64 64, i1 false) ; CHECK-NEXT: ret void ; bb: