diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h --- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h @@ -324,9 +324,44 @@ void mergeInStatus(VectorizationSafetyStatus S); }; +class RuntimePointerChecking; +/// A grouping of pointers. A single memcheck is required between +/// two groups. +struct CheckingPtrGroup { + /// Create a new pointer checking group containing a single + /// pointer, with index \p Index in RtCheck. + CheckingPtrGroup(unsigned Index, RuntimePointerChecking &RtCheck); + + /// Tries to add the pointer recorded in RtCheck at index + /// \p Index to this pointer checking group. We can only add a pointer + /// to a checking group if we will still be able to get + /// the upper and lower bounds of the check. Returns true in case + /// of success, false otherwise. + bool addPointer(unsigned Index); + + /// Constitutes the context of this pointer checking group. For each + /// pointer that is a member of this group we will retain the index + /// at which it appears in RtCheck. + RuntimePointerChecking &RtCheck; + /// The SCEV expression which represents the upper bound of all the + /// pointers in this group. + const SCEV *High; + /// The SCEV expression which represents the lower bound of all the + /// pointers in this group. + const SCEV *Low; + /// Indices of all the pointers that constitute this grouping. + SmallVector Members; +}; + +/// A memcheck which made up of a pair of grouped pointers. +typedef std::pair + PointerCheck; + /// Holds information about the memory runtime legality checks to verify /// that a group of pointers do not overlap. class RuntimePointerChecking { + friend struct CheckingPtrGroup; + public: struct PointerInfo { /// Holds the pointer value that we need to check. @@ -376,47 +411,6 @@ /// No run-time memory checking is necessary. bool empty() const { return Pointers.empty(); } - /// A grouping of pointers. A single memcheck is required between - /// two groups. - struct CheckingPtrGroup { - /// Create a new pointer checking group containing a single - /// pointer, with index \p Index in RtCheck. - CheckingPtrGroup(unsigned Index, RuntimePointerChecking &RtCheck) - : RtCheck(RtCheck), High(RtCheck.Pointers[Index].End), - Low(RtCheck.Pointers[Index].Start) { - Members.push_back(Index); - } - - /// Tries to add the pointer recorded in RtCheck at index - /// \p Index to this pointer checking group. We can only add a pointer - /// to a checking group if we will still be able to get - /// the upper and lower bounds of the check. Returns true in case - /// of success, false otherwise. - bool addPointer(unsigned Index); - - /// Constitutes the context of this pointer checking group. For each - /// pointer that is a member of this group we will retain the index - /// at which it appears in RtCheck. - RuntimePointerChecking &RtCheck; - /// The SCEV expression which represents the upper bound of all the - /// pointers in this group. - const SCEV *High; - /// The SCEV expression which represents the lower bound of all the - /// pointers in this group. - const SCEV *Low; - /// Indices of all the pointers that constitute this grouping. - SmallVector Members; - }; - - /// A memcheck which made up of a pair of grouped pointers. - /// - /// These *have* to be const for now, since checks are generated from - /// CheckingPtrGroups in LAI::addRuntimeChecks which is a const member - /// function. FIXME: once check-generation is moved inside this class (after - /// the PtrPartition hack is removed), we could drop const. - typedef std::pair - PointerCheck; - /// Generate the checks and store it. This also performs the grouping /// of pointers to reduce the number of memchecks necessary. void generateChecks(MemoryDepChecker::DepCandidates &DepCands, @@ -559,8 +553,7 @@ /// second value is the final comparator value or NULL if no check is needed. std::pair addRuntimeChecks(Instruction *Loc, - const SmallVectorImpl - &PointerChecks) const; + const SmallVectorImpl &PointerChecks) const; /// The diagnostics report generated for the analysis. E.g. why we /// couldn't analyze the loop. diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h --- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -49,6 +49,9 @@ template class SmallVector; template class SmallVectorImpl; template class SmallPriorityWorklist; +struct CheckingPtrGroup; +typedef std::pair + PointerCheck; BasicBlock *InsertPreheaderForLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, MemorySSAUpdater *MSSAU, bool PreserveLCSSA); diff --git a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h --- a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h +++ b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h @@ -15,7 +15,6 @@ #ifndef LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H #define LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H -#include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ValueMapper.h" @@ -26,6 +25,9 @@ class LoopAccessInfo; class LoopInfo; class ScalarEvolution; +struct CheckingPtrGroup; +typedef std::pair + PointerCheck; /// This class emits a version of the loop where run-time checks ensure /// that may-alias pointers can't overlap. @@ -71,8 +73,7 @@ Loop *getNonVersionedLoop() { return NonVersionedLoop; } /// Sets the runtime alias checks for versioning the loop. - void setAliasChecks( - SmallVector Checks); + void setAliasChecks(SmallVector Checks); /// Sets the runtime SCEV checks for versioning the loop. void setSCEVChecks(SCEVUnionPredicate Check); @@ -122,23 +123,20 @@ ValueToValueMapTy VMap; /// The set of alias checks that we are versioning for. - SmallVector AliasChecks; + SmallVector AliasChecks; /// The set of SCEV checks that we are versioning for. SCEVUnionPredicate Preds; /// Maps a pointer to the pointer checking group that the pointer /// belongs to. - DenseMap - PtrToGroup; + DenseMap PtrToGroup; /// The alias scope corresponding to a pointer checking group. - DenseMap - GroupToScope; + DenseMap GroupToScope; /// The list of alias scopes that a pointer checking group can't alias. - DenseMap - GroupToNonAliasingScopeList; + DenseMap GroupToNonAliasingScopeList; /// Analyses used. const LoopAccessInfo &LAI; diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -174,6 +174,13 @@ return OrigSCEV; } +CheckingPtrGroup::CheckingPtrGroup(unsigned Index, + RuntimePointerChecking &RtCheck) + : RtCheck(RtCheck), High(RtCheck.Pointers[Index].End), + Low(RtCheck.Pointers[Index].Start) { + Members.push_back(Index); +} + /// Calculate Start and End points of memory access. /// Let's assume A is the first access and B is a memory access on N-th loop /// iteration. Then B is calculated as: @@ -231,14 +238,13 @@ Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc); } -SmallVector -RuntimePointerChecking::generateChecks() const { +SmallVector RuntimePointerChecking::generateChecks() const { SmallVector Checks; for (unsigned I = 0; I < CheckingGroups.size(); ++I) { for (unsigned J = I + 1; J < CheckingGroups.size(); ++J) { - const RuntimePointerChecking::CheckingPtrGroup &CGI = CheckingGroups[I]; - const RuntimePointerChecking::CheckingPtrGroup &CGJ = CheckingGroups[J]; + const CheckingPtrGroup &CGI = CheckingGroups[I]; + const CheckingPtrGroup &CGJ = CheckingGroups[J]; if (needsChecking(CGI, CGJ)) Checks.push_back(std::make_pair(&CGI, &CGJ)); @@ -277,7 +283,7 @@ return I; } -bool RuntimePointerChecking::CheckingPtrGroup::addPointer(unsigned Index) { +bool CheckingPtrGroup::addPointer(unsigned Index) { const SCEV *Start = RtCheck.Pointers[Index].Start; const SCEV *End = RtCheck.Pointers[Index].End; diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -903,15 +903,14 @@ /// \p PtrToPartition contains the partition number for pointers. Partition /// number -1 means that the pointer is used in multiple partitions. In this /// case we can't safely omit the check. - SmallVector - includeOnlyCrossPartitionChecks( - const SmallVectorImpl &AllChecks, + SmallVector includeOnlyCrossPartitionChecks( + const SmallVectorImpl &AllChecks, const SmallVectorImpl &PtrToPartition, const RuntimePointerChecking *RtPtrChecking) { - SmallVector Checks; + SmallVector Checks; copy_if(AllChecks, std::back_inserter(Checks), - [&](const RuntimePointerChecking::PointerCheck &Check) { + [&](const PointerCheck &Check) { for (unsigned PtrIdx1 : Check.first->Members) for (unsigned PtrIdx2 : Check.second->Members) // Only include this check if there is a pair of pointers diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp --- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -377,7 +377,7 @@ /// Determine the pointer alias checks to prove that there are no /// intervening stores. - SmallVector collectMemchecks( + SmallVector collectMemchecks( const SmallVectorImpl &Candidates) { SmallPtrSet PtrsWrittenOnFwdingPath = @@ -391,10 +391,10 @@ std::mem_fn(&StoreToLoadForwardingCandidate::getLoadPtr)); const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks(); - SmallVector Checks; + SmallVector Checks; copy_if(AllChecks, std::back_inserter(Checks), - [&](const RuntimePointerChecking::PointerCheck &Check) { + [&](const PointerCheck &Check) { for (auto PtrIdx1 : Check.first->Members) for (auto PtrIdx2 : Check.second->Members) if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath, @@ -520,8 +520,7 @@ // Check intervening may-alias stores. These need runtime checks for alias // disambiguation. - SmallVector Checks = - collectMemchecks(Candidates); + SmallVector Checks = collectMemchecks(Candidates); // Too many checks are likely to outweigh the benefits of forwarding. if (Checks.size() > Candidates.size() * CheckPerElim) { diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -44,8 +44,7 @@ } } -void LoopVersioning::setAliasChecks( - SmallVector Checks) { +void LoopVersioning::setAliasChecks(SmallVector Checks) { AliasChecks = std::move(Checks); } @@ -194,8 +193,7 @@ // Go through the checks and for each pointer group, collect the scopes for // each non-aliasing pointer group. - DenseMap> + DenseMap> GroupToNonAliasingScopes; for (const auto &Check : AliasChecks)