Index: lib/Transforms/Scalar/MemCpyOptimizer.cpp =================================================================== --- lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -800,9 +800,30 @@ if (cpyLen < srcSize) return false; - // Check that accessing the first srcSize bytes of dest will not cause a - // trap. Otherwise the transform is invalid since it might cause a trap - // to occur earlier than it otherwise would. + // Check that the destination is a valid target. It must be valid to access + // the first srcSize bytes of the destination, and dest must not be accessed + // in the case of exceptional control flow. + // + // If the destination is an alloca of sufficient size, accessing the first + // srcSize bytes of dest is well-defined. If the call does not return + // normally and this function does not call a returns_twice function, any + // unwinding will jump past this function, so dest will be deallocated before + // it could be accessed. + // + // If the destination is an sret pointer with an appropriate type, accessing the + // first srcSize bytes of dest is well-defined. If the call does not return + // normally and this function does not call a returns_twice function, any + // unwinding will jump past this function, and the caller won't access dest + // because sret values are only defined when a function returns normally. + // + // TODO: The justification for sret is a little dubious... it's relying on + // calling convention details which aren't actually defined in LangRef. + // + // TODO: It might be possible to use the dereferenceable attribute plus + // some way of checking for exceptional control flow to verify a destination. + // + // TODO: Need to actually check for setjmp calls (callsFunctionThatReturnsTwice()). + // See PR27848. if (AllocaInst *A = dyn_cast(cpyDest)) { // The destination is an alloca. Check it is larger than srcSize. ConstantInt *destArraySize = dyn_cast(A->getArraySize()); @@ -815,24 +836,22 @@ if (destSize < srcSize) return false; } else if (Argument *A = dyn_cast(cpyDest)) { - if (A->getDereferenceableBytes() < srcSize) { - // If the destination is an sret parameter then only accesses that are - // outside of the returned struct type can trap. - if (!A->hasStructRetAttr()) - return false; - - Type *StructTy = cast(A->getType())->getElementType(); - if (!StructTy->isSized()) { - // The call may never return and hence the copy-instruction may never - // be executed, and therefore it's not safe to say "the destination - // has at least bytes, as implied by the copy-instruction", - return false; - } + // If the destination is an sret parameter then only accesses that are + // outside of the returned struct type can trap. + if (!A->hasStructRetAttr()) + return false; - uint64_t destSize = DL.getTypeAllocSize(StructTy); - if (destSize < srcSize) - return false; + Type *StructTy = cast(A->getType())->getElementType(); + if (!StructTy->isSized()) { + // The call may never return and hence the copy-instruction may never + // be executed, and therefore it's not safe to say "the destination + // has at least bytes, as implied by the copy-instruction", + return false; } + + uint64_t destSize = DL.getTypeAllocSize(StructTy); + if (destSize < srcSize) + return false; } else { return false; } Index: test/Transforms/MemCpyOpt/callslot_deref.ll =================================================================== --- test/Transforms/MemCpyOpt/callslot_deref.ll +++ test/Transforms/MemCpyOpt/callslot_deref.ll @@ -4,14 +4,14 @@ declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture, i8* nocapture readonly, i64, i32, i1) unnamed_addr nounwind declare void @llvm.memset.p0i8.i64(i8* nocapture, i8, i64, i32, i1) nounwind -; all bytes of %dst that are touch by the memset are dereferenceable -define void @must_remove_memcpy(i8* noalias nocapture dereferenceable(4096) %dst) { -; CHECK-LABEL: @must_remove_memcpy( -; CHECK: call void @llvm.memset.p0i8.i64 -; CHECK-NOT: call void @llvm.memcpy.p0i8.p0i8.i64 +; can't remove memcpy because @bar might throw an exception or longjmp. +define void @cant_remove_memcpy(i8* noalias nocapture dereferenceable(4096) %dst) { +; CHECK-LABEL: @cant_remove_memcpy( +; CHECK: call void @bar +; CHECK: call void @llvm.memcpy.p0i8.p0i8.i64 %src = alloca [4096 x i8], align 1 %p = getelementptr inbounds [4096 x i8], [4096 x i8]* %src, i64 0, i64 0 - call void @llvm.memset.p0i8.i64(i8* %p, i8 0, i64 4096, i32 1, i1 false) + call void @bar(i8* %p) call void @llvm.memcpy.p0i8.p0i8.i64(i8* %dst, i8* %p, i64 4096, i32 1, i1 false) #2 ret void } @@ -28,3 +28,5 @@ call void @llvm.memcpy.p0i8.p0i8.i64(i8* %dst, i8* %p, i64 4096, i32 1, i1 false) #2 ret void } + +declare void @bar(i8* nocapture sret)