diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -147,6 +147,132 @@ namespace { +/// A helper class to do the following SCEV expression conversions. +/// 1) "sext %val" to "zext %val" +/// 2) "SOME_CONSTANT_VALUE smax %val" to "%val" +/// The converter would use ScalarEvolution::isLoopEntryGuardedByCond to see if +/// the expression can be folded. +class SCEVExprConverter { +public: + Loop *CurLoop; + ScalarEvolution *SE; + + SCEVExprConverter(ScalarEvolution *SE) : SE(SE) { + assert(SE != nullptr && "expect SE provided"); + } + + const SCEV *convertSCEV(const SCEV *Expr); +}; + +/// Tries to fold the SCEV with regard to loop guards of CurLoop +const SCEV *SCEVExprConverter::convertSCEV(const SCEV *Expr) { + switch (Expr->getSCEVType()) { + case scConstant: + case scUnknown: + case scCouldNotCompute: + return Expr; + case scTruncate: { + const SCEVTruncateExpr *Trunc = cast(Expr); + Type *Ty = Trunc->getType(); + const SCEV *NewTrunc = convertSCEV(Trunc->getOperand()); + return SE->getTruncateExpr(NewTrunc, Ty); + } + case scZeroExtend: { + const SCEVZeroExtendExpr *Zext = cast(Expr); + Type *Ty = Zext->getType(); + const SCEV *NewZext = convertSCEV(Zext->getOperand()); + return SE->getZeroExtendExpr(NewZext, Ty); + } + case scSignExtend: { + // Return original SCEV if expression is not guarded by loop with Zero + // Otherwise fold the expression + const SCEVSignExtendExpr *Sext = cast(Expr); + if (SE->isLoopEntryGuardedByCond(CurLoop, ICmpInst::ICMP_SGE, Sext, SE->getZero(Sext->getType())) == false) + return Sext; + const SCEV *NewZext = convertSCEV(Sext->getOperand()); + return SE->getZeroExtendExpr(NewZext, Sext->getType()); + } + case scAddExpr: { + const SCEVAddExpr *Add = cast(Expr); + const SCEV *NewAdd = convertSCEV(Add->getOperand(0)); + for (int I = 1, E = Add->getNumOperands(); I != E; ++I) { + NewAdd = SE->getAddExpr(NewAdd, convertSCEV(Add->getOperand(I))); + } + return NewAdd; + } + case scMulExpr: { + const SCEVMulExpr *Mul = cast(Expr); + const SCEV *NewMul = convertSCEV(Mul->getOperand(0)); + for (int I = 1, E = Mul->getNumOperands(); I != E; ++I) { + NewMul = SE->getMulExpr(NewMul, convertSCEV(Mul->getOperand(I))); + } + return NewMul; + } + case scUDivExpr: { + const SCEVUDivExpr *UDiv = cast(Expr); + const SCEV *NewLHS = convertSCEV(UDiv->getLHS()); + const SCEV *NewRHS = convertSCEV(UDiv->getRHS()); + return SE->getUDivExpr(NewLHS, NewRHS); + } + case scAddRecExpr: + assert(false && "Do not expect AddRec here!"); + case scUMaxExpr: { + const SCEVUMaxExpr *UMax = cast(Expr); + const SCEV *NewUMax = convertSCEV(UMax->getOperand(0)); + for (int I = 1, E = UMax->getNumOperands(); I != E; ++I) { + NewUMax = SE->getUMaxExpr(NewUMax, convertSCEV(UMax->getOperand(I))); + } + return NewUMax; + } + case scSMaxExpr: { + // Return original SCEV if expression is not guarded by loop with Zero + // Otherwise fold the expression + const SCEVSMaxExpr *SMax = cast(Expr); + const int NumOfOps = SMax->getNumOperands(); + bool Fold = false; + // If an operand is constant, it will be the first operand. + const SCEV *SMaxOp0 = SMax->getOperand(0); + const SCEVConstant *Cst = dyn_cast(SMaxOp0); + + if (Cst) { + // check if the operand is guarded to the constant + // if not, return orignal expression + Fold = true; + for (int I = 1, E = NumOfOps; I != E; ++I) { + const SCEV *Ev = SMax->getOperand(I); + if (SE->isLoopEntryGuardedByCond(CurLoop, ICmpInst::ICMP_SGE, Ev, Cst) == false) + return SMax; + } + } + + const int StartIdx = Fold ? 1 : 0; + const SCEV *NewSMax = convertSCEV(SMax->getOperand(StartIdx)); + for (int I = StartIdx + 1, E = NumOfOps; I != E; ++I) { + NewSMax = SE->getSMaxExpr(NewSMax, convertSCEV(SMax->getOperand(I))); + } + return NewSMax; + } + case scUMinExpr: { + const SCEVUMinExpr *UMin = cast(Expr); + const SCEV *NewUMin = convertSCEV(UMin->getOperand(0)); + for (int I = 1, E = UMin->getNumOperands(); I != E; ++I) { + NewUMin = SE->getUMinExpr(NewUMin, convertSCEV(UMin->getOperand(I))); + } + return NewUMin; + } + case scSMinExpr: { + const SCEVSMinExpr *SMin = cast(Expr); + const SCEV *NewSMin = convertSCEV(SMin->getOperand(0)); + for (int I = 1, E = SMin->getNumOperands(); I != E; ++I) { + NewSMin = SE->getSMinExpr(NewSMin, convertSCEV(SMin->getOperand(I))); + } + return NewSMin; + } + default: + llvm_unreachable("Unexpected SCEV expression!"); + } +} + class LoopIdiomRecognize { Loop *CurLoop = nullptr; AliasAnalysis *AA; @@ -159,6 +285,7 @@ OptimizationRemarkEmitter &ORE; bool ApplyCodeSizeHeuristics; std::unique_ptr MSSAU; + SCEVExprConverter Converter; public: explicit LoopIdiomRecognize(AliasAnalysis *AA, DominatorTree *DT, @@ -167,7 +294,8 @@ const TargetTransformInfo *TTI, MemorySSA *MSSA, const DataLayout *DL, OptimizationRemarkEmitter &ORE) - : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), ORE(ORE) { + : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), ORE(ORE), + Converter(SE) { if (MSSA) MSSAU = std::make_unique(MSSA); } @@ -217,7 +345,7 @@ bool processLoopMemCpy(MemCpyInst *MCI, const SCEV *BECount); bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount); - bool processLoopStridedStore(Value *DestPtr, unsigned StoreSize, + bool processLoopStridedStore(Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment, Value *StoredVal, Instruction *TheStore, SmallPtrSetImpl &Stores, @@ -509,10 +637,6 @@ if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine()) return LegalStoreKind::None; - // Check to see if we have a constant stride. - if (!isa(StoreEv->getOperand(1))) - return LegalStoreKind::None; - // See if the store can be turned into a memset. // If the stored value is a byte-wise value (like i32 -1), then it may be @@ -786,7 +910,8 @@ bool NegStride = StoreSize == -Stride; - if (processLoopStridedStore(StorePtr, StoreSize, + const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize); + if (processLoopStridedStore(StorePtr, StoreSizeSCEV, MaybeAlign(HeadStore->getAlignment()), StoredVal, HeadStore, AdjacentStores, StoreEv, BECount, NegStride)) { @@ -896,7 +1021,7 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, const SCEV *BECount) { // We can only handle non-volatile memsets with a constant size. - if (MSI->isVolatile() || !isa(MSI->getLength())) + if (MSI->isVolatile()) return false; // If we're not allowed to hack on memset, we fail. @@ -912,20 +1037,79 @@ if (!Ev || Ev->getLoop() != CurLoop || !Ev->isAffine()) return false; - // Reject memsets that are so large that they overflow an unsigned. - uint64_t SizeInBytes = cast(MSI->getLength())->getZExtValue(); - if ((SizeInBytes >> 32) != 0) + const SCEV *StrideSCEV = Ev->getOperand(1); + const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); + if (!StrideSCEV || !MemsetSizeSCEV) return false; - // Check to see if the stride matches the size of the memset. If so, then we - // know that every byte is touched in the loop. - const SCEVConstant *ConstStride = dyn_cast(Ev->getOperand(1)); - if (!ConstStride) - return false; + bool NegStride; + const bool IsConstantSize = isa(MSI->getLength()); - APInt Stride = ConstStride->getAPInt(); - if (SizeInBytes != Stride && SizeInBytes != -Stride) - return false; + if (IsConstantSize) { + // Memset size is constant + // Reject memsets that are so large that they overflow an unsigned. + LLVM_DEBUG(dbgs() << " memset size is constant\n"); + uint64_t SizeInBytes = cast(MSI->getLength())->getZExtValue(); + if ((SizeInBytes >> 32) != 0) + return false; + + // Check to see if the stride matches the size of the memset. If so, then + // we know that every byte is touched in the loop. + const SCEVConstant *ConstStride = dyn_cast(Ev->getOperand(1)); + if (!ConstStride) + return false; + + APInt Stride = ConstStride->getAPInt(); + if (SizeInBytes != Stride && SizeInBytes != -Stride) + return false; + + NegStride = SizeInBytes == -Stride; + } else { + // Memset size is non-constant + // Check if the pointer stride matches the memset size. + // Also, the pass only handle memset length and stride that are invariant + // for the top level loop. To be conservative, the pass would not promote + // pointers that isn't in address space zero. + // If the original StrideSCEV and MemsetSizeSCEV does not match, the pass + // will fold expressions that is covered by the loop guard at loop entry. + // The pass will compare again after the folding and proceed if equal. + LLVM_DEBUG(dbgs() << " memset size is non-constant\n"); + if (Pointer->getType()->getPointerAddressSpace() != 0) { + LLVM_DEBUG(dbgs() << " pointer is not in address space zero\n"); + return false; + } + if (!SE->isLoopInvariant(MemsetSizeSCEV, CurLoop) || + !SE->isLoopInvariant(StrideSCEV, CurLoop)) { + LLVM_DEBUG(dbgs() << " memset size or stride is not a loop-invariant, " + << "abort\n"); + return false; + } + + // compare positive direction strideSCEV with MemsizeSizeSCEV + NegStride = StrideSCEV->isNonConstantNegative(); + const SCEV *PositiveStrideSCEV = + NegStride ? SE->getNegativeSCEV(StrideSCEV) : StrideSCEV; + LLVM_DEBUG(dbgs() << " MemsetSizeSCEV: " << *MemsetSizeSCEV << "\n" + << " PositiveStrideSCEV: " << *PositiveStrideSCEV + << "\n"); + + if (PositiveStrideSCEV != MemsetSizeSCEV) { + Converter.CurLoop = CurLoop; + const SCEV *FoldedPositiveStride = + Converter.convertSCEV(PositiveStrideSCEV); + const SCEV *FoldedMemsetSize = Converter.convertSCEV(MemsetSizeSCEV); + LLVM_DEBUG( + dbgs() << " Try to fold SCEV expression covered by loop guard\n" + << " FoldedMemsetSCEV: " << *FoldedMemsetSize << "\n" + << " FoldedPositiveStrideSCEV: " + << *FoldedPositiveStride << "\n"); + + if (FoldedPositiveStride != FoldedMemsetSize) { + LLVM_DEBUG(dbgs() << " folded SCEV unmatch, abort\n"); + return false; + } + } + } // Verify that the memset value is loop invariant. If not, we can't promote // the memset. @@ -935,10 +1119,9 @@ SmallPtrSet MSIs; MSIs.insert(MSI); - bool NegStride = SizeInBytes == -Stride; return processLoopStridedStore( - Pointer, (unsigned)SizeInBytes, MaybeAlign(MSI->getDestAlignment()), - SplatValue, MSI, MSIs, Ev, BECount, NegStride, /*IsLoopMemset=*/true); + Pointer, MemsetSizeSCEV, MaybeAlign(MSI->getDestAlignment()), SplatValue, + MSI, MSIs, Ev, BECount, NegStride, /*IsLoopMemset=*/true); } /// mayLoopAccessLocation - Return true if the specified loop might access the @@ -946,7 +1129,7 @@ /// argument specifies what the verboten forms of access are (read or write). static bool mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, - const SCEV *BECount, unsigned StoreSize, + const SCEV *BECount, const SCEV *StoreSizeSCEV, AliasAnalysis &AA, SmallPtrSetImpl &IgnoredStores) { // Get the location that may be stored across the loop. Since the access is @@ -956,9 +1139,11 @@ // If the loop iterates a fixed number of times, we can refine the access size // to be exactly the size of the memset, which is (BECount+1)*StoreSize - if (const SCEVConstant *BECst = dyn_cast(BECount)) + const SCEVConstant *BECst = dyn_cast(BECount); + const SCEVConstant *ConstSize = dyn_cast(StoreSizeSCEV); + if (BECst && ConstSize) AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) * - StoreSize); + ConstSize->getValue()->getZExtValue()); // TODO: For this to be really effective, we have to dive into the pointer // operand in the store. Store to &A[i] of 100 will always return may alias @@ -977,19 +1162,71 @@ return false; } +// forwards StoreSize as SCEV, aim to replace usage of this prototype to +// let LIR deal with runtime-determined store size. +static bool mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, + const SCEV *BECount, unsigned StoreSize, + AliasAnalysis &AA, + SmallPtrSetImpl &IgnoredStores, + ScalarEvolution *SE) { + const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize); + return mayLoopAccessLocation(Ptr, Access, L, BECount, StoreSizeSCEV, AA, + IgnoredStores); +} + // If we have a negative stride, Start refers to the end of the memory location // we're trying to memset. Therefore, we need to recompute the base pointer, // which is just Start - BECount*Size. static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, - Type *IntPtr, unsigned StoreSize, + Type *IntPtr, const SCEV *StoreSizeSCEV, ScalarEvolution *SE) { const SCEV *Index = SE->getTruncateOrZeroExtend(BECount, IntPtr); - if (StoreSize != 1) - Index = SE->getMulExpr(Index, SE->getConstant(IntPtr, StoreSize), + const SCEVConstant *ConstSize = dyn_cast(StoreSizeSCEV); + if (!ConstSize || (ConstSize->getAPInt() != 1)) { + // index = back edge count * store size + Index = SE->getMulExpr(Index, + SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr), SCEV::FlagNUW); + } + // base pointer = start - index * store size return SE->getMinusSCEV(Start, Index); } +// forwards StoreSize as SCEV, aim to replace usage of this prototype to +// let LIR deal with runtime-determined store size. +static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, + Type *IntPtr, unsigned StoreSize, + ScalarEvolution *SE) { + const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize); + return getStartForNegStride(Start, BECount, IntPtr, StoreSizeSCEV, SE); +} + +static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr, + Loop *CurLoop, const DataLayout *DL, + ScalarEvolution *SE) { + const SCEV *TripCountS = nullptr; + // The # stored bytes is (BECount+1). Expand the trip count out to + // pointer size if it isn't already. + // + // If we're going to need to zero extend the BE count, check if we can add + // one to it prior to zero extending without overflow. Provided this is safe, + // it allows better simplification of the +1. + if (DL->getTypeSizeInBits(BECount->getType()) < + DL->getTypeSizeInBits(IntPtr) && + SE->isLoopEntryGuardedByCond( + CurLoop, ICmpInst::ICMP_NE, BECount, + SE->getNegativeSCEV(SE->getOne(BECount->getType())))) { + TripCountS = SE->getZeroExtendExpr( + SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW), + IntPtr); + } else { + TripCountS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr), + SE->getOne(IntPtr), SCEV::FlagNUW); + } + + return TripCountS; +} + /// Compute the number of bytes as a SCEV from the backedge taken count. /// /// This also maps the SCEV into the provided type and tries to handle the @@ -1028,7 +1265,7 @@ /// processLoopStridedStore - We see a strided store of some value. If we can /// transform this into a memset or memset_pattern in the loop preheader, do so. bool LoopIdiomRecognize::processLoopStridedStore( - Value *DestPtr, unsigned StoreSize, MaybeAlign StoreAlignment, + Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment, Value *StoredVal, Instruction *TheStore, SmallPtrSetImpl &Stores, const SCEVAddRecExpr *Ev, const SCEV *BECount, bool NegStride, bool IsLoopMemset) { @@ -1057,7 +1294,7 @@ const SCEV *Start = Ev->getStart(); // Handle negative strided loops. if (NegStride) - Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSize, SE); + Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSizeSCEV, SE); // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. @@ -1082,7 +1319,7 @@ Changed = true; if (mayLoopAccessLocation(BasePtr, ModRefInfo::ModRef, CurLoop, BECount, - StoreSize, *AA, Stores)) + StoreSizeSCEV, *AA, Stores)) return Changed; if (avoidLIRForMultiBlockLoop(/*IsMemset=*/true, IsLoopMemset)) @@ -1090,8 +1327,21 @@ // Okay, everything looks good, insert the memset. - const SCEV *NumBytesS = - getNumBytes(BECount, IntIdxTy, StoreSize, CurLoop, DL, SE); + const SCEV *TripCountS = getTripCount(BECount, IntIdxTy, CurLoop, DL, SE); + const SCEVConstant *ConstSize = dyn_cast(StoreSizeSCEV); + const SCEV *NumBytesS; + if (ConstSize && ConstSize->getAPInt() == 1) { + NumBytesS = TripCountS; + LLVM_DEBUG(dbgs() << "StoreSize = 1, NumbytesS: " << *NumBytesS << "\n"); + } else { + NumBytesS = SE->getMulExpr( + TripCountS, SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntIdxTy), + SCEV::FlagNUW); + LLVM_DEBUG(dbgs() << " Calculate NumBytesS = TripCountS * StoreSizeSCEV\n" + << " TripCountS: " << *TripCountS << "\n" + << " StoreSizeSCEV: " << *StoreSizeSCEV << "\n" + << " NumBytesS: " << *NumBytesS << "\n"); + } // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. @@ -1245,11 +1495,11 @@ bool UseMemMove = mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount, - StoreSize, *AA, Stores); + StoreSize, *AA, Stores, SE); if (UseMemMove) { Stores.insert(TheLoad); if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, - BECount, StoreSize, *AA, Stores)) { + BECount, StoreSize, *AA, Stores, SE)) { ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessStore", TheStore) @@ -1280,7 +1530,7 @@ if (IsMemCpy) Stores.erase(TheStore); if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount, - StoreSize, *AA, Stores)) { + StoreSize, *AA, Stores, SE)) { ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessLoad", TheLoad) << ore::NV("Inst", InstRemark) << " in " diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll @@ -0,0 +1,159 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes="function(loop(loop-idiom,loop-deletion),simplifycfg)" < %s -S | FileCheck %s +; The C code to generate this testcase: +; void test(int n, int m, int *ar) +; { +; long i; +; for (i=0; i