Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -459,6 +459,11 @@ /// and the number of execution units in the CPU. unsigned getMaxInterleaveFactor(unsigned VF) const; + /// \rreturn The maximum number of store operations permitted to replace a + /// call to llvm.memset. If OptSize is true, return the limit for functions + /// that have OptSize attribute. + unsigned getMaxStoresPerMemset(bool OptSize) const; + /// \return The expected cost of arithmetic ops, such as mul, xor, fsub, etc. int getArithmeticInstrCost( unsigned Opcode, Type *Ty, OperandValueKind Opd1Info = OK_AnyValue, @@ -659,6 +664,7 @@ virtual unsigned getMinPrefetchStride() = 0; virtual unsigned getMaxPrefetchIterationsAhead() = 0; virtual unsigned getMaxInterleaveFactor(unsigned VF) = 0; + virtual unsigned getMaxStoresPerMemset(bool OptSize) = 0; virtual unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, @@ -847,6 +853,9 @@ unsigned getMaxInterleaveFactor(unsigned VF) override { return Impl.getMaxInterleaveFactor(VF); } + unsigned getMaxStoresPerMemset(bool OptSize) override { + return Impl.getMaxStoresPerMemset(OptSize); + } unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -278,6 +278,8 @@ unsigned getMaxInterleaveFactor(unsigned VF) { return 1; } + unsigned getMaxStoresPerMemset(bool OptSize) { return 0; } + unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::OperandValueKind Opd1Info, TTI::OperandValueKind Opd2Info, Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -284,6 +284,10 @@ unsigned getMaxInterleaveFactor(unsigned VF) { return 1; } + unsigned getMaxStoresPerMemset(bool OptSize) { + return getTLI()->getMaxStoresPerMemset(OptSize); + } + unsigned getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::OperandValueKind Opd1Info = TTI::OK_AnyValue, Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -243,6 +243,10 @@ return TTIImpl->getMaxInterleaveFactor(VF); } +unsigned TargetTransformInfo::getMaxStoresPerMemset(bool OptSize) const { + return TTIImpl->getMaxStoresPerMemset(OptSize); +} + int TargetTransformInfo::getArithmeticInstrCost( unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, OperandValueProperties Opd1PropInfo, Index: lib/Transforms/Scalar/DeadStoreElimination.cpp =================================================================== --- lib/Transforms/Scalar/DeadStoreElimination.cpp +++ lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -25,6 +25,7 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -51,6 +52,7 @@ MemoryDependenceResults *MD; DominatorTree *DT; const TargetLibraryInfo *TLI; + const TargetTransformInfo *TTI; static char ID; // Pass identification, replacement for typeid DSE() : FunctionPass(ID), AA(nullptr), MD(nullptr), DT(nullptr) { @@ -65,6 +67,7 @@ MD = &getAnalysis().getMemDep(); DT = &getAnalysis().getDomTree(); TLI = &getAnalysis().getTLI(); + TTI = &getAnalysis().getTTI(F); bool Changed = false; for (BasicBlock &I : F) @@ -91,6 +94,7 @@ AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addRequired(); AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); @@ -105,6 +109,7 @@ INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(DSE, "dse", "Dead Store Elimination", false, false) FunctionPass *llvm::createDeadStoreEliminationPass() { return new DSE(); } @@ -274,6 +279,51 @@ return false; } +static unsigned getExpectedNumStores(unsigned SizeInBytes, + unsigned MaxIntSizeInBytes) { + // Since we don't have perfect knowledge here, assume that the maximum GPR + // width is the same size as the largest legal integer size. + unsigned NumWideStores = (SizeInBytes / MaxIntSizeInBytes); + // Conservatively assume the remaining bytes as a byte at a time. + unsigned NumNarrowStores = (SizeInBytes % MaxIntSizeInBytes); + return NumWideStores + NumNarrowStores; +} + +/// Return true if splitting memset into two parts is profitable in terms of +/// the number of stores lowered from the memset. +static bool isSplittingProfitable(Instruction *EarlierStore, + int64_t NewLengthPart1, + int64_t NewLengthPart2, const DataLayout &DL, + const TargetTransformInfo *TTI) { + // FIXME: Split only memset for now. Supporting memcpy/memmove is also + // possible. + MemSetInst *MSI = dyn_cast(EarlierStore); + if (!MSI) + return false; + + // First make sure that the existing memset could be lowered to stores later. + unsigned OrigLength = cast(MSI->getLength())->getZExtValue(); + unsigned MaxIntSize = DL.getLargestLegalIntTypeSizeInBits() / 8; + + if (MaxIntSize == 0) + MaxIntSize = 1; + + if (OrigLength <= MaxIntSize) + return false; + + unsigned NumStoresInOrigMemset = getExpectedNumStores(OrigLength, MaxIntSize); + bool OptSize = MSI->getParent()->getParent()->optForSize(); + + if (TTI->getMaxStoresPerMemset(OptSize) < NumStoresInOrigMemset) + return false; + + if (NewLengthPart1 < MaxIntSize || NewLengthPart2 < MaxIntSize) + return false; + + return NumStoresInOrigMemset > + (getExpectedNumStores(NewLengthPart1, MaxIntSize) + + getExpectedNumStores(NewLengthPart2, MaxIntSize)); +} /// Returns true if the end of this instruction can be safely shortened in /// length. @@ -341,6 +391,7 @@ OverwriteBegin, OverwriteComplete, OverwriteEnd, + OverwritePartial, OverwriteUnknown }; } @@ -454,6 +505,12 @@ && "Expect to be handled as OverwriteComplete" ); return OverwriteBegin; } + + if (LaterOff > EarlierOff && + int64_t(LaterOff + Later.Size) < int64_t(EarlierOff + Earlier.Size)) { + return OverwritePartial; + } + // Otherwise, they don't completely overlap. return OverwriteUnknown; } @@ -673,6 +730,47 @@ } MadeChange = true; } + } else if (OR == OverwritePartial) { + + int64_t NewLengthPart1 = InstWriteOffset - DepWriteOffset; + int64_t NewLengthPart2 = int64_t(DepWriteOffset + DepLoc.Size) - + int64_t(InstWriteOffset + Loc.Size); + if (isSplittingProfitable(DepWrite, NewLengthPart1, NewLengthPart2, + DL, TTI)) { + MemIntrinsic *DepIntrinsic = cast(DepWrite); + unsigned DepWriteAlign = DepIntrinsic->getAlignment(); + int64_t NewWriteOffset = int64_t(InstWriteOffset + Loc.Size); + + if ((llvm::isPowerOf2_64(NewWriteOffset) && + DepWriteAlign <= NewWriteOffset) || + ((DepWriteAlign != 0) && NewWriteOffset % DepWriteAlign == 0)) { + DEBUG(dbgs() << "DSE: Split MemIntrinsic :\n " << *DepWrite + << "\n"); + DEBUG(dbgs() << "into :\n"); + + Value *DepWriteLength = DepIntrinsic->getLength(); + Value *NewLengthPart1Val = + ConstantInt::get(DepWriteLength->getType(), NewLengthPart1); + Value *NewLengthPart2Val = + ConstantInt::get(DepWriteLength->getType(), NewLengthPart2); + + DepIntrinsic->setLength(NewLengthPart1Val); + DEBUG(dbgs() << " " << *DepIntrinsic << "\n"); + + int64_t OffsetMoved = NewWriteOffset - DepWriteOffset; + Value *Indices[1] = { + ConstantInt::get(DepWriteLength->getType(), OffsetMoved)}; + GetElementPtrInst *NewDestGEP = GetElementPtrInst::CreateInBounds( + DepIntrinsic->getRawDest(), Indices, "", DepWrite); + MemSetInst *DepIntrinsic2 = + dyn_cast(DepIntrinsic->clone()); + DepIntrinsic2->setDest(NewDestGEP); + DepIntrinsic2->setLength(NewLengthPart2Val); + DepIntrinsic2->insertAfter(DepIntrinsic); + DEBUG(dbgs() << " " << *DepIntrinsic2 << "\n"); + MadeChange = true; + } + } } } Index: test/Transforms/DeadStoreElimination/SplitMemintrinsic.ll =================================================================== --- /dev/null +++ test/Transforms/DeadStoreElimination/SplitMemintrinsic.ll @@ -0,0 +1,102 @@ +; RUN: opt < %s -basicaa -dse -mtriple=arm64-linux-gnu -S | FileCheck %s + +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" + +define void @test_32_8to15(i64* nocapture %P, i64 %n64) { +; CHECK-LABEL: @test_32_8to15( +; CHECK: %1 = getelementptr inbounds i8, i8* %0, i64 16 +; CHECK: call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 8, i32 8, i1 false) +; CHECK: call void @llvm.memset.p0i8.i64(i8* %1, i8 0, i64 16, i32 8, i1 false) +entry: + %0 = bitcast i64* %P to i8* + call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 32, i32 8, i1 false) + %arrayidx1 = getelementptr inbounds i64, i64* %P, i64 1 + store i64 %n64, i64* %arrayidx1 + ret void +} + +define void @test_32_8to23(i64* nocapture %P, i64 %n64) { +; CHECK-LABEL: @test_32_8to23( +; CHECK: %1 = getelementptr inbounds i8, i8* %0, i64 24 +; CHECK: call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 8, i32 8, i1 false) +; CHECK: call void @llvm.memset.p0i8.i64(i8* %1, i8 0, i64 8, i32 8, i1 false) +entry: + %0 = bitcast i64* %P to i8* + call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 32, i32 8, i1 false) + %arrayidx2 = getelementptr inbounds i8, i8* %0, i64 8 + call void @llvm.memset.p0i8.i64(i8* %arrayidx2, i8 1, i64 16, i32 8, i1 false) + ret void +} + +define void @test_32_4x8(i64* nocapture %P, i64 %n64) { +; CHECK-LABEL: @test_32_4x8( +; CHECK-NOT: call void @llvm.memset.p0i8.i64 +entry: + %0 = bitcast i64* %P to i8* + call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 32, i32 8, i1 false) + ; P[2] + %arrayidx2 = getelementptr inbounds i64, i64* %P, i64 2 + store i64 %n64, i64* %arrayidx2 + ; P[3] + %arrayidx3 = getelementptr inbounds i64, i64* %P, i64 3 + store i64 %n64, i64* %arrayidx3 + ; P[1] + %arrayidx1 = getelementptr inbounds i64, i64* %P, i64 1 + store i64 %n64, i64* %arrayidx1 + ; P[0] + %arrayidx0 = getelementptr inbounds i64, i64* %P, i64 0 + store i64 %n64, i64* %arrayidx0 + ret void +} + +define void @test_32_8to11(i64* nocapture %P, i32 %n32) { +; CHECK-LABEL: @test_32_8to11( +; CHECK: call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 32, i32 8, i1 false) +entry: + %0 = bitcast i64* %P to i8* + call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 32, i32 8, i1 false) + %arrayidx1 = getelementptr inbounds i64, i64* %P, i64 1 + %arrayidx1_32 = bitcast i64* %arrayidx1 to i32* + store i32 %n32, i32* %arrayidx1_32 + ret void +} + +define void @test_32_8to9(i64* nocapture %P, i16 %n16) { +; CHECK-LABEL: @test_32_8to9( +; CHECK: call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 32, i32 8, i1 false) +entry: + %0 = bitcast i64* %P to i8* + call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 32, i32 8, i1 false) + %arrayidx1 = getelementptr inbounds i64, i64* %P, i64 1 + %arrayidx1_16 = bitcast i64* %arrayidx1 to i16* + store i16 %n16, i16* %arrayidx1_16 + ret void +} + +define void @test_32_8to8(i64* nocapture %P, i8 %n8) { +; CHECK-LABEL: @test_32_8to8( +; CHECK: call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 32, i32 8, i1 false) +entry: + %0 = bitcast i64* %P to i8* + call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 32, i32 8, i1 false) + %arrayidx1 = getelementptr inbounds i64, i64* %P, i64 1 + %arrayidx1_8 = bitcast i64* %arrayidx1 to i8* + store i8 %n8, i8* %arrayidx1_8 + ret void +} + +; This should not be split due to the alignment (10). +define void @test_34_8to9(i64* nocapture %P, i16 %n16) { +; CHECK-LABEL: @test_34_8to9( +; CHECK: call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 34, i32 8, i1 false) +entry: + %0 = bitcast i64* %P to i8* + call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 34, i32 8, i1 false) + %arrayidx1 = getelementptr inbounds i64, i64* %P, i64 1 + %arrayidx1_16 = bitcast i64* %arrayidx1 to i16* + store i16 %n16, i16* %arrayidx1_16 + ret void +} + +declare void @llvm.memset.p0i8.i64(i8* nocapture, i8, i64, i32, i1) +