diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -350,88 +350,34 @@ return false; } -// -// Check for mod of Loc between Start and End, excluding both boundaries, if \p -// End is not null. If \p End is null, check the full block \p BB. If \p Start -// isn't in the current block, recursively check all predecessors. -static bool writtenBetweenRecursive(MemorySSA *MSSA, AliasAnalysis &AA, - MemoryLocation Loc, - const MemoryAccess *Start, - const MemoryAccess *End, - const BasicBlock *BB, - SmallPtrSetImpl &Seen) { - if (End) { - assert(!BB); - auto *Defs = MSSA->getBlockDefs(End->getBlock()); - for (auto I = End->getReverseDefsIterator(), E = Defs->rend(); I != E; - ++I) { - const MemoryAccess *Curr = &*I; - if (Curr == Start) - return false; - - if (isa(Curr)) - continue; - - Instruction *CurrInst = cast(Curr)->getMemoryInst(); - if (isModSet(AA.getModRefInfo(CurrInst, Loc))) - return true; - } - BB = End->getBlock(); - } else { - assert(BB && "If end is null, BB must be set"); - // If the block has any defs, check them. - if (auto *Defs = MSSA->getBlockDefs(BB)) - return writtenBetweenRecursive(MSSA, AA, Loc, Start, &*Defs->rbegin(), - nullptr, Seen); - } - - // Already processed the block. - if (!Seen.insert(BB).second) - return false; - - if (Seen.size() > 8) - return true; - - // Check all predecessors. - return any_of(predecessors(BB), [&](const BasicBlock *Pred) { - return writtenBetweenRecursive(MSSA, AA, Loc, Start, nullptr, Pred, Seen); +template +static bool writtenBetween(AliasAnalysis &AA, MemoryLocation Loc, IterTy Start, + IterTy End) { + return any_of(make_range(Start, End), [&AA, Loc](const MemoryAccess &Acc) { + if (isa(&Acc)) + return false; + Instruction *AccInst = cast(&Acc)->getMemoryInst(); + return isModSet(AA.getModRefInfo(AccInst, Loc)); }); } // Check for mod of Loc between Start and End, excluding both boundaries. // Start and End can be in different blocks. static bool writtenBetween(MemorySSA *MSSA, AliasAnalysis &AA, - MemoryLocation Loc, const MemoryUseOrDef *Start, + MemoryLocation Loc, const MemoryDef *Start, const MemoryUseOrDef *End) { // TODO: Only walk until we hit Start. MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess( End->getDefiningAccess(), Loc); - if (!MSSA->dominates(Clobber, Start)) + if (!MSSA->dominates(Clobber, Start) || Start->getBlock() != End->getBlock()) return true; - // If there's an access before the End in the current block, check from that - // access. - const MemoryAccess *PrevAcc = nullptr; - if (isa(End)) { - auto *Defs = MSSA->getBlockDefs(End->getBlock()); - auto PrevIter = std::next(End->getReverseDefsIterator()); - if (PrevIter != Defs->rend()) - PrevAcc = &*PrevIter; - } else { - auto *Defs = MSSA->getBlockAccesses(End->getBlock()); - auto PrevIter = std::next(End->getReverseIterator()); - if (PrevIter != Defs->rend()) - PrevAcc = &*PrevIter; - } - - SmallPtrSet Seen; - if (PrevAcc) - return writtenBetweenRecursive(MSSA, AA, Loc, Start, PrevAcc, nullptr, - Seen); - return any_of(predecessors(End->getBlock()), [&](const BasicBlock *Pred) { - return writtenBetweenRecursive(MSSA, AA, Loc, Start, nullptr, Pred, Seen); - }); + if (isa(End)) + return writtenBetween(AA, Loc, std::next(Start->getDefsIterator()), + End->getDefsIterator()); + return writtenBetween(AA, Loc, std::next(Start->getIterator()), + End->getIterator()); } /// When scanning forward over instructions, we look for some other patterns to @@ -1192,7 +1138,8 @@ // TODO: It would be sufficient to check the MDep source up to the memcpy // size of M, rather than MDep. if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep), - MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(M))) + cast(MSSA->getMemoryAccess(MDep)), + MSSA->getMemoryAccess(M))) return false; // If the dest of the second might alias the source of the first, then the @@ -1631,7 +1578,8 @@ // foo(*a) // It would be invalid to transform the second memcpy into foo(*b). if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep), - MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(&CB))) + cast(MSSA->getMemoryAccess(MDep)), + MSSA->getMemoryAccess(&CB))) return false; Value *TmpCast = MDep->getSource(); diff --git a/llvm/test/Transforms/MemCpyOpt/memcpy-invoke-memcpy.ll b/llvm/test/Transforms/MemCpyOpt/memcpy-invoke-memcpy.ll --- a/llvm/test/Transforms/MemCpyOpt/memcpy-invoke-memcpy.ll +++ b/llvm/test/Transforms/MemCpyOpt/memcpy-invoke-memcpy.ll @@ -17,7 +17,7 @@ ; CHECK-NEXT: catch i8* null ; CHECK-NEXT: ret void ; CHECK: try.cont: -; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 [[DST:%.*]], i8* align 8 [[SRC]], i64 64, i1 false) +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 [[DST:%.*]], i8* align 8 [[TEMP]], i64 64, i1 false) ; CHECK-NEXT: ret void ; entry: @@ -48,7 +48,7 @@ ; CHECK: lpad: ; CHECK-NEXT: [[TMP0:%.*]] = landingpad { i8*, i32 } ; CHECK-NEXT: catch i8* null -; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 [[DST:%.*]], i8* align 8 [[SRC]], i64 64, i1 false) +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 [[DST:%.*]], i8* align 8 [[TEMP]], i64 64, i1 false) ; CHECK-NEXT: ret void ; CHECK: try.cont: ; CHECK-NEXT: ret void diff --git a/llvm/test/Transforms/MemCpyOpt/memcpy.ll b/llvm/test/Transforms/MemCpyOpt/memcpy.ll --- a/llvm/test/Transforms/MemCpyOpt/memcpy.ll +++ b/llvm/test/Transforms/MemCpyOpt/memcpy.ll @@ -230,7 +230,7 @@ ; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 4 [[A2]], i8* align 4 [[P:%.*]], i64 8, i1 false) ; CHECK-NEXT: br i1 [[C:%.*]], label [[CALL:%.*]], label [[EXIT:%.*]] ; CHECK: call: -; CHECK-NEXT: call void @test4a(i8* byval(i8) align 1 [[P]]) +; CHECK-NEXT: call void @test4a(i8* byval(i8) align 1 [[A2]]) ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret void diff --git a/llvm/test/Transforms/MemCpyOpt/nonlocal-memcpy-memcpy.ll b/llvm/test/Transforms/MemCpyOpt/nonlocal-memcpy-memcpy.ll --- a/llvm/test/Transforms/MemCpyOpt/nonlocal-memcpy-memcpy.ll +++ b/llvm/test/Transforms/MemCpyOpt/nonlocal-memcpy-memcpy.ll @@ -30,7 +30,7 @@ ; CHECK-NEXT: call void @qux() ; CHECK-NEXT: unreachable ; CHECK: more: -; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 [[DST:%.*]], i8* align 8 [[SRC]], i64 64, i1 false) +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 [[DST:%.*]], i8* align 8 [[TEMP]], i64 64, i1 false) ; CHECK-NEXT: ret void ; bb: @@ -62,7 +62,7 @@ ; CHECK: bb4: ; CHECK-NEXT: [[T5:%.*]] = bitcast %struct.s* [[T]] to i8* ; CHECK-NEXT: [[S6:%.*]] = bitcast %struct.s* [[S]] to i8* -; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 4 [[T5]], i8* align 4 bitcast (%struct.s* @s_foo to i8*), i64 8, i1 false) +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 4 [[T5]], i8* align 4 [[S6]], i64 8, i1 false) ; CHECK-NEXT: br label [[BB7]] ; CHECK: bb7: ; CHECK-NEXT: [[T8:%.*]] = getelementptr [[STRUCT_S]], %struct.s* [[T]], i32 0, i32 0 @@ -121,7 +121,7 @@ ; CHECK: bb22: ; CHECK-NEXT: [[T23:%.*]] = bitcast %struct.s* [[T]] to i8* ; CHECK-NEXT: [[S24:%.*]] = bitcast %struct.s* [[S]] to i8* -; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 4 [[T23]], i8* align 4 bitcast (%struct.s* @s_baz to i8*), i64 8, i1 false) +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 4 [[T23]], i8* align 4 [[S24]], i64 8, i1 false) ; CHECK-NEXT: br label [[BB23]] ; CHECK: bb23: ; CHECK-NEXT: [[T17:%.*]] = getelementptr inbounds [[STRUCT_S]], %struct.s* [[T]], i32 0, i32 0 @@ -180,7 +180,7 @@ ; CHECK-NEXT: store i64 0, i64* [[ARG:%.*]], align 4 ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: -; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[C:%.*]], i8* [[B]], i64 16, i1 false) +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[C:%.*]], i8* [[A]], i64 16, i1 false) ; CHECK-NEXT: ret void ; entry: