diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h --- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h @@ -324,9 +324,45 @@ void mergeInStatus(VectorizationSafetyStatus S); }; +class RuntimePointerChecking; +/// A grouping of pointers. A single memcheck is required between +/// two groups. +struct RuntimeCheckingPtrGroup { + /// Create a new pointer checking group containing a single + /// pointer, with index \p Index in RtCheck. + RuntimeCheckingPtrGroup(unsigned Index, RuntimePointerChecking &RtCheck); + + /// Tries to add the pointer recorded in RtCheck at index + /// \p Index to this pointer checking group. We can only add a pointer + /// to a checking group if we will still be able to get + /// the upper and lower bounds of the check. Returns true in case + /// of success, false otherwise. + bool addPointer(unsigned Index); + + /// Constitutes the context of this pointer checking group. For each + /// pointer that is a member of this group we will retain the index + /// at which it appears in RtCheck. + RuntimePointerChecking &RtCheck; + /// The SCEV expression which represents the upper bound of all the + /// pointers in this group. + const SCEV *High; + /// The SCEV expression which represents the lower bound of all the + /// pointers in this group. + const SCEV *Low; + /// Indices of all the pointers that constitute this grouping. + SmallVector Members; +}; + +/// A memcheck which made up of a pair of grouped pointers. +typedef std::pair + RuntimePointerCheck; + /// Holds information about the memory runtime legality checks to verify /// that a group of pointers do not overlap. class RuntimePointerChecking { + friend struct RuntimeCheckingPtrGroup; + public: struct PointerInfo { /// Holds the pointer value that we need to check. @@ -376,59 +412,20 @@ /// No run-time memory checking is necessary. bool empty() const { return Pointers.empty(); } - /// A grouping of pointers. A single memcheck is required between - /// two groups. - struct CheckingPtrGroup { - /// Create a new pointer checking group containing a single - /// pointer, with index \p Index in RtCheck. - CheckingPtrGroup(unsigned Index, RuntimePointerChecking &RtCheck) - : RtCheck(RtCheck), High(RtCheck.Pointers[Index].End), - Low(RtCheck.Pointers[Index].Start) { - Members.push_back(Index); - } - - /// Tries to add the pointer recorded in RtCheck at index - /// \p Index to this pointer checking group. We can only add a pointer - /// to a checking group if we will still be able to get - /// the upper and lower bounds of the check. Returns true in case - /// of success, false otherwise. - bool addPointer(unsigned Index); - - /// Constitutes the context of this pointer checking group. For each - /// pointer that is a member of this group we will retain the index - /// at which it appears in RtCheck. - RuntimePointerChecking &RtCheck; - /// The SCEV expression which represents the upper bound of all the - /// pointers in this group. - const SCEV *High; - /// The SCEV expression which represents the lower bound of all the - /// pointers in this group. - const SCEV *Low; - /// Indices of all the pointers that constitute this grouping. - SmallVector Members; - }; - - /// A memcheck which made up of a pair of grouped pointers. - /// - /// These *have* to be const for now, since checks are generated from - /// CheckingPtrGroups in LAI::addRuntimeChecks which is a const member - /// function. FIXME: once check-generation is moved inside this class (after - /// the PtrPartition hack is removed), we could drop const. - typedef std::pair - PointerCheck; - /// Generate the checks and store it. This also performs the grouping /// of pointers to reduce the number of memchecks necessary. void generateChecks(MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies); /// Returns the checks that generateChecks created. - const SmallVector &getChecks() const { return Checks; } + const SmallVector &getChecks() const { + return Checks; + } /// Decide if we need to add a check between two groups of pointers, /// according to needsChecking. - bool needsChecking(const CheckingPtrGroup &M, - const CheckingPtrGroup &N) const; + bool needsChecking(const RuntimeCheckingPtrGroup &M, + const RuntimeCheckingPtrGroup &N) const; /// Returns the number of run-time checks required according to /// needsChecking. @@ -438,7 +435,8 @@ void print(raw_ostream &OS, unsigned Depth = 0) const; /// Print \p Checks. - void printChecks(raw_ostream &OS, const SmallVectorImpl &Checks, + void printChecks(raw_ostream &OS, + const SmallVectorImpl &Checks, unsigned Depth = 0) const; /// This flag indicates if we need to add the runtime check. @@ -448,7 +446,7 @@ SmallVector Pointers; /// Holds a partitioning of pointers into "check groups". - SmallVector CheckingGroups; + SmallVector CheckingGroups; /// Check if pointers are in the same partition /// @@ -476,15 +474,14 @@ bool UseDependencies); /// Generate the checks and return them. - SmallVector - generateChecks() const; + SmallVector generateChecks() const; /// Holds a pointer to the ScalarEvolution analysis. ScalarEvolution *SE; /// Set of run-time checks required to establish independence of /// otherwise may-aliasing pointers in the loop. - SmallVector Checks; + SmallVector Checks; }; /// Drive the analysis of memory accesses in the loop @@ -557,10 +554,9 @@ /// Returns a pair of instructions where the first element is the first /// instruction generated in possibly a sequence of instructions and the /// second value is the final comparator value or NULL if no check is needed. - std::pair - addRuntimeChecks(Instruction *Loc, - const SmallVectorImpl - &PointerChecks) const; + std::pair addRuntimeChecks( + Instruction *Loc, + const SmallVectorImpl &PointerChecks) const; /// The diagnostics report generated for the analysis. E.g. why we /// couldn't analyze the loop. diff --git a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h --- a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h +++ b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h @@ -15,7 +15,6 @@ #ifndef LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H #define LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H -#include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ValueMapper.h" @@ -26,6 +25,10 @@ class LoopAccessInfo; class LoopInfo; class ScalarEvolution; +struct RuntimeCheckingPtrGroup; +typedef std::pair + RuntimePointerCheck; /// This class emits a version of the loop where run-time checks ensure /// that may-alias pointers can't overlap. @@ -71,8 +74,7 @@ Loop *getNonVersionedLoop() { return NonVersionedLoop; } /// Sets the runtime alias checks for versioning the loop. - void setAliasChecks( - SmallVector Checks); + void setAliasChecks(SmallVector Checks); /// Sets the runtime SCEV checks for versioning the loop. void setSCEVChecks(SCEVUnionPredicate Check); @@ -122,22 +124,20 @@ ValueToValueMapTy VMap; /// The set of alias checks that we are versioning for. - SmallVector AliasChecks; + SmallVector AliasChecks; /// The set of SCEV checks that we are versioning for. SCEVUnionPredicate Preds; /// Maps a pointer to the pointer checking group that the pointer /// belongs to. - DenseMap - PtrToGroup; + DenseMap PtrToGroup; /// The alias scope corresponding to a pointer checking group. - DenseMap - GroupToScope; + DenseMap GroupToScope; /// The list of alias scopes that a pointer checking group can't alias. - DenseMap + DenseMap GroupToNonAliasingScopeList; /// Analyses used. diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -174,6 +174,13 @@ return OrigSCEV; } +RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup( + unsigned Index, RuntimePointerChecking &RtCheck) + : RtCheck(RtCheck), High(RtCheck.Pointers[Index].End), + Low(RtCheck.Pointers[Index].Start) { + Members.push_back(Index); +} + /// Calculate Start and End points of memory access. /// Let's assume A is the first access and B is a memory access on N-th loop /// iteration. Then B is calculated as: @@ -231,14 +238,14 @@ Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc); } -SmallVector +SmallVector RuntimePointerChecking::generateChecks() const { - SmallVector Checks; + SmallVector Checks; for (unsigned I = 0; I < CheckingGroups.size(); ++I) { for (unsigned J = I + 1; J < CheckingGroups.size(); ++J) { - const RuntimePointerChecking::CheckingPtrGroup &CGI = CheckingGroups[I]; - const RuntimePointerChecking::CheckingPtrGroup &CGJ = CheckingGroups[J]; + const RuntimeCheckingPtrGroup &CGI = CheckingGroups[I]; + const RuntimeCheckingPtrGroup &CGJ = CheckingGroups[J]; if (needsChecking(CGI, CGJ)) Checks.push_back(std::make_pair(&CGI, &CGJ)); @@ -254,8 +261,8 @@ Checks = generateChecks(); } -bool RuntimePointerChecking::needsChecking(const CheckingPtrGroup &M, - const CheckingPtrGroup &N) const { +bool RuntimePointerChecking::needsChecking( + const RuntimeCheckingPtrGroup &M, const RuntimeCheckingPtrGroup &N) const { for (unsigned I = 0, EI = M.Members.size(); EI != I; ++I) for (unsigned J = 0, EJ = N.Members.size(); EJ != J; ++J) if (needsChecking(M.Members[I], N.Members[J])) @@ -277,7 +284,7 @@ return I; } -bool RuntimePointerChecking::CheckingPtrGroup::addPointer(unsigned Index) { +bool RuntimeCheckingPtrGroup::addPointer(unsigned Index) { const SCEV *Start = RtCheck.Pointers[Index].Start; const SCEV *End = RtCheck.Pointers[Index].End; @@ -352,7 +359,7 @@ // pointers to the same underlying object. if (!UseDependencies) { for (unsigned I = 0; I < Pointers.size(); ++I) - CheckingGroups.push_back(CheckingPtrGroup(I, *this)); + CheckingGroups.push_back(RuntimeCheckingPtrGroup(I, *this)); return; } @@ -378,7 +385,7 @@ MemoryDepChecker::MemAccessInfo Access(Pointers[I].PointerValue, Pointers[I].IsWritePtr); - SmallVector Groups; + SmallVector Groups; auto LeaderI = DepCands.findValue(DepCands.getLeaderValue(Access)); // Because DepCands is constructed by visiting accesses in the order in @@ -395,7 +402,7 @@ // Go through all the existing sets and see if we can find one // which can include this pointer. - for (CheckingPtrGroup &Group : Groups) { + for (RuntimeCheckingPtrGroup &Group : Groups) { // Don't perform more than a certain amount of comparisons. // This should limit the cost of grouping the pointers to something // reasonable. If we do end up hitting this threshold, the algorithm @@ -415,7 +422,7 @@ // We couldn't add this pointer to any existing set or the threshold // for the number of comparisons has been reached. Create a new group // to hold the current pointer. - Groups.push_back(CheckingPtrGroup(Pointer, *this)); + Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this)); } // We've computed the grouped checks for this partition. @@ -451,7 +458,7 @@ } void RuntimePointerChecking::printChecks( - raw_ostream &OS, const SmallVectorImpl &Checks, + raw_ostream &OS, const SmallVectorImpl &Checks, unsigned Depth) const { unsigned N = 0; for (const auto &Check : Checks) { @@ -2142,10 +2149,10 @@ /// Expand code for the lower and upper bound of the pointer group \p CG /// in \p TheLoop. \return the values for the bounds. -static PointerBounds -expandBounds(const RuntimePointerChecking::CheckingPtrGroup *CG, Loop *TheLoop, - Instruction *Loc, SCEVExpander &Exp, ScalarEvolution *SE, - const RuntimePointerChecking &PtrRtChecking) { +static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, + Loop *TheLoop, Instruction *Loc, + SCEVExpander &Exp, ScalarEvolution *SE, + const RuntimePointerChecking &PtrRtChecking) { Value *Ptr = PtrRtChecking.Pointers[CG->Members[0]].PointerValue; const SCEV *Sc = SE->getSCEV(Ptr); @@ -2181,17 +2188,17 @@ /// Turns a collection of checks into a collection of expanded upper and /// lower bounds for both pointers in the check. -static SmallVector, 4> expandBounds( - const SmallVectorImpl &PointerChecks, - Loop *L, Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp, - const RuntimePointerChecking &PtrRtChecking) { +static SmallVector, 4> +expandBounds(const SmallVectorImpl &PointerChecks, Loop *L, + Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp, + const RuntimePointerChecking &PtrRtChecking) { SmallVector, 4> ChecksWithBounds; // Here we're relying on the SCEV Expander's cache to only emit code for the // same bounds once. transform( PointerChecks, std::back_inserter(ChecksWithBounds), - [&](const RuntimePointerChecking::PointerCheck &Check) { + [&](const RuntimePointerCheck &Check) { PointerBounds First = expandBounds(Check.first, L, Loc, Exp, SE, PtrRtChecking), Second = expandBounds(Check.second, L, Loc, Exp, SE, PtrRtChecking); @@ -2203,8 +2210,7 @@ std::pair LoopAccessInfo::addRuntimeChecks( Instruction *Loc, - const SmallVectorImpl &PointerChecks) - const { + const SmallVectorImpl &PointerChecks) const { const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); auto *SE = PSE->getSE(); SCEVExpander Exp(*SE, DL, "induction"); diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -903,15 +903,14 @@ /// \p PtrToPartition contains the partition number for pointers. Partition /// number -1 means that the pointer is used in multiple partitions. In this /// case we can't safely omit the check. - SmallVector - includeOnlyCrossPartitionChecks( - const SmallVectorImpl &AllChecks, + SmallVector includeOnlyCrossPartitionChecks( + const SmallVectorImpl &AllChecks, const SmallVectorImpl &PtrToPartition, const RuntimePointerChecking *RtPtrChecking) { - SmallVector Checks; + SmallVector Checks; copy_if(AllChecks, std::back_inserter(Checks), - [&](const RuntimePointerChecking::PointerCheck &Check) { + [&](const RuntimePointerCheck &Check) { for (unsigned PtrIdx1 : Check.first->Members) for (unsigned PtrIdx2 : Check.second->Members) // Only include this check if there is a pair of pointers diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp --- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -377,7 +377,7 @@ /// Determine the pointer alias checks to prove that there are no /// intervening stores. - SmallVector collectMemchecks( + SmallVector collectMemchecks( const SmallVectorImpl &Candidates) { SmallPtrSet PtrsWrittenOnFwdingPath = @@ -391,10 +391,10 @@ std::mem_fn(&StoreToLoadForwardingCandidate::getLoadPtr)); const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks(); - SmallVector Checks; + SmallVector Checks; copy_if(AllChecks, std::back_inserter(Checks), - [&](const RuntimePointerChecking::PointerCheck &Check) { + [&](const RuntimePointerCheck &Check) { for (auto PtrIdx1 : Check.first->Members) for (auto PtrIdx2 : Check.second->Members) if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath, @@ -520,8 +520,7 @@ // Check intervening may-alias stores. These need runtime checks for alias // disambiguation. - SmallVector Checks = - collectMemchecks(Candidates); + SmallVector Checks = collectMemchecks(Candidates); // Too many checks are likely to outweigh the benefits of forwarding. if (Checks.size() > Candidates.size() * CheckPerElim) { diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -45,7 +45,7 @@ } void LoopVersioning::setAliasChecks( - SmallVector Checks) { + SmallVector Checks) { AliasChecks = std::move(Checks); } @@ -194,8 +194,7 @@ // Go through the checks and for each pointer group, collect the scopes for // each non-aliasing pointer group. - DenseMap> + DenseMap> GroupToNonAliasingScopes; for (const auto &Check : AliasChecks)