Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -459,6 +459,10 @@ /// and the number of execution units in the CPU. unsigned getMaxInterleaveFactor(unsigned VF) const; + /// \rreturn The maximum number of store operations permitted to replace a + /// call to llvm.memset. + unsigned getMaxStoresPerMemset(Function &F) const; + /// \return The expected cost of arithmetic ops, such as mul, xor, fsub, etc. int getArithmeticInstrCost( unsigned Opcode, Type *Ty, OperandValueKind Opd1Info = OK_AnyValue, @@ -659,6 +663,7 @@ virtual unsigned getMinPrefetchStride() = 0; virtual unsigned getMaxPrefetchIterationsAhead() = 0; virtual unsigned getMaxInterleaveFactor(unsigned VF) = 0; + virtual unsigned getMaxStoresPerMemset(Function &F) = 0; virtual unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, @@ -847,6 +852,9 @@ unsigned getMaxInterleaveFactor(unsigned VF) override { return Impl.getMaxInterleaveFactor(VF); } + unsigned getMaxStoresPerMemset(Function &F) override { + return Impl.getMaxStoresPerMemset(F); + } unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -278,6 +278,8 @@ unsigned getMaxInterleaveFactor(unsigned VF) { return 1; } + unsigned getMaxStoresPerMemset(Function &F) { return 0; } + unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::OperandValueKind Opd1Info, TTI::OperandValueKind Opd2Info, Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -284,6 +284,11 @@ unsigned getMaxInterleaveFactor(unsigned VF) { return 1; } + unsigned getMaxStoresPerMemset(Function &F) { + bool OptSize = F.optForMinSize() || F.optForMinSize(); + return getTLI()->getMaxStoresPerMemset(OptSize); + } + unsigned getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::OperandValueKind Opd1Info = TTI::OK_AnyValue, Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -243,6 +243,10 @@ return TTIImpl->getMaxInterleaveFactor(VF); } +unsigned TargetTransformInfo::getMaxStoresPerMemset(Function &F) const { + return TTIImpl->getMaxStoresPerMemset(F); +} + int TargetTransformInfo::getArithmeticInstrCost( unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, OperandValueProperties Opd1PropInfo, Index: lib/Transforms/Scalar/DeadStoreElimination.cpp =================================================================== --- lib/Transforms/Scalar/DeadStoreElimination.cpp +++ lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -25,6 +25,7 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -208,6 +209,104 @@ return false; } +static unsigned getExpectedNumStores(unsigned SizeInBytes, + unsigned MaxIntSizeInBytes) { + // Since we don't have perfect knowledge here, assume that the maximum GPR + // width is the same size as the largest legal integer size. + unsigned NumWideStores = (SizeInBytes / MaxIntSizeInBytes); + // Conservatively assume the remaining bytes as a byte at a time. + unsigned NumNarrowStores = (SizeInBytes % MaxIntSizeInBytes); + return NumWideStores + NumNarrowStores; +} + +/// Return true if splitting memset into two parts is profitable in terms of +/// the number of stores when lowered from the memset. +static bool isSplittingProfitable(Instruction *EarlierStore, + int64_t NewLengthPart1, + int64_t NewLengthPart2, const DataLayout &DL, + const TargetTransformInfo *TTI) { + // FIXME: Split only memset for now. Supporting memcpy/memmove is also + // possible. + MemSetInst *MSI = dyn_cast(EarlierStore); + if (!MSI) + return false; + + unsigned MaxIntSize = DL.getLargestLegalIntTypeSizeInBits() / 8; + // Return false if we don't have information about the legal integer size. + if (MaxIntSize == 0) + return false; + + unsigned OrigLength = cast(MSI->getLength())->getZExtValue(); + // Return false if the memset size is equal to the largest integer size + // because it's going to be only one store, or if the memset size is + // smaller than the largest integer size because in this case, the total + // nubmer of narrow stores could be different depending on backend. + if (OrigLength <= MaxIntSize) + return false; + + // Check if the existing memset is small enough to be lowered to stores later. + unsigned NumStoresInOrigMemset = getExpectedNumStores(OrigLength, MaxIntSize); + if (TTI->getMaxStoresPerMemset(*(MSI->getParent()->getParent())) < + NumStoresInOrigMemset) + return false; + + // Make sure that each part is also larger than the largest integer size. + if (NewLengthPart1 < MaxIntSize || NewLengthPart2 < MaxIntSize) + return false; + + // The expected number of stores after splitting should be less than the + // expected number of stores from the original memset. + return NumStoresInOrigMemset > + (getExpectedNumStores(NewLengthPart1, MaxIntSize) + + getExpectedNumStores(NewLengthPart2, MaxIntSize)); +} + +static bool tryToSplitStore(Instruction *EarlierStore, int64_t LaterOffset, + int64_t EarlierOffset, MemoryLocation &EarlierLoc, + MemoryLocation &LaterLoc, const DataLayout &DL, + const TargetTransformInfo *TTI) { + + int64_t NewLengthPart1 = LaterOffset - EarlierOffset; + int64_t NewLengthPart2 = int64_t(EarlierOffset + EarlierLoc.Size) - + int64_t(LaterOffset + LaterLoc.Size); + + if (!isSplittingProfitable(EarlierStore, NewLengthPart1, NewLengthPart2, DL, + TTI)) + return false; + + MemIntrinsic *DepIntrinsic = cast(EarlierStore); + unsigned DepWriteAlign = DepIntrinsic->getAlignment(); + int64_t NewWriteOffset = int64_t(LaterOffset + LaterLoc.Size); + + if (!((llvm::isPowerOf2_64(NewWriteOffset) && + DepWriteAlign <= NewWriteOffset) || + ((DepWriteAlign != 0) && NewWriteOffset % DepWriteAlign == 0))) + return false; + + DEBUG(dbgs() << "DSE: Split MemIntrinsic :\n " << *EarlierStore << "\n"); + DEBUG(dbgs() << "into :\n"); + + Value *DepWriteLength = DepIntrinsic->getLength(); + Value *NewLengthPart1Val = + ConstantInt::get(DepWriteLength->getType(), NewLengthPart1); + Value *NewLengthPart2Val = + ConstantInt::get(DepWriteLength->getType(), NewLengthPart2); + + DepIntrinsic->setLength(NewLengthPart1Val); + DEBUG(dbgs() << " " << *DepIntrinsic << "\n"); + + int64_t OffsetMoved = NewWriteOffset - EarlierOffset; + Value *Indices[1] = { + ConstantInt::get(DepWriteLength->getType(), OffsetMoved)}; + GetElementPtrInst *NewDestGEP = GetElementPtrInst::CreateInBounds( + DepIntrinsic->getRawDest(), Indices, "", EarlierStore); + MemSetInst *DepIntrinsic2 = dyn_cast(DepIntrinsic->clone()); + DepIntrinsic2->setDest(NewDestGEP); + DepIntrinsic2->setLength(NewLengthPart2Val); + DepIntrinsic2->insertAfter(DepIntrinsic); + DEBUG(dbgs() << " " << *DepIntrinsic2 << "\n"); + return true; +} /// Returns true if the end of this instruction can be safely shortened in /// length. @@ -275,6 +374,7 @@ OverwriteBegin, OverwriteComplete, OverwriteEnd, + OverwritePartial, OverwriteUnknown }; } @@ -374,8 +474,8 @@ int64_t(LaterOff + Later.Size) >= int64_t(EarlierOff + Earlier.Size)) return OverwriteEnd; - // Finally, we also need to check if the later store overwrites the beginning - // of the earlier store. + // We also need to check if the later store overwrites the beginning of the + // earlier store. // // |--earlier--| // |-- later --| @@ -388,6 +488,19 @@ && "Expect to be handled as OverwriteComplete" ); return OverwriteBegin; } + + // Finally, we need to check if the later store partially overwrites the + // earlier in the middle. + // |------earlier------| + // |-- later --| + // + // In this case we may want to split the earlier just to write to addresses + // which are not covered by later. + if (LaterOff > EarlierOff && + int64_t(LaterOff + Later.Size) < int64_t(EarlierOff + Earlier.Size)) { + return OverwritePartial; + } + // Otherwise, they don't completely overlap. return OverwriteUnknown; } @@ -747,7 +860,8 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, MemoryDependenceResults *MD, DominatorTree *DT, - const TargetLibraryInfo *TLI) { + const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI) { const DataLayout &DL = BB.getModule()->getDataLayout(); bool MadeChange = false; @@ -913,6 +1027,10 @@ } MadeChange = true; } + } else if (OR == OverwritePartial) { + if (tryToSplitStore(DepWrite, InstWriteOffset, DepWriteOffset, DepLoc, + Loc, DL, TTI)) + MadeChange = true; } } @@ -945,13 +1063,14 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis *AA, MemoryDependenceResults *MD, DominatorTree *DT, - const TargetLibraryInfo *TLI) { + const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI) { bool MadeChange = false; for (BasicBlock &BB : F) // Only check non-dead blocks. Dead blocks may have strange pointer // cycles that will confuse alias analysis. if (DT->isReachableFromEntry(&BB)) - MadeChange |= eliminateDeadStores(BB, AA, MD, DT, TLI); + MadeChange |= eliminateDeadStores(BB, AA, MD, DT, TLI, TTI); return MadeChange; } @@ -963,8 +1082,9 @@ DominatorTree *DT = &AM.getResult(F); MemoryDependenceResults *MD = &AM.getResult(F); const TargetLibraryInfo *TLI = &AM.getResult(F); + const TargetTransformInfo *TTI = &AM.getResult(F); - if (!eliminateDeadStores(F, AA, MD, DT, TLI)) + if (!eliminateDeadStores(F, AA, MD, DT, TLI, TTI)) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve(); @@ -990,8 +1110,10 @@ &getAnalysis().getMemDep(); const TargetLibraryInfo *TLI = &getAnalysis().getTLI(); + const TargetTransformInfo *TTI = + &getAnalysis().getTTI(F); - return eliminateDeadStores(F, AA, MD, DT, TLI); + return eliminateDeadStores(F, AA, MD, DT, TLI, TTI); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -1000,6 +1122,7 @@ AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addRequired(); AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); @@ -1016,6 +1139,7 @@ INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(DSELegacyPass, "dse", "Dead Store Elimination", false, false) Index: test/CodeGen/AArch64/aarch64-small-memset-lowering.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/aarch64-small-memset-lowering.ll @@ -0,0 +1,34 @@ +; RUN: llc -mtriple=arm64-linux-gnu < %s | FileCheck %s + +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64--linux-gnu" + +define void @test_32_8to15(i64* nocapture %P, i64 %n64) { +; CHECK-LABEL: test_32_8to15 +; CHECK: stp xzr, xzr, [x0, #16] +; CHECK: stp xzr, x1, [x0] +entry: + %Base = bitcast i64* %P to i8* + %Base2 = getelementptr inbounds i8, i8* %Base, i64 16 + call void @llvm.memset.p0i8.i64(i8* %Base, i8 0, i64 8, i32 8, i1 false) + call void @llvm.memset.p0i8.i64(i8* %Base2, i8 0, i64 16, i32 8, i1 false) + %arrayidx1 = getelementptr inbounds i64, i64* %P, i64 1 + store i64 %n64, i64* %arrayidx1 + ret void +} + +define void @test_32_8to23(i64* nocapture %P, i64 %n64) { +; CHECK-LABEL: test_32_8to23 +; CHECK: stp [[REG:x[0-9]+]], xzr, [x0, #16] +; CHECK: stp xzr, [[REG]], [x0] +entry: + %Base = bitcast i64* %P to i8* + %Base2 = getelementptr inbounds i8, i8* %Base, i64 24 + call void @llvm.memset.p0i8.i64(i8* %Base, i8 0, i64 8, i32 8, i1 false) + call void @llvm.memset.p0i8.i64(i8* %Base2, i8 0, i64 8, i32 8, i1 false) + %arrayidx2 = getelementptr inbounds i8, i8* %Base, i64 8 + call void @llvm.memset.p0i8.i64(i8* %arrayidx2, i8 1, i64 16, i32 8, i1 false) + ret void +} + +declare void @llvm.memset.p0i8.i64(i8* nocapture, i8, i64, i32, i1) Index: test/Transforms/DeadStoreElimination/SplitMemintrinsic.ll =================================================================== --- /dev/null +++ test/Transforms/DeadStoreElimination/SplitMemintrinsic.ll @@ -0,0 +1,81 @@ +; RUN: opt < %s -basicaa -dse -S | FileCheck %s + +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64--linux-gnu" + +define void @test_32_8to15(i64* nocapture %P, i64 %n64) { +; CHECK: [[BASE2:%[0-9]+]] = getelementptr inbounds i8, i8* %Base, i64 16 +; CHECK: call void @llvm.memset.p0i8.i64(i8* %Base, i8 0, i64 8, i32 8, i1 false) +; CHECK: call void @llvm.memset.p0i8.i64(i8* [[BASE2]], i8 0, i64 16, i32 8, i1 false) +entry: + %Base = bitcast i64* %P to i8* + call void @llvm.memset.p0i8.i64(i8* %Base, i8 0, i64 32, i32 8, i1 false) + %arrayidx1 = getelementptr inbounds i64, i64* %P, i64 1 + store i64 %n64, i64* %arrayidx1 + ret void +} + +define void @test_32_8to23(i64* nocapture %P, i64 %n64) { +; CHECK-LABEL: test_32_8to23 +; CHECK: [[BASE2:%[0-9]+]] = getelementptr inbounds i8, i8* %Base, i64 24 +; CHECK: call void @llvm.memset.p0i8.i64(i8* %Base, i8 0, i64 8, i32 8, i1 false) +; CHECK: call void @llvm.memset.p0i8.i64(i8* [[BASE2]], i8 0, i64 8, i32 8, i1 false) +entry: + %Base = bitcast i64* %P to i8* + call void @llvm.memset.p0i8.i64(i8* %Base, i8 0, i64 32, i32 8, i1 false) + %arrayidx2 = getelementptr inbounds i8, i8* %Base, i64 8 + call void @llvm.memset.p0i8.i64(i8* %arrayidx2, i8 1, i64 16, i32 8, i1 false) + ret void +} + +define void @test_32_4x8(i64* nocapture %P, i64 %n64) { +; The memset should be completely removed by later overlapped stores. +; CHECK-LABEL: @test_32_4x8( +; CHECK-NOT: call void @llvm.memset.p0i8.i64 +entry: + %Base = bitcast i64* %P to i8* + call void @llvm.memset.p0i8.i64(i8* %Base, i8 0, i64 32, i32 8, i1 false) + ; P[2] + %arrayidx2 = getelementptr inbounds i64, i64* %P, i64 2 + store i64 %n64, i64* %arrayidx2 + ; P[3] + %arrayidx3 = getelementptr inbounds i64, i64* %P, i64 3 + store i64 %n64, i64* %arrayidx3 + ; P[1] + %arrayidx1 = getelementptr inbounds i64, i64* %P, i64 1 + store i64 %n64, i64* %arrayidx1 + ; P[0] + %arrayidx0 = getelementptr inbounds i64, i64* %P, i64 0 + store i64 %n64, i64* %arrayidx0 + ret void +} + +define void @test_32_8to11(i64* nocapture %P, i32 %n32) { +; Splitting memset may be unprofitable as the laster store could cause +; fractions if split. +; CHECK-LABEL: test_32_8to11 +; CHECK: call void @llvm.memset.p0i8.i64(i8* %Base, i8 0, i64 32, i32 8, i1 false) +entry: + %Base = bitcast i64* %P to i8* + call void @llvm.memset.p0i8.i64(i8* %Base, i8 0, i64 32, i32 8, i1 false) + %arrayidx1 = getelementptr inbounds i64, i64* %P, i64 1 + %arrayidx1_32 = bitcast i64* %arrayidx1 to i32* + store i32 %n32, i32* %arrayidx1_32 + ret void +} + +define void @test_34_8to9(i64* nocapture %P, i16 %n16) { +; This should not be split due to the alignment in the 2nd memset if split. +; CHECK-LABEL: test_34_8to9 +; CHECK: call void @llvm.memset.p0i8.i64(i8* %Base, i8 0, i64 34, i32 8, i1 false) +entry: + %Base = bitcast i64* %P to i8* + call void @llvm.memset.p0i8.i64(i8* %Base, i8 0, i64 34, i32 8, i1 false) + %arrayidx1 = getelementptr inbounds i64, i64* %P, i64 1 + %arrayidx1_16 = bitcast i64* %arrayidx1 to i16* + store i16 %n16, i16* %arrayidx1_16 + ret void +} + +declare void @llvm.memset.p0i8.i64(i8* nocapture, i8, i64, i32, i1) +