diff --git a/llvm/include/llvm/Transforms/Scalar/LoopIdiomRecognize.h b/llvm/include/llvm/Transforms/Scalar/LoopIdiomRecognize.h --- a/llvm/include/llvm/Transforms/Scalar/LoopIdiomRecognize.h +++ b/llvm/include/llvm/Transforms/Scalar/LoopIdiomRecognize.h @@ -44,7 +44,11 @@ LoopStandardAnalysisResults &AR, LPMUpdater &U); }; -// NFC LoopNestPass with regards to the current LoopPass-LoopIdiomRecognize +// The LoopNestIdiomRecognize is a LoopNestPass that feeds LoopNest into +// LoopIdiomRecognize. The main difference from LoopIdiomRecognize is it +// allows runtime-determined store size optimization by versioning and creates +// runtime checks on the top-level loop. The reason to only version on the +// top-level loop is to avoid the exponential growth of versioning. class LoopNestIdiomRecognizePass : public PassInfoMixin { public: diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -96,6 +96,7 @@ #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/LoopVersioning.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include #include @@ -108,6 +109,8 @@ #define DEBUG_TYPE "loop-idiom" STATISTIC(NumMemSet, "Number of memset's formed from loop stores"); +STATISTIC(NumMemSetRuntimeLength, + "Number of memset's formed from memset with runtime length"); STATISTIC(NumMemCpy, "Number of memcpy's formed from loop load+stores"); STATISTIC( NumShiftUntilBitTest, @@ -144,11 +147,24 @@ "with -Os/-Oz"), cl::init(true), cl::Hidden); +static cl::opt ForceNoLoopVersion( + DEBUG_TYPE "-no-loop-version", + cl::desc("Force not to create loop versions if the user guarantees that" + "the length of each array dimension is positive value, and" + "the multiplication of lengths of all array dimensions does not" + "exceeds the range of type size_t"), + cl::init(false), cl::Hidden); + namespace { +typedef SmallVector, 4> SCEVExprPairList; + class LoopIdiomRecognize { Loop *CurLoop = nullptr; LoopNest *LN; + Loop *TopLoop = nullptr; + Loop *FallBackLoop = nullptr; + BasicBlock *RuntimeCheckBB = nullptr; AliasAnalysis *AA; DominatorTree *DT; LoopInfo *LI; @@ -159,6 +175,7 @@ OptimizationRemarkEmitter &ORE; bool ApplyCodeSizeHeuristics; std::unique_ptr MSSAU; + SCEVExprPairList *SizeAddrSpacePairList; public: explicit LoopIdiomRecognize(AliasAnalysis *AA, DominatorTree *DT, @@ -166,11 +183,14 @@ TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, MemorySSA *MSSA, const DataLayout *DL, - OptimizationRemarkEmitter &ORE) + OptimizationRemarkEmitter &ORE, + SCEVExprPairList *SizeAddrSpacePairList) : LN(LN), AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), - ORE(ORE) { + ORE(ORE), SizeAddrSpacePairList(SizeAddrSpacePairList) { if (MSSA) MSSAU = std::make_unique(MSSA); + if (LN) + TopLoop = &LN->getOutermostLoop(); } bool runOnLoopNest(); @@ -235,6 +255,8 @@ const SCEV *BECount); bool avoidLIRForMultiBlockLoop(bool IsMemset = false, bool IsLoopMemset = false); + bool isTopLoopVersioned() const { return RuntimeCheckBB != nullptr; } + void versionTopLoop(); /// @} /// \name Noncountable Loop Idiom Handling @@ -295,7 +317,8 @@ // but ORE cannot be preserved (see comment before the pass definition). OptimizationRemarkEmitter ORE(L->getHeader()->getParent()); - LoopIdiomRecognize LIR(AA, DT, LI, nullptr, SE, TLI, TTI, MSSA, DL, ORE); + LoopIdiomRecognize LIR(AA, DT, LI, nullptr, SE, TLI, TTI, MSSA, DL, ORE, + nullptr); return LIR.runOnLoop(L); } @@ -311,6 +334,39 @@ } // end anonymous namespace +/// Helper function to generate predicate "((uint64)X >> 32) != 0". +static Value *generateOverflowPredicate(const SCEV *Ev, unsigned AddrSpace, + BranchInst *BI, const DataLayout *DL, + ScalarEvolution *SE, + IRBuilder<> &Builder) { + SCEVExpander Expander(*SE, *DL, "loop-idiom-overflow"); + + Type *Ty = Ev->getType(); + Value *StoreSize = Expander.expandCodeFor(Ev, Ty, BI); + + Type *IntPtrTy = Builder.getIntPtrTy(*DL, AddrSpace); + uint64_t SizeInBits = DL->getTypeSizeInBits(IntPtrTy); + Constant *Shift = ConstantInt::get(Ty, (SizeInBits >> 1)); + + Value *StoreSizeShift = Builder.CreateLShr(StoreSize, Shift); + Constant *Zero = ConstantInt::get(Ty, 0); + + return Builder.CreateICmpNE(StoreSizeShift, Zero); +} + +/// Helper function to generate predicate "X < 0". +static Value *generateSltZeroPredicate(const SCEV *Ev, BranchInst *BI, + const DataLayout *DL, + ScalarEvolution *SE, + IRBuilder<> &Builder) { + SCEVExpander Expander(*SE, *DL, "loop-idiom-non-negative"); + Type *Ty = Ev->getType(); + Value *Val = Expander.expandCodeFor(Ev, Ty, BI); + Constant *Zero = ConstantInt::get(Ty, 0); + + return Builder.CreateICmpSLT(Val, Zero); +} + char LoopIdiomRecognizeLegacyPass::ID = 0; PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM, @@ -327,7 +383,7 @@ OptimizationRemarkEmitter ORE(L.getHeader()->getParent()); LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, nullptr, &AR.SE, &AR.TLI, - &AR.TTI, AR.MSSA, DL, ORE); + &AR.TTI, AR.MSSA, DL, ORE, nullptr); if (!LIR.runOnLoop(&L)) return PreservedAnalyses::all(); @@ -348,8 +404,9 @@ &LN.getOutermostLoop().getHeader()->getModule()->getDataLayout(); OptimizationRemarkEmitter ORE(LN.getOutermostLoop().getHeader()->getParent()); + SCEVExprPairList SizeAddrSpacePairList; LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &LN, &AR.SE, &AR.TLI, &AR.TTI, - AR.MSSA, DL, ORE); + AR.MSSA, DL, ORE, &SizeAddrSpacePairList); if (!LIR.runOnLoopNest()) return PreservedAnalyses::all(); @@ -400,6 +457,37 @@ Changed |= runOnLoop(L); } + // After processing all the loops, we now add the stored conditions into + // the RuntimeCheckBB. Conditions are stored when: + // - detect runtime store size in StridedStore (SizeAddrSpacePairList) + if (Changed && isTopLoopVersioned()) { + // Get the branch instruction in the runtime check basic block. + BranchInst *BI = dyn_cast(RuntimeCheckBB->getTerminator()); + assert(BI && "Expects a BranchInst"); + + // Create conditional branch instructions with conditions: + // - Store size overflow half of the width of the pointer + // If any of the condition above is true, the fallback loop is taken. + // Otherwise, the optimized loop is taken. + LLVMContext &Context = TopLoop->getHeader()->getContext(); + Value *Cond = ConstantInt::getFalse(Context); + + IRBuilder<> Builder(BI); + for (auto Pair : *SizeAddrSpacePairList) { + const SCEV *Ev = Pair.first; + unsigned AddrSpace = Pair.second; + Value *NewCond0 = + generateOverflowPredicate(Ev, AddrSpace, BI, DL, SE, Builder); + Value *NewCond1 = generateSltZeroPredicate(Ev, BI, DL, SE, Builder); + Cond = Builder.CreateOr(Cond, NewCond0); + Cond = Builder.CreateOr(Cond, NewCond1); + } + + BranchInst::Create(FallBackLoop->getLoopPreheader(), + LN->getOutermostLoop().getLoopPreheader(), Cond, BI); + deleteDeadInstruction(BI); + } + return Changed; } @@ -944,7 +1032,7 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, const SCEV *BECount) { // We can only handle non-volatile memsets with a constant size. - if (MSI->isVolatile() || !isa(MSI->getLength())) + if (MSI->isVolatile()) return false; // If we're not allowed to hack on memset, we fail. @@ -960,24 +1048,65 @@ if (!Ev || Ev->getLoop() != CurLoop || !Ev->isAffine()) return false; + const SCEV *StrideSCEV = Ev->getOperand(1); const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); - if (!MemsetSizeSCEV) + if (!StrideSCEV || !MemsetSizeSCEV) return false; - // Reject memsets that are so large that they overflow an unsigned. - uint64_t SizeInBytes = cast(MSI->getLength())->getZExtValue(); - if ((SizeInBytes >> 32) != 0) - return false; + bool NegStride; + const bool IsConstantSize = isa(MSI->getLength()); + if (IsConstantSize) { + // Memset size is constant + // Reject memsets that are so large that they overflow an unsigned. + LLVM_DEBUG(dbgs() << " memset size is constant\n"); + uint64_t SizeInBytes = cast(MSI->getLength())->getZExtValue(); + if ((SizeInBytes >> 32) != 0) + return false; - // Check to see if the stride matches the size of the memset. If so, then we - // know that every byte is touched in the loop. - const SCEVConstant *ConstStride = dyn_cast(Ev->getOperand(1)); - if (!ConstStride) - return false; + // Check to see if the stride matches the size of the memset. If so, then + // we know that every byte is touched in the loop. + const SCEVConstant *ConstStride = dyn_cast(Ev->getOperand(1)); + if (!ConstStride) + return false; - APInt Stride = ConstStride->getAPInt(); - if (SizeInBytes != Stride && SizeInBytes != -Stride) - return false; + APInt Stride = ConstStride->getAPInt(); + if (SizeInBytes != Stride && SizeInBytes != -Stride) + return false; + + NegStride = SizeInBytes == -Stride; + } else { + // Memset size is non-constant + // Check if the stride matches the memset size, by comparing the SCEV + // expressions of the stride and memset size. The two expressions match + // if they are equal. If they match, then we know that every byte is + // touched in the loop. We only handle memset length and stride that + // are invariant for the top level loop. + LLVM_DEBUG(dbgs() << " memset size is non-constant\n"); + if (LN == nullptr) { + LLVM_DEBUG(dbgs() << " need to call LNIR for non-constant memset" + << "optimization\n"); + return false; + } + if (!SE->isLoopInvariant(MemsetSizeSCEV, TopLoop) || + !SE->isLoopInvariant(StrideSCEV, TopLoop)) { + LLVM_DEBUG(dbgs() << " memset size or stride is not a loop-invariant, " + << "abort\n"); + return false; + } + + // compare positive direction strideSCEV with MemsizeSizeSCEV + NegStride = StrideSCEV->isNonConstantNegative(); + const SCEV *PositiveStrideSCEV = + NegStride ? SE->getNegativeSCEV(StrideSCEV) : StrideSCEV; + LLVM_DEBUG(dbgs() << " MemsetSizeSCEV: " << *MemsetSizeSCEV << "\n" + << " PositiveStrideSCEV: " << *PositiveStrideSCEV + << "\n"); + + if (PositiveStrideSCEV != MemsetSizeSCEV) { + // TODO: add converter on SCEV (can be expanded) + return false; + } + } // Verify that the memset value is loop invariant. If not, we can't promote // the memset. @@ -987,7 +1116,6 @@ SmallPtrSet MSIs; MSIs.insert(MSI); - bool NegStride = SizeInBytes == -Stride; return processLoopStridedStore( Pointer, MemsetSizeSCEV, MaybeAlign(MSI->getDestAlignment()), SplatValue, MSI, MSIs, Ev, BECount, NegStride, /*IsLoopMemset=*/true); @@ -1216,15 +1344,37 @@ // NumBytes = TripCount * StoreSize const SCEV *TripCountS = getTripCount(BECount, IntIdxTy, CurLoop, DL, SE); + + // This check is possible only for LoopNestIdiomRecognize, since we are + // trying to version on the top-level loop. + // Give up if the store size is not constant and the trip count SCEV + // expression is variant to the top level loop. In this sense versioning is + // needed and compile option enforces not to. + if (LN != nullptr && !SE->isLoopInvariant(TripCountS, TopLoop)) { + const bool IsConstantSize = isa(StoreSizeSCEV); + if (IsLoopMemset && !IsConstantSize && ForceNoLoopVersion) { + LLVM_DEBUG(dbgs() << "requires versioning but abort becuase " + << "ForceNoLoopVersion is set to true\n"); + return Changed; + } + } + const SCEVConstant *ConstSize = dyn_cast(StoreSizeSCEV); const SCEV *NumBytesS; - if (!ConstSize && ConstSize->getAPInt() == 1) + + if (ConstSize && ConstSize->getAPInt() == 1) { NumBytesS = TripCountS; - else + LLVM_DEBUG(dbgs() << "StoreSize = 1, NumbytesS: " << *NumBytesS << "\n"); + } else { NumBytesS = SE->getMulExpr(TripCountS, SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntIdxTy), SCEV::FlagNUW); + LLVM_DEBUG(dbgs() << " Calculate NumBytesS = TripCountS * StoreSizeSCEV\n" + << " TripCountS: " << *TripCountS << "\n" + << " StoreSizeSCEV: " << *StoreSizeSCEV << "\n" + << " NumBytesS: " << *NumBytesS << "\n"); + } // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. @@ -1234,6 +1384,34 @@ Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator()); + // If the memset size is not a constant, we will need to version the top + // level loop in the current loop nest with runtime checks. We are going + // to version on only the top level loop once to avoid exponential growth + // of versioning. + // Here we check whether the top level clone has beed created yet, and create + // it if it hasn't. The initial runtime check is set to false and the + // conditions would be updated after we process all the loops. + const bool IsConstantSize = isa(StoreSizeSCEV); + if (LN != nullptr && IsLoopMemset && !IsConstantSize && !ForceNoLoopVersion) { + if (!isTopLoopVersioned()) { + LLVM_DEBUG(dbgs() << " Create versioning for top loop\n"); + versionTopLoop(); + + // If current loop is the top loop, versioning would change the loop's + // preheader to RuntimeCheckBB, so we need to reset the insert point. + if (CurLoop == TopLoop) { + Preheader = CurLoop->getLoopPreheader(); + Builder.SetInsertPoint(Preheader->getTerminator()); + } + } + + // Record SCEV expression for storesize and trip count. This that would + // later be used to generate runtime check conditions for the top-level + // loop versioning. + SizeAddrSpacePairList->push_back(std::make_pair(StoreSizeSCEV, DestAS)); + SizeAddrSpacePairList->push_back(std::make_pair(TripCountS, DestAS)); + } + CallInst *NewCall; if (SplatValue) { NewCall = Builder.CreateMemSet(BasePtr, SplatValue, NumBytes, @@ -1291,6 +1469,8 @@ MSSAU->getMemorySSA()->verifyMemorySSA(); ++NumMemSet; ExpCleaner.markResultUsed(); + if (IsLoopMemset && !IsConstantSize) + ++NumMemSetRuntimeLength; return true; } @@ -1503,6 +1683,18 @@ return false; } +/// versionTopLoop - Create a fallback version the TopLoop +void LoopIdiomRecognize::versionTopLoop() { + const LoopAccessInfo LAI(TopLoop, SE, TLI, AA, DT, LI); + LoopVersioning LV(LAI, LAI.getRuntimePointerChecking()->getChecks(), TopLoop, + LI, DT, SE); + + LV.versionLoopWithPlainRuntimeCheck(); + + RuntimeCheckBB = LV.getRuntimeCheckBB(); + FallBackLoop = LV.getNonVersionedLoop(); +} + bool LoopIdiomRecognize::runOnNoncountableLoop() { LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << CurLoop->getHeader()->getParent()->getName() diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll @@ -0,0 +1,57 @@ +; The C code to generate this testcase: +; void test(int ar[][m], long n, long m) +; { +; long i; +; for (i=0; i> 32) != 0 || (m >> 32) != 0) +; /* optimization result identical to LoopIdiomRecognize */ +; else +; /* hoists memset to loop-preheader */ +; } + +; RUN: opt -S -passes="loop-nest-idiom" < %s +; TODO - auto-generate CHECKs + +define void @test(i32* nocapture %ar, i64 %n, i64 %m) { +entry: + %0 = shl nuw i64 %m, 2 + br label %for.cond1.preheader + +for.cond1.preheader: ; preds = %for.inc4, %entry + %i.017 = phi i64 [ 0, %entry ], [ %inc5, %for.inc4 ] + %1 = mul i64 %m, %i.017 + %scevgep = getelementptr i32, i32* %ar, i64 %1 + %scevgep1 = bitcast i32* %scevgep to i8* + %mul = mul nsw i64 %i.017, %m + call void @llvm.memset.p0i8.i64(i8* align 4 %scevgep1, i8 0, i64 %0, i1 false) + br label %for.inc4 + +for.inc4: ; preds = %for.cond1.preheader + %inc5 = add nuw nsw i64 %i.017, 1 + %exitcond18.not = icmp eq i64 %inc5, %n + br i1 %exitcond18.not, label %for.end6, label %for.cond1.preheader + +for.end6: ; preds = %for.inc4 + ret void +} + +; Function Attrs: argmemonly nofree nounwind willreturn writeonly +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg) #0 + +attributes #0 = { argmemonly nofree nounwind willreturn writeonly } \ No newline at end of file