Index: lib/Transforms/Scalar/DeadStoreElimination.cpp =================================================================== --- lib/Transforms/Scalar/DeadStoreElimination.cpp +++ lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -275,9 +275,9 @@ } -/// isShortenable - Returns true if this instruction can be safely shortened in -/// length. -static bool isShortenable(Instruction *I) { +/// isShortenableAtTheEnd - Returns true if the end of this instruction can be +/// safely shortened in length. +static bool isShortenableAtTheEnd(Instruction *I) { // Don't shorten stores for now if (isa(I)) return false; @@ -297,6 +297,15 @@ return false; } +/// isShortenableAtTheBeginning - Returns true if the beginning of this +/// instruction can be safely shortened in length. +static bool isShortenableAtTheBeginning(Instruction *I) { + // FIXME: Handle only memset for now. Supporting memcpy should be easily done + // by offsetting the source address. + IntrinsicInst *II = dyn_cast(I); + return II && II->getIntrinsicID() == Intrinsic::memset; +} + /// getStoredPointerOperand - Return the pointer that is being written to. static Value *getStoredPointerOperand(Instruction *I) { if (StoreInst *SI = dyn_cast(I)) @@ -327,18 +336,19 @@ } namespace { - enum OverwriteResult - { - OverwriteComplete, - OverwriteEnd, - OverwriteUnknown - }; +enum OverwriteResult { + OverwriteComplete, + OverwriteBegin, + OverwriteEnd, + OverwriteUnknown +}; } /// isOverwrite - Return 'OverwriteComplete' if a store to the 'Later' location -/// completely overwrites a store to the 'Earlier' location. -/// 'OverwriteEnd' if the end of the 'Earlier' location is completely -/// overwritten by 'Later', or 'OverwriteUnknown' if nothing can be determined +/// completely overwrites a store to the 'Earlier' location. 'OverwriteEnd' if +/// the end of the 'Earlier' location is completely overwritten by 'Later', +/// 'OverwriteBegin' if the beginning of the 'Earlier' location is overwritten +/// by 'Later', or 'OverwriteUnknown' if nothing can be determined. static OverwriteResult isOverwrite(const MemoryLocation &Later, const MemoryLocation &Earlier, const DataLayout &DL, @@ -429,7 +439,19 @@ int64_t(LaterOff + Later.Size) >= int64_t(EarlierOff + Earlier.Size)) return OverwriteEnd; - // Otherwise, they don't completely overlap. + // We also need to check if the later store overwrites the beginning of + // the earlier store + // + // |--earlier--| + // |-- later --| + // + // In this case we may want to move the destination address and trim the size + // of earlier to avoid generating writes to addresses which will definitely + // be overwritten later. + if (LaterOff <= EarlierOff && int64_t(LaterOff + Later.Size) > EarlierOff && + int64_t(LaterOff + Later.Size) < int64_t(EarlierOff + Earlier.Size)) + return OverwriteBegin; + return OverwriteUnknown; } @@ -603,29 +625,49 @@ if (BBI != BB.begin()) --BBI; break; - } else if (OR == OverwriteEnd && isShortenable(DepWrite)) { + } else if ((OR == OverwriteEnd && isShortenableAtTheEnd(DepWrite)) || + ((OR == OverwriteBegin && + isShortenableAtTheBeginning(DepWrite)))) { // TODO: base this on the target vector size so that if the earlier // store was too small to get vector writes anyway then its likely // a good idea to shorten it // Power of 2 vector writes are probably always a bad idea to optimize // as any store/memset/memcpy is likely using vector instructions so // shortening it to not vector size is likely to be slower - MemIntrinsic* DepIntrinsic = cast(DepWrite); + MemIntrinsic *DepIntrinsic = cast(DepWrite); unsigned DepWriteAlign = DepIntrinsic->getAlignment(); - if (llvm::isPowerOf2_64(InstWriteOffset) || + bool IsOverwriteEnd = OR == OverwriteEnd; + if (!IsOverwriteEnd) + InstWriteOffset = int64_t(InstWriteOffset + Loc.Size); + + if ((llvm::isPowerOf2_64(InstWriteOffset) && + DepWriteAlign <= InstWriteOffset) || ((DepWriteAlign != 0) && InstWriteOffset % DepWriteAlign == 0)) { - DEBUG(dbgs() << "DSE: Remove Dead Store:\n OW END: " - << *DepWrite << "\n KILLER (offset " - << InstWriteOffset << ", " - << DepLoc.Size << ")" - << *Inst << '\n'); + DEBUG(dbgs() << "DSE: Remove Dead Store:\n OW " + << (IsOverwriteEnd ? "END" : "BEGIN") << ": " + << *DepWrite << "\n KILLER (offset " + << InstWriteOffset << ", " << DepLoc.Size << ")" + << *Inst << '\n'); - Value* DepWriteLength = DepIntrinsic->getLength(); - Value* TrimmedLength = ConstantInt::get(DepWriteLength->getType(), - InstWriteOffset - - DepWriteOffset); + int64_t NewLength = + IsOverwriteEnd + ? InstWriteOffset - DepWriteOffset + : DepLoc.Size - (InstWriteOffset - DepWriteOffset); + + Value *DepWriteLength = DepIntrinsic->getLength(); + Value *TrimmedLength = + ConstantInt::get(DepWriteLength->getType(), NewLength); DepIntrinsic->setLength(TrimmedLength); + + if (!IsOverwriteEnd) { + int64_t OffsetMoved = (InstWriteOffset - DepWriteOffset); + Value *Indices[1] = { + ConstantInt::get(DepWriteLength->getType(), OffsetMoved)}; + GetElementPtrInst *NewDestGEP = GetElementPtrInst::CreateInBounds( + DepIntrinsic->getRawDest(), Indices, "", DepWrite); + DepIntrinsic->setDest(NewDestGEP); + } MadeChange = true; } } Index: test/Transforms/DeadStoreElimination/OverwriteStoreBegin.ll =================================================================== --- /dev/null +++ test/Transforms/DeadStoreElimination/OverwriteStoreBegin.ll @@ -0,0 +1,88 @@ +; RUN: opt < %s -basicaa -dse -S | FileCheck %s + +define void @write4to8(i32* nocapture %p) { +; CHECK-LABEL: @write4to8( +entry: + %arrayidx0 = getelementptr inbounds i32, i32* %p, i64 1 + %p3 = bitcast i32* %arrayidx0 to i8* +; CHECK: [[GEP:%[0-9]+]] = getelementptr inbounds i8, i8* %p3, i64 4 +; CHECK: call void @llvm.memset.p0i8.i64(i8* [[GEP]], i8 0, i64 24, i32 4, i1 false) + call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 28, i32 4, i1 false) + %arrayidx1 = getelementptr inbounds i32, i32* %p, i64 1 + store i32 1, i32* %arrayidx1, align 4 + ret void +} + +define void @write0to4(i32* nocapture %p) { +; CHECK-LABEL: @write0to4( +entry: + %p3 = bitcast i32* %p to i8* +; CHECK: [[GEP:%[0-9]+]] = getelementptr inbounds i8, i8* %p3, i64 4 +; CHECK: call void @llvm.memset.p0i8.i64(i8* [[GEP]], i8 0, i64 24, i32 4, i1 false) + call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 28, i32 4, i1 false) + store i32 1, i32* %p, align 4 + ret void +} + +define void @write0to8(i32* nocapture %p) { +; CHECK-LABEL: @write0to8( +entry: + %p3 = bitcast i32* %p to i8* +; CHECK: [[GEP:%[0-9]+]] = getelementptr inbounds i8, i8* %p3, i64 8 +; CHECK: call void @llvm.memset.p0i8.i64(i8* [[GEP]], i8 0, i64 24, i32 4, i1 false) + call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 4, i1 false) + %p4 = bitcast i32* %p to i64* + store i64 1, i64* %p4, align 8 + ret void +} + +define void @write0to8_2(i32* nocapture %p) { +; CHECK-LABEL: @write0to8_2( +entry: + %arrayidx0 = getelementptr inbounds i32, i32* %p, i64 1 + %p3 = bitcast i32* %arrayidx0 to i8* +; CHECK: [[GEP:%[0-9]+]] = getelementptr inbounds i8, i8* %p3, i64 4 +; CHECK: call void @llvm.memset.p0i8.i64(i8* [[GEP]], i8 0, i64 24, i32 4, i1 false) + call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 28, i32 4, i1 false) + %p4 = bitcast i32* %p to i64* + store i64 1, i64* %p4, align 8 + ret void +} + +define void @dontwrite0to4_align8(i32* nocapture %p) { +; CHECK-LABEL: @dontwrite0to4_align8( +entry: + %p3 = bitcast i32* %p to i8* +; CHECK: call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 8, i1 false) + call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 8, i1 false) + store i32 1, i32* %p, align 4 + ret void +} + +define void @dontwrite0to2(i32* nocapture %p) { +; CHECK-LABEL: @dontwrite0to2( +entry: + %p3 = bitcast i32* %p to i8* +; CHECK: call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 4, i1 false) + call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 4, i1 false) + %p4 = bitcast i32* %p to i16* + store i16 1, i16* %p4, align 4 + ret void +} + +define void @dontwrite2to10(i32* nocapture %p) { +; CHECK-LABEL: @dontwrite2to10( +entry: + %arrayidx0 = getelementptr inbounds i32, i32* %p, i64 1 + %p3 = bitcast i32* %arrayidx0 to i8* +; CHECK: call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 4, i1 false) + call void @llvm.memset.p0i8.i64(i8* %p3, i8 0, i64 32, i32 4, i1 false) + %p4 = bitcast i32* %p to i16* + %arrayidx2 = getelementptr inbounds i16, i16* %p4, i64 1 + %p5 = bitcast i16* %arrayidx2 to i64* + store i64 1, i64* %p5, align 8 + ret void +} + +declare void @llvm.memset.p0i8.i64(i8* nocapture, i8, i64, i32, i1) nounwind +