Index: lib/Transforms/Scalar/DeadStoreElimination.cpp =================================================================== --- lib/Transforms/Scalar/DeadStoreElimination.cpp +++ lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -16,6 +16,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/DeadStoreElimination.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" @@ -34,10 +35,12 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" +#include using namespace llvm; #define DEBUG_TYPE "dse" @@ -45,6 +48,12 @@ STATISTIC(NumRedundantStores, "Number of redundant stores deleted"); STATISTIC(NumFastStores, "Number of stores deleted"); STATISTIC(NumFastOther , "Number of other instrs removed"); +STATISTIC(NumCompletePartials, "Number of stores dead by later partials"); + +static cl::opt +EnablePartialOverwriteTracking("enable-dse-partial-overwrite-tracking", + cl::init(true), cl::Hidden, + cl::desc("Enable partial-overwrite tracking in DSE")); //===----------------------------------------------------------------------===// @@ -279,6 +288,9 @@ }; } +typedef DenseMap> InstOverlapIntervalsTy; + /// Return 'OverwriteComplete' if a store to the 'Later' location completely /// overwrites a store to the 'Earlier' location, 'OverwriteEnd' if the end of /// the 'Earlier' location is completely overwritten by 'Later', @@ -288,7 +300,9 @@ const MemoryLocation &Earlier, const DataLayout &DL, const TargetLibraryInfo &TLI, - int64_t &EarlierOff, int64_t &LaterOff) { + int64_t &EarlierOff, int64_t &LaterOff, + Instruction *DepWrite, + InstOverlapIntervalsTy &IOL) { const Value *P1 = Earlier.Ptr->stripPointerCasts(); const Value *P2 = Later.Ptr->stripPointerCasts(); @@ -361,6 +375,60 @@ uint64_t(EarlierOff - LaterOff) + Earlier.Size <= Later.Size) return OverwriteComplete; + // We may now overlap, although the overlap is not complete. There might also + // be other incomplete overlaps, and together, they might cover the complete + // earlier write. + // Note: The correctness of this logic depends on the fact that this function + // is not even called providing DepWrite when there are any intervening reads. + if (EnablePartialOverwriteTracking && + LaterOff < int64_t(EarlierOff + Earlier.Size) && + int64_t(LaterOff + Later.Size) >= EarlierOff) { + + // Insert our part of the overlap into the map. + auto &IM = IOL[DepWrite]; + DEBUG(dbgs() << "DSE: Partial overwrite: Earlier [" << EarlierOff << ", " << + int64_t(EarlierOff + Earlier.Size) << ") Later [" << + LaterOff << ", " << int64_t(LaterOff + Later.Size) << ")\n"); + + // Make sure that we only insert non-overlapping intervals and combine + // adjacent intervals. The intervals are stored in the map with the ending + // offset as the key (in the half-open sense) and the starting offset as + // the value. + int64_t LaterIntStart = LaterOff, LaterIntEnd = LaterOff + Later.Size; + + // Find any intervals ending at, or after, LaterIntStart which start + // before LaterIntEnd. + auto ILI = IM.lower_bound(LaterIntStart); + if (ILI != IM.end() && ILI->second < LaterIntEnd) { + // This existing interval ends in the middle of + // [LaterIntStart, LaterIntEnd), erase it adjusting our start. + LaterIntStart = std::min(LaterIntStart, ILI->second); + ILI = IM.erase(ILI); + + while (ILI != IM.end() && ILI->first <= LaterIntEnd) + ILI = IM.erase(ILI); + + if (ILI != IM.end() && ILI->second < LaterIntEnd) { + LaterIntEnd = std::max(LaterIntEnd, ILI->first); + ILI = IM.erase(ILI); + } + } + + IM[LaterIntEnd] = LaterIntStart; + + ILI = IM.begin(); + if (ILI->second <= EarlierOff && + ILI->first >= int64_t(EarlierOff + Earlier.Size)) { + DEBUG(dbgs() << "DSE: Full overwrite from partials: Earlier [" << + EarlierOff << ", " << + int64_t(EarlierOff + Earlier.Size) << + ") Composite Later [" << + ILI->second << ", " << ILI->first << ")\n"); + ++NumCompletePartials; + return OverwriteComplete; + } + } + // Another interesting case is if the later store overwrites the end of the // earlier store. // @@ -751,6 +819,9 @@ const DataLayout &DL = BB.getModule()->getDataLayout(); bool MadeChange = false; + // A map of interval maps representing partially-overwritten value parts. + InstOverlapIntervalsTy IOL; + // Do a top-down walk on the BB. for (BasicBlock::iterator BBI = BB.begin(), BBE = BB.end(); BBI != BBE; ) { Instruction *Inst = &*BBI++; @@ -852,7 +923,8 @@ !isPossibleSelfRead(Inst, Loc, DepWrite, *TLI, *AA)) { int64_t InstWriteOffset, DepWriteOffset; OverwriteResult OR = - isOverwrite(Loc, DepLoc, DL, *TLI, DepWriteOffset, InstWriteOffset); + isOverwrite(Loc, DepLoc, DL, *TLI, DepWriteOffset, InstWriteOffset, + DepWrite, IOL); if (OR == OverwriteComplete) { DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *DepWrite << "\n KILLER: " << *Inst << '\n'); Index: test/Transforms/DeadStoreElimination/combined-partial-overwrites.ll =================================================================== --- /dev/null +++ test/Transforms/DeadStoreElimination/combined-partial-overwrites.ll @@ -0,0 +1,141 @@ +; RUN: opt -S -dse < %s | FileCheck %s +target datalayout = "E-m:e-i64:64-n32:64" +target triple = "powerpc64-bgq-linux" + +%"struct.std::complex" = type { { float, float } } + +define void @_Z4testSt7complexIfE(%"struct.std::complex"* noalias nocapture sret %agg.result, i64 %c.coerce) { +entry: +; CHECK-LABEL: @_Z4testSt7complexIfE + + %ref.tmp = alloca i64, align 8 + %tmpcast = bitcast i64* %ref.tmp to %"struct.std::complex"* + %c.sroa.0.0.extract.shift = lshr i64 %c.coerce, 32 + %c.sroa.0.0.extract.trunc = trunc i64 %c.sroa.0.0.extract.shift to i32 + %0 = bitcast i32 %c.sroa.0.0.extract.trunc to float + %c.sroa.2.0.extract.trunc = trunc i64 %c.coerce to i32 + %1 = bitcast i32 %c.sroa.2.0.extract.trunc to float + call void @_Z3barSt7complexIfE(%"struct.std::complex"* nonnull sret %tmpcast, i64 %c.coerce) + %2 = bitcast %"struct.std::complex"* %agg.result to i64* + %3 = load i64, i64* %ref.tmp, align 8 + store i64 %3, i64* %2, align 4 +; CHECK-NOT: store i64 + + %_M_value.realp.i.i = getelementptr inbounds %"struct.std::complex", %"struct.std::complex"* %agg.result, i64 0, i32 0, i32 0 + %4 = lshr i64 %3, 32 + %5 = trunc i64 %4 to i32 + %6 = bitcast i32 %5 to float + %_M_value.imagp.i.i = getelementptr inbounds %"struct.std::complex", %"struct.std::complex"* %agg.result, i64 0, i32 0, i32 1 + %7 = trunc i64 %3 to i32 + %8 = bitcast i32 %7 to float + %mul_ad.i.i = fmul fast float %6, %1 + %mul_bc.i.i = fmul fast float %8, %0 + %mul_i.i.i = fadd fast float %mul_ad.i.i, %mul_bc.i.i + %mul_ac.i.i = fmul fast float %6, %0 + %mul_bd.i.i = fmul fast float %8, %1 + %mul_r.i.i = fsub fast float %mul_ac.i.i, %mul_bd.i.i + store float %mul_r.i.i, float* %_M_value.realp.i.i, align 4 + store float %mul_i.i.i, float* %_M_value.imagp.i.i, align 4 + ret void +; CHECK: ret void +} + +declare void @_Z3barSt7complexIfE(%"struct.std::complex"* sret, i64) + +define void @test1(i32 *%ptr) { +entry: +; CHECK-LABEL: @test1 + + store i32 5, i32* %ptr + %bptr = bitcast i32* %ptr to i8* + store i8 7, i8* %bptr + %wptr = bitcast i32* %ptr to i16* + store i16 -30062, i16* %wptr + %bptr2 = getelementptr inbounds i8, i8* %bptr, i64 2 + store i8 25, i8* %bptr2 + %bptr3 = getelementptr inbounds i8, i8* %bptr, i64 3 + store i8 47, i8* %bptr3 + %bptr1 = getelementptr inbounds i8, i8* %bptr, i64 1 + %wptrp = bitcast i8* %bptr1 to i16* + store i16 2020, i16* %wptrp, align 1 + ret void + +; CHECK-NOT: store i32 5, i32* %ptr +; CHECK-NOT: store i8 7, i8* %bptr +; CHECK: store i16 -30062, i16* %wptr +; CHECK-NOT: store i8 25, i8* %bptr2 +; CHECK: store i8 47, i8* %bptr3 +; CHECK: store i16 2020, i16* %wptrp, align 1 + +; CHECK: ret void +} + +define void @test2(i32 *%ptr) { +entry: +; CHECK-LABEL: @test2 + + store i32 5, i32* %ptr + + %bptr = bitcast i32* %ptr to i8* + %bptrm1 = getelementptr inbounds i8, i8* %bptr, i64 -1 + %bptr1 = getelementptr inbounds i8, i8* %bptr, i64 1 + %bptr2 = getelementptr inbounds i8, i8* %bptr, i64 2 + %bptr3 = getelementptr inbounds i8, i8* %bptr, i64 3 + + %wptr = bitcast i8* %bptr to i16* + %wptrm1 = bitcast i8* %bptrm1 to i16* + %wptr1 = bitcast i8* %bptr1 to i16* + %wptr2 = bitcast i8* %bptr2 to i16* + %wptr3 = bitcast i8* %bptr3 to i16* + + store i16 1456, i16* %wptrm1, align 1 + store i16 1346, i16* %wptr, align 1 + store i16 1756, i16* %wptr1, align 1 + store i16 1126, i16* %wptr2, align 1 + store i16 5656, i16* %wptr3, align 1 + +; CHECK-NOT: store i32 5, i32* %ptr + +; CHECK: store i16 1456, i16* %wptrm1, align 1 +; CHECK: store i16 1346, i16* %wptr, align 1 +; CHECK: store i16 1756, i16* %wptr1, align 1 +; CHECK: store i16 1126, i16* %wptr2, align 1 +; CHECK: store i16 5656, i16* %wptr3, align 1 + + ret void + +; CHECK: ret void +} + +define signext i8 @test3(i32 *%ptr) { +entry: +; CHECK-LABEL: @test3 + + store i32 5, i32* %ptr + + %bptr = bitcast i32* %ptr to i8* + %bptrm1 = getelementptr inbounds i8, i8* %bptr, i64 -1 + %bptr1 = getelementptr inbounds i8, i8* %bptr, i64 1 + %bptr2 = getelementptr inbounds i8, i8* %bptr, i64 2 + %bptr3 = getelementptr inbounds i8, i8* %bptr, i64 3 + + %wptr = bitcast i8* %bptr to i16* + %wptrm1 = bitcast i8* %bptrm1 to i16* + %wptr1 = bitcast i8* %bptr1 to i16* + %wptr2 = bitcast i8* %bptr2 to i16* + %wptr3 = bitcast i8* %bptr3 to i16* + + %v = load i8, i8* %bptr, align 1 + store i16 1456, i16* %wptrm1, align 1 + store i16 1346, i16* %wptr, align 1 + store i16 1756, i16* %wptr1, align 1 + store i16 1126, i16* %wptr2, align 1 + store i16 5656, i16* %wptr3, align 1 + +; CHECK: store i32 5, i32* %ptr + + ret i8 %v + +; CHECK: ret i8 %v +} +