diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h --- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h @@ -388,12 +388,15 @@ /// SCEV for the access. const SCEV *Expr; + /// Translated base SCEV for the access. + const SCEV *PtrExpr; + PointerInfo(Value *PointerValue, const SCEV *Start, const SCEV *End, bool IsWritePtr, unsigned DependencySetId, unsigned AliasSetId, - const SCEV *Expr) + const SCEV *Expr, const SCEV *PtrExpr) : PointerValue(PointerValue), Start(Start), End(End), IsWritePtr(IsWritePtr), DependencySetId(DependencySetId), - AliasSetId(AliasSetId), Expr(Expr) {} + AliasSetId(AliasSetId), Expr(Expr), PtrExpr(PtrExpr) {} }; RuntimePointerChecking(ScalarEvolution *SE) : Need(false), SE(SE) {} @@ -410,9 +413,8 @@ /// according to the assumptions that we've made during the analysis. /// The method might also version the pointer stride according to \p Strides, /// and add new predicates to \p PSE. - void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, - unsigned ASId, const ValueToValueMap &Strides, - PredicatedScalarEvolution &PSE); + void insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr, bool WritePtr, + unsigned DepSetId, unsigned ASId, PredicatedScalarEvolution &PSE); /// No run-time memory checking is necessary. bool empty() const { return Pointers.empty(); } @@ -420,7 +422,8 @@ /// Generate the checks and store it. This also performs the grouping /// of pointers to reduce the number of memchecks necessary. void generateChecks(MemoryDepChecker::DepCandidates &DepCands, - bool UseDependencies); + bool UseDependencies, PredicatedScalarEvolution &PSE, + Loop &L, const ValueToValueMap &SymbolicStrides); /// Returns the checks that generateChecks created. const SmallVectorImpl &getChecks() const { @@ -478,7 +481,8 @@ /// and re-compute it. We will only group dependecies if \p UseDependencies /// is true, otherwise we will create a separate group for each pointer. void groupChecks(MemoryDepChecker::DepCandidates &DepCands, - bool UseDependencies); + bool UseDependencies, PredicatedScalarEvolution &PSE, + Loop &L, const ValueToValueMap &SymbolicStrides); /// Generate the checks and return them. SmallVector generateChecks() const; diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -189,13 +189,13 @@ /// /// There is no conflict when the intervals are disjoint: /// NoConflict = (P2.Start >= P1.End) || (P1.Start >= P2.End) -void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, - unsigned DepSetId, unsigned ASId, - const ValueToValueMap &Strides, +void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr, + bool WritePtr, unsigned DepSetId, + unsigned ASId, PredicatedScalarEvolution &PSE) { // Get the stride replaced scev. - const SCEV *Sc = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); ScalarEvolution *SE = PSE.getSE(); + const SCEV *Sc = PtrExpr; const SCEV *ScStart; const SCEV *ScEnd; @@ -231,7 +231,8 @@ SE->getStoreSizeOfExpr(IdxTy, Ptr->getType()->getPointerElementType()); ScEnd = SE->getAddExpr(ScEnd, EltSizeSCEV); - Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc); + Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc, + PtrExpr); } SmallVector @@ -251,9 +252,11 @@ } void RuntimePointerChecking::generateChecks( - MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies) { + MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies, + PredicatedScalarEvolution &PSE, Loop &L, + const ValueToValueMap &SymbolicStrides) { assert(Checks.empty() && "Checks is not empty"); - groupChecks(DepCands, UseDependencies); + groupChecks(DepCands, UseDependencies, PSE, L, SymbolicStrides); Checks = generateChecks(); } @@ -317,8 +320,46 @@ return true; } +static void +visitPointers2(Value *StartPtr, const Loop &InnermostLoop, + PredicatedScalarEvolution &PSE, + const ValueToValueMap &SymbolicStrides, + function_ref AddPointer) { + SmallPtrSet Visited; + SmallVector WorkList; + WorkList.push_back(StartPtr); + + ScalarEvolution &SE = *PSE.getSE(); + while (!WorkList.empty()) { + Value *Ptr = WorkList.pop_back_val(); + if (!Visited.insert(Ptr).second) + continue; + auto *PN = dyn_cast(Ptr); + auto *GEP = dyn_cast(Ptr); + if (GEP && GEP->getNumOperands() == 2) { + if (auto *SI = dyn_cast(GEP->getOperand(0))) { + const SCEV *BaseA = SE.getSCEV(SI->getOperand(1)); + const SCEV *BaseB = SE.getSCEV(SI->getOperand(2)); + const SCEV *Offset = SE.getSCEV(GEP->getOperand(1)); + if (SE.getTypeSizeInBits(Offset->getType()) < + SE.getTypeSizeInBits(BaseA->getType())) + Offset = SE.getSignExtendExpr( + Offset, SE.getEffectiveSCEVType(BaseA->getType())); + auto *PtrA = SE.getAddExpr(BaseA, Offset, SCEV::FlagNUW); + auto *PtrB = SE.getAddExpr(BaseB, Offset, SCEV::FlagNUW); + AddPointer(Ptr, PtrA); + AddPointer(Ptr, PtrB); + continue; + } + } + AddPointer(Ptr, replaceSymbolicStrideSCEV(PSE, SymbolicStrides, Ptr)); + } +} + void RuntimePointerChecking::groupChecks( - MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies) { + MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies, + PredicatedScalarEvolution &PSE, Loop &L, + const ValueToValueMap &SymbolicStrides) { // We build the groups from dependency candidates equivalence classes // because: // - We know that pointers in the same equivalence class share @@ -371,9 +412,9 @@ unsigned TotalComparisons = 0; - DenseMap PositionMap; + DenseMap PositionMap; for (unsigned Index = 0; Index < Pointers.size(); ++Index) - PositionMap[Pointers[Index].PointerValue] = Index; + PositionMap[Pointers[Index].PtrExpr] = Index; // We need to keep track of what pointers we've already seen so we // don't process them twice. @@ -401,37 +442,46 @@ // equivalence class, the iteration order is deterministic. for (auto MI = DepCands.member_begin(LeaderI), ME = DepCands.member_end(); MI != ME; ++MI) { - auto PointerI = PositionMap.find(MI->getPointer()); - assert(PointerI != PositionMap.end() && - "pointer in equivalence class not found in PositionMap"); - unsigned Pointer = PointerI->second; - bool Merged = false; - // Mark this pointer as seen. - Seen.insert(Pointer); - - // Go through all the existing sets and see if we can find one - // which can include this pointer. - for (RuntimeCheckingPtrGroup &Group : Groups) { - // Don't perform more than a certain amount of comparisons. - // This should limit the cost of grouping the pointers to something - // reasonable. If we do end up hitting this threshold, the algorithm - // will create separate groups for all remaining pointers. - if (TotalComparisons > MemoryCheckMergeThreshold) - break; - - TotalComparisons++; - - if (Group.addPointer(Pointer, *this)) { - Merged = true; - break; + ValueToValueMap SymbolicStrides; + SmallVector TranslatedPtrs; + visitPointers2(MI->getPointer(), L, PSE, SymbolicStrides, + [&TranslatedPtrs](Value *Ptr, const SCEV *PtrExpr) { + TranslatedPtrs.push_back(PtrExpr); + }); + + for (const SCEV *PtrScev : TranslatedPtrs) { + auto PointerI = PositionMap.find(PtrScev); + assert(PointerI != PositionMap.end() && + "pointer in equivalence class not found in PositionMap"); + unsigned Pointer = PointerI->second; + bool Merged = false; + // Mark this pointer as seen. + Seen.insert(Pointer); + + // Go through all the existing sets and see if we can find one + // which can include this pointer. + for (RuntimeCheckingPtrGroup &Group : Groups) { + // Don't perform more than a certain amount of comparisons. + // This should limit the cost of grouping the pointers to something + // reasonable. If we do end up hitting this threshold, the algorithm + // will create separate groups for all remaining pointers. + if (TotalComparisons > MemoryCheckMergeThreshold) + break; + + TotalComparisons++; + + if (Group.addPointer(Pointer, *this)) { + Merged = true; + break; + } } - } - if (!Merged) - // We couldn't add this pointer to any existing set or the threshold - // for the number of comparisons has been reached. Create a new group - // to hold the current pointer. - Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this)); + if (!Merged) + // We couldn't add this pointer to any existing set or the threshold + // for the number of comparisons has been reached. Create a new group + // to hold the current pointer. + Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this)); + } } // We've computed the grouped checks for this partition. @@ -631,11 +681,8 @@ /// Check whether a pointer can participate in a runtime bounds check. /// If \p Assume, try harder to prove that we can compute the bounds of \p Ptr /// by adding run-time checks (overflow checks) if necessary. -static bool hasComputableBounds(PredicatedScalarEvolution &PSE, - const ValueToValueMap &Strides, Value *Ptr, - Loop *L, bool Assume) { - const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); - +static bool hasComputableBounds(PredicatedScalarEvolution &PSE, Value *Ptr, + const SCEV *PtrScev, Loop *L, bool Assume) { // The bounds for loop-invariant pointer is trivial. if (PSE.getSE()->isLoopInvariant(PtrScev, L)) return true; @@ -698,34 +745,49 @@ bool Assume) { Value *Ptr = Access.getPointer(); - if (!hasComputableBounds(PSE, StridesMap, Ptr, TheLoop, Assume)) - return false; + SmallVector TranslatedPtrs; + visitPointers2(Ptr, *TheLoop, PSE, StridesMap, + [&TranslatedPtrs](Value *Ptr, const SCEV *PtrExpr) { + TranslatedPtrs.push_back(PtrExpr); + }); - // When we run after a failing dependency check we have to make sure - // we don't have wrapping pointers. - if (ShouldCheckWrap && !isNoWrap(PSE, StridesMap, Ptr, TheLoop)) { - auto *Expr = PSE.getSCEV(Ptr); - if (!Assume || !isa(Expr)) + for (const SCEV *PtrExpr : TranslatedPtrs) { + if (!hasComputableBounds(PSE, Ptr, PtrExpr, TheLoop, Assume)) return false; - PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + + // When we run after a failing dependency check we have to make sure + // we don't have wrapping pointers. + if (ShouldCheckWrap && !isNoWrap(PSE, StridesMap, Ptr, TheLoop)) { + if (TranslatedPtrs.size() > 1) + return false; + auto *Expr = PSE.getSCEV(Ptr); + if (!Assume || !isa(Expr)) + return false; + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + } } - // The id of the dependence set. - unsigned DepId; + if (TranslatedPtrs.size() == 1) + TranslatedPtrs[0] = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr); - if (isDependencyCheckNeeded()) { - Value *Leader = DepCands.getLeaderValue(Access).getPointer(); - unsigned &LeaderId = DepSetId[Leader]; - if (!LeaderId) - LeaderId = RunningDepId++; - DepId = LeaderId; - } else - // Each access has its own dependence set. - DepId = RunningDepId++; + for (const SCEV *PtrExpr : TranslatedPtrs) { + // The id of the dependence set. + unsigned DepId; + + if (isDependencyCheckNeeded()) { + Value *Leader = DepCands.getLeaderValue(Access).getPointer(); + unsigned &LeaderId = DepSetId[Leader]; + if (!LeaderId) + LeaderId = RunningDepId++; + DepId = LeaderId; + } else + // Each access has its own dependence set. + DepId = RunningDepId++; - bool IsWrite = Access.getInt(); - RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, PSE); - LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); + bool IsWrite = Access.getInt(); + RtCheck.insert(TheLoop, Ptr, PtrExpr, IsWrite, DepId, ASId, PSE); + LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); + } return true; } @@ -861,7 +923,8 @@ } if (MayNeedRTCheck && CanDoRT) - RtCheck.generateChecks(DepCands, IsDepCheckNeeded); + RtCheck.generateChecks(DepCands, IsDepCheckNeeded, PSE, *TheLoop, + StridesMap); LLVM_DEBUG(dbgs() << "LAA: We need to do " << RtCheck.getNumberOfChecks() << " pointer comparisons.\n");