diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -219,7 +219,7 @@ bool processLoopMemCpy(MemCpyInst *MCI, const SCEV *BECount); bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount); - bool processLoopStridedStore(Value *DestPtr, unsigned StoreSize, + bool processLoopStridedStore(Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment, Value *StoredVal, Instruction *TheStore, SmallPtrSetImpl &Stores, @@ -833,7 +833,8 @@ bool NegStride = StoreSize == -Stride; - if (processLoopStridedStore(StorePtr, StoreSize, + const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize); + if (processLoopStridedStore(StorePtr, StoreSizeSCEV, MaybeAlign(HeadStore->getAlignment()), StoredVal, HeadStore, AdjacentStores, StoreEv, BECount, NegStride)) { @@ -959,6 +960,10 @@ if (!Ev || Ev->getLoop() != CurLoop || !Ev->isAffine()) return false; + const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); + if (!MemsetSizeSCEV) + return false; + // Reject memsets that are so large that they overflow an unsigned. uint64_t SizeInBytes = cast(MSI->getLength())->getZExtValue(); if ((SizeInBytes >> 32) != 0) @@ -984,7 +989,7 @@ MSIs.insert(MSI); bool NegStride = SizeInBytes == -Stride; return processLoopStridedStore( - Pointer, (unsigned)SizeInBytes, MaybeAlign(MSI->getDestAlignment()), + Pointer, MemsetSizeSCEV, MaybeAlign(MSI->getDestAlignment()), SplatValue, MSI, MSIs, Ev, BECount, NegStride, /*IsLoopMemset=*/true); } @@ -1024,6 +1029,33 @@ return false; } +// This is a temporal version that takes StoreSize as SCEV. The goal is to make +// all LoopIdiom optimizations to take StoreSize as SCEV. Then the above version +// will be fully replaced by this one. +static bool +mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, + const SCEV *BECount, const SCEV *StoreSizeSCEV, + AliasAnalysis &AA, + SmallPtrSetImpl &IgnoredStores) { + LocationSize AccessSize = LocationSize::afterPointer(); + const SCEVConstant *BECst = dyn_cast(BECount); + const SCEVConstant *ConstSize = dyn_cast(StoreSizeSCEV); + if (BECst && ConstSize) + AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) * + ConstSize->getValue()->getZExtValue()); + // The range of Locations to store + MemoryLocation StoreLoc(Ptr, AccessSize); + + for (Loop::block_iterator BI = L->block_begin(), E = L->block_end(); BI != E; + ++BI) + for (Instruction &I : **BI) + if (IgnoredStores.count(&I) == 0 && + isModOrRefSet( + intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access))) + return true; + return false; +} + // If we have a negative stride, Start refers to the end of the memory location // we're trying to memset. Therefore, we need to recompute the base pointer, // which is just Start - BECount*Size. @@ -1037,6 +1069,25 @@ return SE->getMinusSCEV(Start, Index); } +// This is a temporal version that takes StoreSize as SCEV. The goal is to make +// all LoopIdiom optimizations to take StoreSize as SCEV. Then the above version +// will be fully replaced by this one. +static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, + Type *IntPtr, const SCEV *StoreSizeSCEV, + ScalarEvolution *SE) { + // index = back edge count + const SCEV *Index = SE->getTruncateOrZeroExtend(BECount, IntPtr); + const SCEVConstant *ConstSize = dyn_cast(StoreSizeSCEV); + if (!ConstSize || (ConstSize->getAPInt() != 1)) { + // index = index * storesize = back edge count * store size + Index = SE->getMulExpr(Index, + SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr), + SCEV::FlagNUW); + } + // return start - index * store size = base pointer + return SE->getMinusSCEV(Start, Index); +} + /// Compute the number of bytes as a SCEV from the backedge taken count. /// /// This also maps the SCEV into the provided type and tries to handle the @@ -1072,10 +1123,36 @@ return NumBytesS; } +static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr, + Loop *CurLoop, const DataLayout *DL, + ScalarEvolution *SE) { + const SCEV *TripCountS; + // The # stored bytes is (BECount+1). Expand the trip count out to + // pointer size if it isn't already. + // + // If we're going to need to zero extend the BE count, check if we can add + // one to it prior to zero extending without overflow. Provided this is safe, + // it allows better simplification of the +1. + if (DL->getTypeSizeInBits(BECount->getType()) < + DL->getTypeSizeInBits(IntPtr) && + SE->isLoopEntryGuardedByCond( + CurLoop, ICmpInst::ICMP_NE, BECount, + SE->getNegativeSCEV(SE->getOne(BECount->getType())))) { + TripCountS = SE->getZeroExtendExpr( + SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW), + IntPtr); + } else { + TripCountS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr), + SE->getOne(IntPtr), SCEV::FlagNUW); + } + + return TripCountS; +} + /// processLoopStridedStore - We see a strided store of some value. If we can /// transform this into a memset or memset_pattern in the loop preheader, do so. bool LoopIdiomRecognize::processLoopStridedStore( - Value *DestPtr, unsigned StoreSize, MaybeAlign StoreAlignment, + Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment, Value *StoredVal, Instruction *TheStore, SmallPtrSetImpl &Stores, const SCEVAddRecExpr *Ev, const SCEV *BECount, bool NegStride, bool IsLoopMemset) { @@ -1104,7 +1181,7 @@ const SCEV *Start = Ev->getStart(); // Handle negative strided loops. if (NegStride) - Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSize, SE); + Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSizeSCEV, SE); // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. @@ -1129,7 +1206,7 @@ Changed = true; if (mayLoopAccessLocation(BasePtr, ModRefInfo::ModRef, CurLoop, BECount, - StoreSize, *AA, Stores)) + StoreSizeSCEV, *AA, Stores)) return Changed; if (avoidLIRForMultiBlockLoop(/*IsMemset=*/true, IsLoopMemset)) @@ -1137,8 +1214,17 @@ // Okay, everything looks good, insert the memset. - const SCEV *NumBytesS = - getNumBytes(BECount, IntIdxTy, StoreSize, CurLoop, DL, SE); + // NumBytes = TripCount * StoreSize + const SCEV *TripCountS = getTripCount(BECount, IntIdxTy, CurLoop, DL, SE); + const SCEVConstant *ConstSize = dyn_cast(StoreSizeSCEV); + const SCEV *NumBytesS; + if (!ConstSize && ConstSize->getAPInt() == 1) + NumBytesS = TripCountS; + else + NumBytesS = + SE->getMulExpr(TripCountS, + SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntIdxTy), + SCEV::FlagNUW); // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point.