diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -350,15 +350,88 @@ return false; } +// +// Check for mod of Loc between Start and End, excluding both boundaries, if \p +// End is not null. If \p End is null, check the full block \p BB. If \p Start +// isn't in the current block, recursively check all predecessors. +static bool writtenBetweenRecursive(MemorySSA *MSSA, AliasAnalysis &AA, + MemoryLocation Loc, + const MemoryAccess *Start, + const MemoryAccess *End, + const BasicBlock *BB, + SmallPtrSetImpl &Seen) { + if (End) { + assert(!BB); + auto *Defs = MSSA->getBlockDefs(End->getBlock()); + for (auto I = End->getReverseDefsIterator(), E = Defs->rend(); I != E; + ++I) { + const MemoryAccess *Curr = &*I; + if (Curr == Start) + return false; + + if (isa(Curr)) + continue; + + Instruction *CurrInst = cast(Curr)->getMemoryInst(); + if (isModSet(AA.getModRefInfo(CurrInst, Loc))) + return true; + } + BB = End->getBlock(); + } else { + assert(BB && "If end is null, BB must be set"); + // If the block has any defs, check them. + if (auto *Defs = MSSA->getBlockDefs(BB)) + return writtenBetweenRecursive(MSSA, AA, Loc, Start, &*Defs->rbegin(), + nullptr, Seen); + } + + // Already processed the block. + if (!Seen.insert(BB).second) + return false; + + if (Seen.size() > 8) + return true; + + // Check all predecessors. + return any_of(predecessors(BB), [&](const BasicBlock *Pred) { + return writtenBetweenRecursive(MSSA, AA, Loc, Start, nullptr, Pred, Seen); + }); +} + // Check for mod of Loc between Start and End, excluding both boundaries. // Start and End can be in different blocks. -static bool writtenBetween(MemorySSA *MSSA, MemoryLocation Loc, - const MemoryUseOrDef *Start, +static bool writtenBetween(MemorySSA *MSSA, AliasAnalysis &AA, + MemoryLocation Loc, const MemoryUseOrDef *Start, const MemoryUseOrDef *End) { // TODO: Only walk until we hit Start. MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess( End->getDefiningAccess(), Loc); - return !MSSA->dominates(Clobber, Start); + + if (!MSSA->dominates(Clobber, Start)) + return true; + + // If there's an access before the End in the current block, check from that + // access. + const MemoryAccess *PrevAcc = nullptr; + if (isa(End)) { + auto *Defs = MSSA->getBlockDefs(End->getBlock()); + auto PrevIter = std::next(End->getReverseDefsIterator()); + if (PrevIter != Defs->rend()) + PrevAcc = &*PrevIter; + } else { + auto *Defs = MSSA->getBlockAccesses(End->getBlock()); + auto PrevIter = std::next(End->getReverseIterator()); + if (PrevIter != Defs->rend()) + PrevAcc = &*PrevIter; + } + + SmallPtrSet Seen; + if (PrevAcc) + return writtenBetweenRecursive(MSSA, AA, Loc, Start, PrevAcc, nullptr, + Seen); + return any_of(predecessors(End->getBlock()), [&](const BasicBlock *Pred) { + return writtenBetweenRecursive(MSSA, AA, Loc, Start, nullptr, Pred, Seen); + }); } /// When scanning forward over instructions, we look for some other patterns to @@ -1118,7 +1191,7 @@ // then we could still perform the xform by moving M up to the first memcpy. // TODO: It would be sufficient to check the MDep source up to the memcpy // size of M, rather than MDep. - if (writtenBetween(MSSA, MemoryLocation::getForSource(MDep), + if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep), MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(M))) return false; @@ -1557,7 +1630,7 @@ // *b = 42; // foo(*a) // It would be invalid to transform the second memcpy into foo(*b). - if (writtenBetween(MSSA, MemoryLocation::getForSource(MDep), + if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep), MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(&CB))) return false; diff --git a/llvm/test/Transforms/MemCpyOpt/memcpy-byval-forwarding-clobbers.ll b/llvm/test/Transforms/MemCpyOpt/memcpy-byval-forwarding-clobbers.ll --- a/llvm/test/Transforms/MemCpyOpt/memcpy-byval-forwarding-clobbers.ll +++ b/llvm/test/Transforms/MemCpyOpt/memcpy-byval-forwarding-clobbers.ll @@ -13,7 +13,6 @@ ; %a.2's lifetime ends before the call to @check. Cannot replace ; %a.1 with %a.2 in the call to @check. -; FIXME: Find lifetime.end, prevent optimization. define i1 @alloca_forwarding_lifetime_end_clobber() { ; CHECK-LABEL: @alloca_forwarding_lifetime_end_clobber( ; CHECK-NEXT: entry: @@ -26,7 +25,7 @@ ; CHECK-NEXT: store i8 0, i8* [[BC_A_2]], align 1 ; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[BC_A_1]], i8* [[BC_A_2]], i64 8, i1 false) ; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 8, i8* [[BC_A_2]]) -; CHECK-NEXT: [[CALL:%.*]] = call i1 @check(i64* byval(i64) align 8 [[A_2]]) +; CHECK-NEXT: [[CALL:%.*]] = call i1 @check(i64* byval(i64) align 8 [[A_1]]) ; CHECK-NEXT: ret i1 [[CALL]] ; entry: @@ -46,7 +45,6 @@ ; There is a call clobbering %a.2 before the call to @check. Cannot replace ; %a.1 with %a.2 in the call to @check. -; FIXME: Find clobber, prevent optimization. define i1 @alloca_forwarding_call_clobber() { ; CHECK-LABEL: @alloca_forwarding_call_clobber( ; CHECK-NEXT: entry: @@ -59,7 +57,7 @@ ; CHECK-NEXT: store i8 0, i8* [[BC_A_2]], align 1 ; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[BC_A_1]], i8* [[BC_A_2]], i64 8, i1 false) ; CHECK-NEXT: call void @clobber(i8* [[BC_A_2]]) -; CHECK-NEXT: [[CALL:%.*]] = call i1 @check(i64* byval(i64) align 8 [[A_2]]) +; CHECK-NEXT: [[CALL:%.*]] = call i1 @check(i64* byval(i64) align 8 [[A_1]]) ; CHECK-NEXT: ret i1 [[CALL]] ; entry: