diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryLocation.h" @@ -945,12 +946,56 @@ return false; } - // Check that src isn't captured by the called function since the - // transformation can cause aliasing issues in that case. - for (unsigned ArgI = 0, E = C->arg_size(); ArgI != E; ++ArgI) - if (C->getArgOperand(ArgI) == cpySrc && !C->doesNotCapture(ArgI)) + // Check whether src is captured by the called function, in which case there + // may be further indirect uses of src. + bool SrcIsCaptured = any_of(C->args(), [&](Use &U) { + return U == cpySrc && !C->doesNotCapture(C->getArgOperandNo(&U)); + }); + + // If src is captured, then check whether there are any potential uses of + // src through the captured pointer before the lifetime of src ends, either + // due to a lifetime.end or a return from the function. + if (SrcIsCaptured) { + // Check that dest is not captured before/at the call. We have already + // checked that src is not captured before it. If either had been captured, + // then the call might be comparing the argument against the captured dest + // or src pointer. + Value *DestObj = getUnderlyingObject(cpyDest); + if (!isIdentifiedFunctionLocal(DestObj) || + PointerMayBeCapturedBefore(DestObj, /* ReturnCaptures */ true, + /* StoreCaptures */ true, C, DT, + /* IncludeI */ true)) return false; + MemoryLocation SrcLoc = + MemoryLocation(srcAlloca, LocationSize::precise(srcSize)); + for (Instruction &I : + make_range(++C->getIterator(), C->getParent()->end())) { + // Lifetime of srcAlloca ends at lifetime.end. + if (auto *II = dyn_cast(&I)) { + if (II->getIntrinsicID() == Intrinsic::lifetime_end && + II->getArgOperand(1)->stripPointerCasts() == srcAlloca && + cast(II->getArgOperand(0))->uge(srcSize)) + break; + } + + // Lifetime of srcAlloca ends at return. + if (isa(&I)) + break; + + // Ignore the direct read of src in the load. + if (&I == cpyLoad) + continue; + + // Check whether this instruction may mod/ref src through the captured + // pointer (we have already any direct mod/refs in the loop above). + // Also bail if we hit a terminator, as we don't want to scan into other + // blocks. + if (isModOrRefSet(AA->getModRefInfo(&I, SrcLoc)) || I.isTerminator()) + return false; + } + } + // Since we're changing the parameter to the callsite, we need to make sure // that what would be the new parameter dominates the callsite. if (!DT->dominates(cpyDest, C)) { diff --git a/llvm/test/Transforms/MemCpyOpt/capturing-func.ll b/llvm/test/Transforms/MemCpyOpt/capturing-func.ll --- a/llvm/test/Transforms/MemCpyOpt/capturing-func.ll +++ b/llvm/test/Transforms/MemCpyOpt/capturing-func.ll @@ -57,8 +57,7 @@ ; CHECK-NEXT: [[PTR1:%.*]] = alloca i8, align 1 ; CHECK-NEXT: [[PTR2:%.*]] = alloca i8, align 1 ; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 1, i8* [[PTR2]]) -; CHECK-NEXT: call void @foo(i8* [[PTR2]]) -; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i32(i8* [[PTR1]], i8* [[PTR2]], i32 1, i1 false) +; CHECK-NEXT: call void @foo(i8* [[PTR1]]) ; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 1, i8* [[PTR2]]) ; CHECK-NEXT: call void @foo(i8* [[PTR1]]) ; CHECK-NEXT: ret void @@ -101,8 +100,7 @@ ; CHECK-LABEL: define {{[^@]+}}@test_function_end() { ; CHECK-NEXT: [[PTR1:%.*]] = alloca i8, align 1 ; CHECK-NEXT: [[PTR2:%.*]] = alloca i8, align 1 -; CHECK-NEXT: call void @foo(i8* [[PTR2]]) -; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i32(i8* [[PTR1]], i8* [[PTR2]], i32 1, i1 false) +; CHECK-NEXT: call void @foo(i8* [[PTR1]]) ; CHECK-NEXT: ret void ; %ptr1 = alloca i8