Index: llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h =================================================================== --- llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h +++ llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h @@ -61,8 +61,9 @@ bool processMemSet(MemSetInst *SI, BasicBlock::iterator &BBI); bool processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI); bool processMemMove(MemMoveInst *M); - bool performCallSlotOptzn(Instruction *cpy, Value *cpyDst, Value *cpySrc, - uint64_t cpyLen, Align cpyAlign, CallInst *C); + bool performCallSlotOptzn(Instruction *cpyLoad, Instruction *cpyStore, + Value *cpyDst, Value *cpySrc, uint64_t cpyLen, + Align cpyAlign, CallInst *C); bool processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep); bool processMemSetMemCpyDependence(MemCpyInst *MemCpy, MemSetInst *MemSet); bool performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, MemSetInst *MemSet); Index: llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp =================================================================== --- llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -658,8 +658,6 @@ if (C) { // Check that nothing touches the dest of the "copy" between // the call and the store. - Value *CpyDest = SI->getPointerOperand()->stripPointerCasts(); - bool CpyDestIsLocal = isa(CpyDest); MemoryLocation StoreLoc = MemoryLocation::get(SI); for (BasicBlock::iterator I = --SI->getIterator(), E = C->getIterator(); I != E; --I) { @@ -667,18 +665,12 @@ C = nullptr; break; } - // The store to dest may never happen if an exception can be thrown - // between the load and the store. - if (I->mayThrow() && !CpyDestIsLocal) { - C = nullptr; - break; - } } } if (C) { bool changed = performCallSlotOptzn( - LI, SI->getPointerOperand()->stripPointerCasts(), + LI, SI, SI->getPointerOperand()->stripPointerCasts(), LI->getPointerOperand()->stripPointerCasts(), DL.getTypeStoreSize(SI->getOperand(0)->getType()), commonAlignment(SI->getAlign(), LI->getAlign()), C); @@ -753,7 +745,8 @@ /// Takes a memcpy and a call that it depends on, /// and checks for the possibility of a call slot optimization by having /// the call write its result directly into the destination of the memcpy. -bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, +bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, + Instruction *cpyStore, Value *cpyDest, Value *cpySrc, uint64_t cpyLen, Align cpyAlign, CallInst *C) { // The general transformation to keep in mind is @@ -784,7 +777,7 @@ if (!srcArraySize) return false; - const DataLayout &DL = cpy->getModule()->getDataLayout(); + const DataLayout &DL = cpyLoad->getModule()->getDataLayout(); uint64_t srcSize = DL.getTypeAllocSize(srcAlloca->getAllocatedType()) * srcArraySize->getZExtValue(); @@ -794,6 +787,7 @@ // Check that accessing the first srcSize bytes of dest will not cause a // trap. Otherwise the transform is invalid since it might cause a trap // to occur earlier than it otherwise would. + // TODO: Use isDereferenceablePointer() API instead. if (AllocaInst *A = dyn_cast(cpyDest)) { // The destination is an alloca. Check it is larger than srcSize. ConstantInt *destArraySize = dyn_cast(A->getArraySize()); @@ -806,10 +800,6 @@ if (destSize < srcSize) return false; } else if (Argument *A = dyn_cast(cpyDest)) { - // The store to dest may never happen if the call can throw. - if (C->mayThrow()) - return false; - if (A->getDereferenceableBytes() < srcSize) { // If the destination is an sret parameter then only accesses that are // outside of the returned struct type can trap. @@ -832,6 +822,18 @@ return false; } + // If the destination is not local, check that nothing between the call and + // the copy (including the call itself) can throw. + if (!isa(cpyDest)) { + assert(C->getParent() == cpyStore->getParent() && + "call and copy must be in the same block"); + for (const Instruction &I : make_range(C->getIterator(), + cpyStore->getIterator())) { + if (I.mayThrow()) + return false; + } + } + // Check that dest points to memory that is at least as aligned as src. Align srcAlign = srcAlloca->getAlign(); bool isDestSufficientlyAligned = srcAlign <= cpyAlign; @@ -866,7 +868,7 @@ if (IT->isLifetimeStartOrEnd()) continue; - if (U != C && U != cpy) + if (U != C && U != cpyLoad) return false; } @@ -940,7 +942,7 @@ LLVMContext::MD_noalias, LLVMContext::MD_invariant_group, LLVMContext::MD_access_group}; - combineMetadata(C, cpy, KnownIDs, true); + combineMetadata(C, cpyLoad, KnownIDs, true); return true; } @@ -1240,7 +1242,7 @@ // of conservatively taking the minimum? Align Alignment = std::min(M->getDestAlign().valueOrOne(), M->getSourceAlign().valueOrOne()); - if (performCallSlotOptzn(M, M->getDest(), M->getSource(), + if (performCallSlotOptzn(M, M, M->getDest(), M->getSource(), CopySize->getZExtValue(), Alignment, C)) { eraseInstruction(M); ++NumMemCpyInstr; Index: llvm/test/Transforms/MemCpyOpt/callslot.ll =================================================================== --- llvm/test/Transforms/MemCpyOpt/callslot.ll +++ llvm/test/Transforms/MemCpyOpt/callslot.ll @@ -75,8 +75,9 @@ define void @throw_between_call_and_mempy(i8* dereferenceable(16) %dest) { ; CHECK-LABEL: @throw_between_call_and_mempy( ; CHECK-NEXT: [[SRC:%.*]] = alloca i8, i64 16, align 1 -; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* [[DEST:%.*]], i8 0, i64 16, i1 false) +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* [[SRC]], i8 0, i64 16, i1 false) ; CHECK-NEXT: call void @may_throw() [[ATTR2:#.*]] +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* [[DEST:%.*]], i8 0, i64 16, i1 false) ; CHECK-NEXT: ret void ; %src = alloca i8, i64 16