diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h --- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h @@ -339,12 +339,8 @@ /// to a checking group if we will still be able to get /// the upper and lower bounds of the check. Returns true in case /// of success, false otherwise. - bool addPointer(unsigned Index); + bool addPointer(unsigned Index, RuntimePointerChecking &RtCheck); - /// Constitutes the context of this pointer checking group. For each - /// pointer that is a member of this group we will retain the index - /// at which it appears in RtCheck. - RuntimePointerChecking &RtCheck; /// The SCEV expression which represents the upper bound of all the /// pointers in this group. const SCEV *High; @@ -469,6 +465,10 @@ ScalarEvolution *getSE() const { return SE; } + /// Remove all runtime checks that satisfy \p Predicate. + void + removeCheckIf(function_ref Predicate); + private: /// Groups pointers such that a single memcheck is required /// between two different groups. This will clear the CheckingGroups vector diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h --- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -45,6 +45,7 @@ typedef std::pair RuntimePointerCheck; +class RuntimePointerChecking; template class Optional; template class SmallSetVector; @@ -440,8 +441,7 @@ /// second value is the final comparator value or NULL if no check is needed. std::pair addRuntimeChecks(Instruction *Loc, Loop *TheLoop, - const SmallVectorImpl &PointerChecks, - ScalarEvolution *SE); + const RuntimePointerChecking &RtPtrChecking); } // end namespace llvm diff --git a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h --- a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h +++ b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h @@ -29,6 +29,7 @@ typedef std::pair RuntimePointerCheck; +class RuntimePointerChecking; template class ArrayRef; @@ -43,8 +44,8 @@ /// It uses runtime check provided by the user. If \p UseLAIChecks is true, /// we will retain the default checks made by LAI. Otherwise, construct an /// object having no checks and we expect the user to add them. - LoopVersioning(const LoopAccessInfo &LAI, - ArrayRef Checks, Loop *L, LoopInfo *LI, + LoopVersioning(const RuntimePointerChecking &RtPTrCheck, + const SCEVUnionPredicate &Preds, Loop *L, LoopInfo *LI, DominatorTree *DT, ScalarEvolution *SE); /// Performs the CFG manipulation part of versioning the loop including @@ -80,7 +81,7 @@ /// /// This is just wrapper that calls prepareNoAliasMetadata and /// annotateInstWithNoAlias on the instructions of the versioned loop. - void annotateLoopWithNoAlias(); + void annotateLoopWithNoAlias(const LoopAccessInfo &LAI); /// Set up the aliasing scopes based on the memchecks. This needs to /// be called before the first call to annotateInstWithNoAlias. @@ -119,9 +120,6 @@ /// in NonVersionedLoop. ValueToValueMapTy VMap; - /// The set of alias checks that we are versioning for. - SmallVector AliasChecks; - /// The set of SCEV checks that we are versioning for. const SCEVUnionPredicate &Preds; @@ -137,7 +135,7 @@ GroupToNonAliasingScopeList; /// Analyses used. - const LoopAccessInfo &LAI; + const RuntimePointerChecking &RtPtrChecking; LoopInfo *LI; DominatorTree *DT; ScalarEvolution *SE; diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -176,8 +176,7 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup( unsigned Index, RuntimePointerChecking &RtCheck) - : RtCheck(RtCheck), High(RtCheck.Pointers[Index].End), - Low(RtCheck.Pointers[Index].Start) { + : High(RtCheck.Pointers[Index].End), Low(RtCheck.Pointers[Index].Start) { Members.push_back(Index); } @@ -284,7 +283,8 @@ return I; } -bool RuntimeCheckingPtrGroup::addPointer(unsigned Index) { +bool RuntimeCheckingPtrGroup::addPointer(unsigned Index, + RuntimePointerChecking &RtCheck) { const SCEV *Start = RtCheck.Pointers[Index].Start; const SCEV *End = RtCheck.Pointers[Index].End; @@ -413,7 +413,7 @@ TotalComparisons++; - if (Group->addPointer(Pointer)) { + if (Group->addPointer(Pointer, *this)) { Merged = true; break; } @@ -459,6 +459,11 @@ return true; } +void RuntimePointerChecking::removeCheckIf( + function_ref Predicate) { + Checks = {Checks.begin(), remove_if(Checks, Predicate)}; +} + void RuntimePointerChecking::printChecks( raw_ostream &OS, const SmallVectorImpl &Checks, unsigned Depth) const { diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -791,10 +791,9 @@ // If we need run-time checks, version the loop now. auto PtrToPartition = Partitions.computePartitionSetForPointers(*LAI); - const auto *RtPtrChecking = LAI->getRuntimePointerChecking(); - const auto &AllChecks = RtPtrChecking->getChecks(); - auto Checks = includeOnlyCrossPartitionChecks(AllChecks, PtrToPartition, - RtPtrChecking); + RuntimePointerChecking RtPtrChecking = *LAI->getRuntimePointerChecking(); + const auto &Checks = + includeOnlyCrossPartitionChecks(PtrToPartition, RtPtrChecking); if (LAI->hasConvergentOp() && !Checks.empty()) { return fail("RuntimeCheckWithConvergent", @@ -814,9 +813,9 @@ LLVM_DEBUG(dbgs() << "\nPointers:\n"); LLVM_DEBUG(LAI->getRuntimePointerChecking()->printChecks(dbgs(), Checks)); - LoopVersioning LVer(*LAI, Checks, L, LI, DT, SE); + LoopVersioning LVer(RtPtrChecking, Pred, L, LI, DT, SE); LVer.versionLoop(DefsUsedOutside); - LVer.annotateLoopWithNoAlias(); + LVer.annotateLoopWithNoAlias(*LAI); // The unversioned loop will not be changed, so we inherit all attributes // from the original loop, but remove the loop distribution metadata to @@ -901,37 +900,34 @@ /// \p PtrToPartition contains the partition number for pointers. Partition /// number -1 means that the pointer is used in multiple partitions. In this /// case we can't safely omit the check. - SmallVector includeOnlyCrossPartitionChecks( - const SmallVectorImpl &AllChecks, - const SmallVectorImpl &PtrToPartition, - const RuntimePointerChecking *RtPtrChecking) { + const SmallVector & + includeOnlyCrossPartitionChecks(const SmallVectorImpl &PtrToPartition, + RuntimePointerChecking &RtPtrChecking) { SmallVector Checks; - copy_if(AllChecks, std::back_inserter(Checks), - [&](const RuntimePointerCheck &Check) { - for (unsigned PtrIdx1 : Check.first->Members) - for (unsigned PtrIdx2 : Check.second->Members) - // Only include this check if there is a pair of pointers - // that require checking and the pointers fall into - // separate partitions. - // - // (Note that we already know at this point that the two - // pointer groups need checking but it doesn't follow - // that each pair of pointers within the two groups need - // checking as well. - // - // In other words we don't want to include a check just - // because there is a pair of pointers between the two - // pointer groups that require checks and a different - // pair whose pointers fall into different partitions.) - if (RtPtrChecking->needsChecking(PtrIdx1, PtrIdx2) && - !RuntimePointerChecking::arePointersInSamePartition( - PtrToPartition, PtrIdx1, PtrIdx2)) - return true; - return false; - }); - - return Checks; + RtPtrChecking.removeCheckIf([&](const RuntimePointerCheck &Check) { + for (unsigned PtrIdx1 : Check.first->Members) + for (unsigned PtrIdx2 : Check.second->Members) + // Only include this check if there is a pair of pointers + // that require checking and the pointers fall into + // separate partitions. + // + // (Note that we already know at this point that the two + // pointer groups need checking but it doesn't follow + // that each pair of pointers within the two groups need + // checking as well. + // + // In other words we don't want to include a check just + // because there is a pair of pointers between the two + // pointer groups that require checks and a different + // pair whose pointers fall into different partitions.) + if (RtPtrChecking.needsChecking(PtrIdx1, PtrIdx2) && + !RuntimePointerChecking::arePointersInSamePartition( + PtrToPartition, PtrIdx1, PtrIdx2)) + return false; + return true; + }); + return RtPtrChecking.getChecks(); } /// Check whether the loop metadata is forcing distribution to be diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp --- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -377,8 +377,9 @@ /// Determine the pointer alias checks to prove that there are no /// intervening stores. - SmallVector collectMemchecks( - const SmallVectorImpl &Candidates) { + const SmallVector &filterMemchecks( + const SmallVectorImpl &Candidates, + RuntimePointerChecking &RtPointerChecking) { SmallPtrSet PtrsWrittenOnFwdingPath = findPointersWrittenOnForwardingPath(Candidates); @@ -388,23 +389,21 @@ for (const auto &Candidate : Candidates) CandLoadPtrs.insert(Candidate.getLoadPtr()); - const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks(); - SmallVector Checks; - - copy_if(AllChecks, std::back_inserter(Checks), - [&](const RuntimePointerCheck &Check) { - for (auto PtrIdx1 : Check.first->Members) - for (auto PtrIdx2 : Check.second->Members) - if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath, - CandLoadPtrs)) - return true; - return false; - }); - + RtPointerChecking.removeCheckIf( + [this, &PtrsWrittenOnFwdingPath, + CandLoadPtrs](const RuntimePointerCheck &Check) { + for (auto PtrIdx1 : Check.first->Members) + for (auto PtrIdx2 : Check.second->Members) + if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath, + CandLoadPtrs)) + return false; + return true; + }); + + auto &Checks = RtPointerChecking.getChecks(); LLVM_DEBUG(dbgs() << "\nPointer Checks (count: " << Checks.size() << "):\n"); - LLVM_DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks)); - + LLVM_DEBUG(RtPointerChecking.printChecks(dbgs(), Checks)); return Checks; } @@ -518,7 +517,9 @@ // Check intervening may-alias stores. These need runtime checks for alias // disambiguation. - SmallVector Checks = collectMemchecks(Candidates); + RuntimePointerChecking RtPointerChecking = *LAI.getRuntimePointerChecking(); + const SmallVector &Checks = + filterMemchecks(Candidates, RtPointerChecking); // Too many checks are likely to outweigh the benefits of forwarding. if (Checks.size() > Candidates.size() * CheckPerElim) { @@ -559,7 +560,8 @@ // Point of no-return, start the transformation. First, version the loop // if necessary. - LoopVersioning LV(LAI, Checks, L, LI, DT, PSE.getSE()); + LoopVersioning LV(RtPointerChecking, PSE.getUnionPredicate(), L, LI, DT, + PSE.getSE()); LV.versionLoop(); } diff --git a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp --- a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -603,8 +603,8 @@ // Create memcheck for memory accessed inside loop. // Clone original loop, and set blocks properly. DominatorTree *DT = &getAnalysis().getDomTree(); - LoopVersioning LVer(*LAI, LAI->getRuntimePointerChecking()->getChecks(), - CurLoop, LI, DT, SE); + LoopVersioning LVer(*LAI->getRuntimePointerChecking(), + LAI->getPSE().getUnionPredicate(), CurLoop, LI, DT, SE); LVer.versionLoop(); // Set Loop Versioning metaData for original loop. addStringMetadataToLoop(LVer.getNonVersionedLoop(), LICMVersioningMetaData); diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1573,10 +1573,11 @@ /// Expand code for the lower and upper bound of the pointer group \p CG /// in \p TheLoop. \return the values for the bounds. static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, + const RuntimePointerChecking &RtPtrChecking, Loop *TheLoop, Instruction *Loc, SCEVExpander &Exp, ScalarEvolution *SE) { // TODO: Add helper to retrieve pointers to CG. - Value *Ptr = CG->RtCheck.Pointers[CG->Members[0]].PointerValue; + Value *Ptr = RtPtrChecking.Pointers[CG->Members[0]].PointerValue; const SCEV *Sc = SE->getSCEV(Ptr); unsigned AS = Ptr->getType()->getPointerAddressSpace(); @@ -1612,32 +1613,33 @@ /// Turns a collection of checks into a collection of expanded upper and /// lower bounds for both pointers in the check. static SmallVector, 4> -expandBounds(const SmallVectorImpl &PointerChecks, Loop *L, +expandBounds(const RuntimePointerChecking &RtPtrChecking, Loop *L, Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp) { SmallVector, 4> ChecksWithBounds; // Here we're relying on the SCEV Expander's cache to only emit code for the // same bounds once. - transform(PointerChecks, std::back_inserter(ChecksWithBounds), + transform(RtPtrChecking.getChecks(), std::back_inserter(ChecksWithBounds), [&](const RuntimePointerCheck &Check) { - PointerBounds First = expandBounds(Check.first, L, Loc, Exp, SE), - Second = - expandBounds(Check.second, L, Loc, Exp, SE); + PointerBounds First = expandBounds(Check.first, RtPtrChecking, L, + Loc, Exp, SE), + Second = expandBounds(Check.second, RtPtrChecking, + L, Loc, Exp, SE); return std::make_pair(First, Second); }); return ChecksWithBounds; } -std::pair llvm::addRuntimeChecks( - Instruction *Loc, Loop *TheLoop, - const SmallVectorImpl &PointerChecks, - ScalarEvolution *SE) { +std::pair +llvm::addRuntimeChecks(Instruction *Loc, Loop *TheLoop, + const RuntimePointerChecking &RtPtrChecking) { // TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible. // TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); + ScalarEvolution *SE = RtPtrChecking.getSE(); SCEVExpander Exp(*SE, DL, "induction"); - auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, SE, Exp); + auto ExpandedChecks = expandBounds(RtPtrChecking, TheLoop, Loc, SE, Exp); LLVMContext &Ctx = Loc->getContext(); Instruction *FirstInst = nullptr; diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -32,14 +32,12 @@ cl::desc("Add no-alias annotation for instructions that " "are disambiguated by memchecks")); -LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, - ArrayRef Checks, Loop *L, +LoopVersioning::LoopVersioning(const RuntimePointerChecking &RtPtrChecking, + const SCEVUnionPredicate &Preds, Loop *L, LoopInfo *LI, DominatorTree *DT, ScalarEvolution *SE) - : VersionedLoop(L), NonVersionedLoop(nullptr), - AliasChecks(Checks.begin(), Checks.end()), - Preds(LAI.getPSE().getUnionPredicate()), LAI(LAI), LI(LI), DT(DT), - SE(SE) { + : VersionedLoop(L), NonVersionedLoop(nullptr), Preds(Preds), + RtPtrChecking(RtPtrChecking), LI(LI), DT(DT), SE(SE) { assert(L->getExitBlock() && "No single exit block"); assert(L->isLoopSimplifyForm() && "Loop is not in loop-simplify form"); } @@ -53,10 +51,8 @@ // Add the memcheck in the original preheader (this is empty initially). BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader(); - const auto &RtPtrChecking = *LAI.getRuntimePointerChecking(); - std::tie(FirstCheckInst, MemRuntimeCheck) = - addRuntimeChecks(RuntimeCheckBB->getTerminator(), VersionedLoop, - AliasChecks, RtPtrChecking.getSE()); + std::tie(FirstCheckInst, MemRuntimeCheck) = addRuntimeChecks( + RuntimeCheckBB->getTerminator(), VersionedLoop, RtPtrChecking); SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(), "scev.check"); @@ -166,7 +162,6 @@ // pointers memchecked together) to an alias scope and then also mapping each // group to the list of scopes it can't alias. - const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking(); LLVMContext &Context = VersionedLoop->getHeader()->getContext(); // First allocate an aliasing scope for each pointer checking group. @@ -177,11 +172,11 @@ MDBuilder MDB(Context); MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain"); - for (const auto &Group : RtPtrChecking->CheckingGroups) { + for (const auto &Group : RtPtrChecking.CheckingGroups) { GroupToScope[&*Group] = MDB.createAnonymousAliasScope(Domain); for (unsigned PtrIdx : Group->Members) - PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &*Group; + PtrToGroup[RtPtrChecking.getPointerInfo(PtrIdx).PointerValue] = &*Group; } // Go through the checks and for each pointer group, collect the scopes for @@ -189,7 +184,7 @@ DenseMap> GroupToNonAliasingScopes; - for (const auto &Check : AliasChecks) + for (const auto &Check : RtPtrChecking.getChecks()) GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]); // Finally, transform the above to actually map to scope list which is what @@ -199,7 +194,7 @@ GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second); } -void LoopVersioning::annotateLoopWithNoAlias() { +void LoopVersioning::annotateLoopWithNoAlias(const LoopAccessInfo &LAI) { if (!AnnotateNoAlias) return; @@ -276,10 +271,10 @@ if (L->isLoopSimplifyForm() && !LAI.hasConvergentOp() && (LAI.getNumRuntimePointerChecks() || !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) { - LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), - L, LI, DT, SE); + LoopVersioning LVer(*LAI.getRuntimePointerChecking(), + LAI.getPSE().getUnionPredicate(), L, LI, DT, SE); LVer.versionLoop(); - LVer.annotateLoopWithNoAlias(); + LVer.annotateLoopWithNoAlias(LAI); Changed = true; } } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -2813,8 +2813,7 @@ Instruction *FirstCheckInst; Instruction *MemRuntimeCheck; std::tie(FirstCheckInst, MemRuntimeCheck) = - addRuntimeChecks(MemCheckBlock->getTerminator(), OrigLoop, - RtPtrChecking.getChecks(), RtPtrChecking.getSE()); + addRuntimeChecks(MemCheckBlock->getTerminator(), OrigLoop, RtPtrChecking); assert(MemRuntimeCheck && "no RT checks generated although RtPtrChecking " "claimed checks are required"); @@ -2852,10 +2851,9 @@ // We currently don't use LoopVersioning for the actual loop cloning but we // still use it to add the noalias metadata. - LVer = std::make_unique( - *Legal->getLAI(), - Legal->getLAI()->getRuntimePointerChecking()->getChecks(), OrigLoop, LI, - DT, PSE.getSE()); + LVer = std::make_unique(*LAI->getRuntimePointerChecking(), + LAI->getPSE().getUnionPredicate(), + OrigLoop, LI, DT, PSE.getSE()); LVer->prepareNoAliasMetadata(); }