Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -459,6 +459,11 @@ /// and the number of execution units in the CPU. unsigned getMaxInterleaveFactor(unsigned VF) const; + /// \rreturn The maximum number of store operations permitted to replace a + /// call to llvm.memset. If OptSize is true, return the limit for functions + /// that have OptSize attribute. + unsigned getMaxStoresPerMemset(bool OptSize) const; + /// \return The expected cost of arithmetic ops, such as mul, xor, fsub, etc. int getArithmeticInstrCost( unsigned Opcode, Type *Ty, OperandValueKind Opd1Info = OK_AnyValue, @@ -659,6 +664,7 @@ virtual unsigned getMinPrefetchStride() = 0; virtual unsigned getMaxPrefetchIterationsAhead() = 0; virtual unsigned getMaxInterleaveFactor(unsigned VF) = 0; + virtual unsigned getMaxStoresPerMemset(bool OptSize) = 0; virtual unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, @@ -847,6 +853,9 @@ unsigned getMaxInterleaveFactor(unsigned VF) override { return Impl.getMaxInterleaveFactor(VF); } + unsigned getMaxStoresPerMemset(bool OptSize) override { + return Impl.getMaxStoresPerMemset(OptSize); + } unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -278,6 +278,8 @@ unsigned getMaxInterleaveFactor(unsigned VF) { return 1; } + unsigned getMaxStoresPerMemset(bool OptSize) { return 0; } + unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::OperandValueKind Opd1Info, TTI::OperandValueKind Opd2Info, Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -284,6 +284,10 @@ unsigned getMaxInterleaveFactor(unsigned VF) { return 1; } + unsigned getMaxStoresPerMemset(bool OptSize) { + return getTLI()->getMaxStoresPerMemset(OptSize); + } + unsigned getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::OperandValueKind Opd1Info = TTI::OK_AnyValue, Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -243,6 +243,10 @@ return TTIImpl->getMaxInterleaveFactor(VF); } +unsigned TargetTransformInfo::getMaxStoresPerMemset(bool OptSize) const { + return TTIImpl->getMaxStoresPerMemset(OptSize); +} + int TargetTransformInfo::getArithmeticInstrCost( unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, OperandValueProperties Opd1PropInfo, Index: lib/Transforms/Scalar/DeadStoreElimination.cpp =================================================================== --- lib/Transforms/Scalar/DeadStoreElimination.cpp +++ lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -25,6 +25,7 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -33,6 +34,7 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -51,6 +53,7 @@ MemoryDependenceResults *MD; DominatorTree *DT; const TargetLibraryInfo *TLI; + const TargetTransformInfo *TTI; static char ID; // Pass identification, replacement for typeid DSE() : FunctionPass(ID), AA(nullptr), MD(nullptr), DT(nullptr) { @@ -65,6 +68,7 @@ MD = &getAnalysis().getMemDep(); DT = &getAnalysis().getDomTree(); TLI = &getAnalysis().getTLI(); + TTI = &getAnalysis().getTTI(F); bool Changed = false; for (BasicBlock &I : F) @@ -91,6 +95,7 @@ AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addRequired(); AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); @@ -105,6 +110,7 @@ INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(DSE, "dse", "Dead Store Elimination", false, false) FunctionPass *llvm::createDeadStoreEliminationPass() { return new DSE(); } @@ -274,6 +280,50 @@ return false; } +static unsigned getExpectedNumStores(unsigned SizeInBytes, + unsigned MaxIntSizeInBytes) { + // Since we don't have perfect knowledge here, assume that the maximum GPR + // width is the same size as the largest legal integer size. + unsigned NumWideStores = (SizeInBytes / MaxIntSizeInBytes); + // Conservatively assume the remaining bytes as a byte at a time. + unsigned NumNarrowStores = (SizeInBytes % MaxIntSizeInBytes); + return NumWideStores + NumNarrowStores; +} + +/// Return true if splitting memset into two parts is profitable in terms of +/// the number of stores lowered from the memset. +static bool isSplittingProfitable(Instruction *EarlierStore, + int64_t NewLengthPart1, + int64_t NewLengthPart2, const DataLayout &DL, + const TargetTransformInfo *TTI) { + // FIXME: Split only memset for now. Supporting memcpy/memmove is also + // possible. + MemSetInst *MSI = dyn_cast(EarlierStore); + if (!MSI) + return false; + + // First make sure that the existing memset could be lowered to stores later. + unsigned OrigLength = cast(MSI->getLength())->getZExtValue(); + unsigned MaxIntSize = DL.getLargestLegalIntTypeSizeInBits() / 8; + + if (MaxIntSize == 0) + MaxIntSize = 1; + + if (OrigLength <= MaxIntSize) + return false; + + unsigned NumStoresInOrigMemset = getExpectedNumStores(OrigLength, MaxIntSize); + bool OptSize = MSI->getParent()->getParent()->optForSize(); + if (TTI->getMaxStoresPerMemset(OptSize) < NumStoresInOrigMemset) + return false; + + if (NewLengthPart1 < MaxIntSize || NewLengthPart2 < MaxIntSize) + return false; + + return NumStoresInOrigMemset > + (getExpectedNumStores(NewLengthPart1, MaxIntSize) + + getExpectedNumStores(NewLengthPart2, MaxIntSize)); +} /// Returns true if the end of this instruction can be safely shortened in /// length. @@ -341,6 +391,7 @@ OverwriteBegin, OverwriteComplete, OverwriteEnd, + OverwritePartial, OverwriteUnknown }; } @@ -454,6 +505,12 @@ && "Expect to be handled as OverwriteComplete" ); return OverwriteBegin; } + + if (LaterOff > EarlierOff && + int64_t(LaterOff + Later.Size) < int64_t(EarlierOff + Earlier.Size)) { + return OverwritePartial; + } + // Otherwise, they don't completely overlap. return OverwriteUnknown; } @@ -673,6 +730,47 @@ } MadeChange = true; } + } else if (OR == OverwritePartial) { + MemIntrinsic *DepIntrinsic = cast(DepWrite); + unsigned DepWriteAlign = DepIntrinsic->getAlignment(); + int64_t NewWriteOffset = int64_t(InstWriteOffset + Loc.Size); + int64_t NewLengthPart1 = InstWriteOffset - DepWriteOffset; + int64_t NewLengthPart2 = int64_t(DepWriteOffset + DepLoc.Size) - + int64_t(InstWriteOffset + Loc.Size); + + bool IsSplitable = isSplittingProfitable(DepWrite, NewLengthPart1, + NewLengthPart2, DL, TTI); + + if (IsSplitable && + ((llvm::isPowerOf2_64(NewWriteOffset) && + DepWriteAlign <= NewWriteOffset) || + ((DepWriteAlign != 0) && NewWriteOffset % DepWriteAlign == 0))) { + + IRBuilder<> Builder(DepWrite); + Value *DepWriteLength = DepIntrinsic->getLength(); + + Value *NewLengthPart1Val = + ConstantInt::get(DepWriteLength->getType(), NewLengthPart1); + Value *NewLengthPart2Val = + ConstantInt::get(DepWriteLength->getType(), NewLengthPart2); + + if (MemSetInst *MSI = cast(DepWrite)) { + DepIntrinsic->setLength(NewLengthPart1Val); + + int64_t OffsetMoved = NewWriteOffset - DepWriteOffset; + Value *Indices[1] = { + ConstantInt::get(DepWriteLength->getType(), OffsetMoved)}; + GetElementPtrInst *NewDestGEP = GetElementPtrInst::CreateInBounds( + DepIntrinsic->getRawDest(), Indices, "", DepWrite); + Instruction *NewMemSetPart2 = + Builder.CreateMemSet(NewDestGEP, MSI->getValue(), + NewLengthPart2Val, DepWriteAlign); + NewMemSetPart2->setDebugLoc(DepWrite->getDebugLoc()); + } else + llvm_unreachable("Unxpected MemIntrinsic."); + + MadeChange = true; + } } }