diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h --- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h @@ -50,6 +50,50 @@ static unsigned RuntimeMemoryCheckThreshold; }; +struct MemAccessInfo { + PointerIntPair ValueAndBool; + MemAccessInfo(Value *V, bool B, const SCEV *PtrExpr) + : ValueAndBool(V, B), PtrExpr(PtrExpr) {} + MemAccessInfo() : ValueAndBool(nullptr) {} + MemAccessInfo(PointerIntPair V) : ValueAndBool(V) {} + + const SCEV *PtrExpr = nullptr; + + Value *getPointer() { return ValueAndBool.getPointer(); } + bool getInt() { return ValueAndBool.getInt(); } + Value *getPointer() const { return ValueAndBool.getPointer(); } + bool getInt() const { return ValueAndBool.getInt(); } + const SCEV *getPtrExpr() const { return PtrExpr; } + + bool operator<(const MemAccessInfo &RHS) const { + return ValueAndBool < RHS.ValueAndBool; + } + bool operator==(const MemAccessInfo &RHS) const { + return ValueAndBool == RHS.ValueAndBool && PtrExpr == RHS.PtrExpr; + } +}; + +template <> struct DenseMapInfo { + using Ty = DenseMapInfo>; + + static MemAccessInfo getEmptyKey() { + return MemAccessInfo(Ty::getEmptyKey()); + } + + static MemAccessInfo getTombstoneKey() { + return MemAccessInfo(Ty::getTombstoneKey()); + } + + static unsigned getHashValue(MemAccessInfo V) { + uintptr_t IV = reinterpret_cast(V.ValueAndBool.getOpaqueValue()); + return unsigned(IV) ^ unsigned(IV >> 9); + } + + static bool isEqual(const MemAccessInfo &LHS, const MemAccessInfo &RHS) { + return LHS == RHS; + } +}; + /// Checks memory dependences among accesses to the same underlying /// object to determine whether there vectorization is legal or not (and at /// which vectorization factor). @@ -86,7 +130,6 @@ /// class MemoryDepChecker { public: - typedef PointerIntPair MemAccessInfo; typedef SmallVector MemAccessInfoList; /// Set of potential dependent memory accesses. typedef EquivalenceClasses DepCandidates; @@ -177,11 +220,11 @@ /// Register the location (instructions are given increasing numbers) /// of a write access. - void addAccess(StoreInst *SI); + void addAccess(StoreInst *SI, const ValueToValueMap &SymbolicStrides); /// Register the location (instructions are given increasing numbers) /// of a write access. - void addAccess(LoadInst *LI); + void addAccess(LoadInst *LI, const ValueToValueMap &SymbolicStrides); /// Check whether the dependencies between the accesses are safe. /// @@ -245,8 +288,9 @@ } /// Find the set of instructions that read or write via \p Ptr. - SmallVector getInstructionsForAccess(Value *Ptr, - bool isWrite) const; + SmallVector + getInstructionsForAccess(Value *Ptr, bool isWrite, + const ValueToValueMap &SymbolicStrides) const; private: /// A wrapper around ScalarEvolution, used to add runtime SCEV checks, and @@ -388,12 +432,15 @@ /// SCEV for the access. const SCEV *Expr; + /// SCEV for the access. + const SCEV *PtrExpr; + PointerInfo(Value *PointerValue, const SCEV *Start, const SCEV *End, bool IsWritePtr, unsigned DependencySetId, unsigned AliasSetId, - const SCEV *Expr) + const SCEV *Expr, const SCEV *PtrExpr) : PointerValue(PointerValue), Start(Start), End(End), IsWritePtr(IsWritePtr), DependencySetId(DependencySetId), - AliasSetId(AliasSetId), Expr(Expr) {} + AliasSetId(AliasSetId), Expr(Expr), PtrExpr(PtrExpr) {} }; RuntimePointerChecking(ScalarEvolution *SE) : Need(false), SE(SE) {} @@ -410,8 +457,8 @@ /// according to the assumptions that we've made during the analysis. /// The method might also version the pointer stride according to \p Strides, /// and add new predicates to \p PSE. - void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, - unsigned ASId, const ValueToValueMap &Strides, + void insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr, const SCEV *Sc, + bool WritePtr, unsigned DepSetId, unsigned ASId, PredicatedScalarEvolution &PSE); /// No run-time memory checking is necessary. @@ -558,9 +605,10 @@ /// Return the list of instructions that use \p Ptr to read or write /// memory. - SmallVector getInstructionsForAccess(Value *Ptr, - bool isWrite) const { - return DepChecker->getInstructionsForAccess(Ptr, isWrite); + SmallVector + getInstructionsForAccess(Value *Ptr, bool isWrite, + const ValueToValueMap &SymbolicStrides) const { + return DepChecker->getInstructionsForAccess(Ptr, isWrite, SymbolicStrides); } /// If an access has a symbolic strides, this maps the pointer value to diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -189,12 +189,11 @@ /// /// There is no conflict when the intervals are disjoint: /// NoConflict = (P2.Start >= P1.End) || (P1.Start >= P2.End) -void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, +void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr, + const SCEV *Sc, bool WritePtr, unsigned DepSetId, unsigned ASId, - const ValueToValueMap &Strides, PredicatedScalarEvolution &PSE) { // Get the stride replaced scev. - const SCEV *Sc = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); ScalarEvolution *SE = PSE.getSE(); const SCEV *ScStart; @@ -231,7 +230,8 @@ SE->getStoreSizeOfExpr(IdxTy, Ptr->getType()->getPointerElementType()); ScEnd = SE->getAddExpr(ScEnd, EltSizeSCEV); - Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc); + Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc, + PtrExpr); } SmallVector @@ -371,9 +371,9 @@ unsigned TotalComparisons = 0; - DenseMap PositionMap; + DenseMap PositionMap; for (unsigned Index = 0; Index < Pointers.size(); ++Index) - PositionMap[Pointers[Index].PointerValue] = Index; + PositionMap[Pointers[Index].PtrExpr] = Index; // We need to keep track of what pointers we've already seen so we // don't process them twice. @@ -388,8 +388,8 @@ if (Seen.count(I)) continue; - MemoryDepChecker::MemAccessInfo Access(Pointers[I].PointerValue, - Pointers[I].IsWritePtr); + MemAccessInfo Access(Pointers[I].PointerValue, Pointers[I].IsWritePtr, + Pointers[I].PtrExpr); SmallVector Groups; auto LeaderI = DepCands.findValue(DepCands.getLeaderValue(Access)); @@ -401,7 +401,7 @@ // equivalence class, the iteration order is deterministic. for (auto MI = DepCands.member_begin(LeaderI), ME = DepCands.member_end(); MI != ME; ++MI) { - auto PointerI = PositionMap.find(MI->getPointer()); + auto PointerI = PositionMap.find(MI->getPtrExpr()); assert(PointerI != PositionMap.end() && "pointer in equivalence class not found in PositionMap"); unsigned Pointer = PointerI->second; @@ -513,7 +513,6 @@ class AccessAnalysis { public: /// Read or write access location. - typedef PointerIntPair MemAccessInfo; typedef SmallVector MemAccessInfoList; AccessAnalysis(Loop *TheLoop, AAResults *AA, LoopInfo *LI, @@ -523,19 +522,19 @@ IsRTCheckAnalysisNeeded(false), PSE(PSE) {} /// Register a load and whether it is only read from. - void addLoad(MemoryLocation &Loc, bool IsReadOnly) { + void addLoad(MemoryLocation &Loc, bool IsReadOnly, const SCEV *PtrScev) { Value *Ptr = const_cast(Loc.Ptr); AST.add(Ptr, LocationSize::beforeOrAfterPointer(), Loc.AATags); - Accesses.insert(MemAccessInfo(Ptr, false)); + Accesses.insert(MemAccessInfo(Ptr, false, PtrScev)); if (IsReadOnly) ReadOnlyPtr.insert(Ptr); } /// Register a store. - void addStore(MemoryLocation &Loc) { + void addStore(MemoryLocation &Loc, const SCEV *PtrScev) { Value *Ptr = const_cast(Loc.Ptr); AST.add(Ptr, LocationSize::beforeOrAfterPointer(), Loc.AATags); - Accesses.insert(MemAccessInfo(Ptr, true)); + Accesses.insert(MemAccessInfo(Ptr, true, PtrScev)); } /// Check if we can emit a run-time no-alias check for \p Access. @@ -547,11 +546,9 @@ /// the bounds of the pointer. bool createCheckForAccess(RuntimePointerChecking &RtCheck, MemAccessInfo Access, - const ValueToValueMap &Strides, - DenseMap &DepSetId, + DenseMap &DepSetId, Loop *TheLoop, unsigned &RunningDepId, - unsigned ASId, bool ShouldCheckStride, - bool Assume); + unsigned ASId, bool ShouldCheckStride, bool Assume); /// Check whether we can check the pointers at runtime for /// non-intersection. @@ -564,8 +561,8 @@ /// Goes over all memory accesses, checks whether a RT check is needed /// and builds sets of dependent accesses. - void buildDependenceSets() { - processMemAccesses(); + void buildDependenceSets(ValueToValueMap &SymbolicStrides) { + processMemAccesses(SymbolicStrides); } /// Initial processing of memory accesses determined that we need to @@ -588,7 +585,7 @@ /// Go over all memory access and check whether runtime pointer checks /// are needed and build sets of dependency check candidates. - void processMemAccesses(); + void processMemAccesses(ValueToValueMap &SymbolicStrides); /// Set of all accesses. PtrAccessSet Accesses; @@ -631,43 +628,45 @@ /// Check whether a pointer can participate in a runtime bounds check. /// If \p Assume, try harder to prove that we can compute the bounds of \p Ptr /// by adding run-time checks (overflow checks) if necessary. -static bool hasComputableBounds(PredicatedScalarEvolution &PSE, - const ValueToValueMap &Strides, Value *Ptr, - Loop *L, bool Assume) { - const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); - +static const SCEV *hasComputableBounds(PredicatedScalarEvolution &PSE, + Value *Ptr, const SCEV *PtrScev, Loop *L, + bool Assume) { // The bounds for loop-invariant pointer is trivial. if (PSE.getSE()->isLoopInvariant(PtrScev, L)) - return true; + return PtrScev; const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (!AR && Assume) AR = PSE.getAsAddRec(Ptr); - - if (!AR) - return false; - - return AR->isAffine(); + if (AR && AR->isAffine()) + return AR; + return nullptr; } +static int64_t getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, + Value *Ptr, const SCEV *PtrScev, const Loop *Lp, + bool Assume = false, bool ShouldCheckWrap = true); + /// Check whether a pointer address cannot wrap. -static bool isNoWrap(PredicatedScalarEvolution &PSE, - const ValueToValueMap &Strides, Value *Ptr, Loop *L) { - const SCEV *PtrScev = PSE.getSCEV(Ptr); +static bool isNoWrap(PredicatedScalarEvolution &PSE, Value *Ptr, + const SCEV *PtrScev, Loop *L) { if (PSE.getSE()->isLoopInvariant(PtrScev, L)) return true; Type *AccessTy = Ptr->getType()->getPointerElementType(); - int64_t Stride = getPtrStride(PSE, AccessTy, Ptr, L, Strides); + int64_t Stride = getPtrStride(PSE, AccessTy, Ptr, PtrScev, L); if (Stride == 1 || PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW)) return true; return false; } -static void visitPointers(Value *StartPtr, const Loop &InnermostLoop, - function_ref AddPointer) { +static void +visitPointers(Value *StartPtr, const Loop &InnermostLoop, + PredicatedScalarEvolution &PSE, + const ValueToValueMap &SymbolicStrides, + function_ref AddPointer) { SmallPtrSet Visited; SmallVector WorkList; WorkList.push_back(StartPtr); @@ -685,28 +684,27 @@ for (const Use &Inc : PN->incoming_values()) WorkList.push_back(Inc); } else - AddPointer(Ptr); + AddPointer(Ptr, replaceSymbolicStrideSCEV(PSE, SymbolicStrides, Ptr)); } } -bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck, - MemAccessInfo Access, - const ValueToValueMap &StridesMap, - DenseMap &DepSetId, - Loop *TheLoop, unsigned &RunningDepId, - unsigned ASId, bool ShouldCheckWrap, - bool Assume) { +bool AccessAnalysis::createCheckForAccess( + RuntimePointerChecking &RtCheck, MemAccessInfo Access, + DenseMap &DepSetId, Loop *TheLoop, + unsigned &RunningDepId, unsigned ASId, bool ShouldCheckWrap, bool Assume) { Value *Ptr = Access.getPointer(); - - if (!hasComputableBounds(PSE, StridesMap, Ptr, TheLoop, Assume)) + const SCEV *PtrScev = + hasComputableBounds(PSE, Ptr, Access.getPtrExpr(), TheLoop, Assume); + if (!PtrScev) return false; // When we run after a failing dependency check we have to make sure // we don't have wrapping pointers. - if (ShouldCheckWrap && !isNoWrap(PSE, StridesMap, Ptr, TheLoop)) { - auto *Expr = PSE.getSCEV(Ptr); - if (!Assume || !isa(Expr)) + if (ShouldCheckWrap && !isNoWrap(PSE, Ptr, PtrScev, TheLoop)) { + if (!Assume || !isa(PtrScev)) return false; + + // FIXME: not handled properly. PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); } @@ -714,7 +712,7 @@ unsigned DepId; if (isDependencyCheckNeeded()) { - Value *Leader = DepCands.getLeaderValue(Access).getPointer(); + const SCEV *Leader = DepCands.getLeaderValue(Access).getPtrExpr(); unsigned &LeaderId = DepSetId[Leader]; if (!LeaderId) LeaderId = RunningDepId++; @@ -724,11 +722,12 @@ DepId = RunningDepId++; bool IsWrite = Access.getInt(); - RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, PSE); + RtCheck.insert(TheLoop, Ptr, Access.getPtrExpr(), PtrScev, IsWrite, DepId, + ASId, PSE); LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); return true; - } +} bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE, Loop *TheLoop, @@ -755,7 +754,7 @@ // We assign consecutive id to access from different dependence sets. // Accesses within the same set don't need a runtime check. unsigned RunningDepId = 1; - DenseMap DepSetId; + DenseMap DepSetId; SmallVector Retries; @@ -764,13 +763,14 @@ SmallVector AccessInfos; for (const auto &A : AS) { Value *Ptr = A.getValue(); - bool IsWrite = Accesses.count(MemAccessInfo(Ptr, true)); + const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr); + bool IsWrite = Accesses.count(MemAccessInfo(Ptr, true, PtrScev)); if (IsWrite) ++NumWritePtrChecks; else ++NumReadPtrChecks; - AccessInfos.emplace_back(Ptr, IsWrite); + AccessInfos.emplace_back(Ptr, IsWrite, PtrScev); } // We do not need runtime checks for this alias set, if there are no writes @@ -779,8 +779,11 @@ (NumWritePtrChecks == 1 && NumReadPtrChecks == 0)) { assert((AS.size() <= 1 || all_of(AS, - [this](auto AC) { - MemAccessInfo AccessWrite(AC.getValue(), true); + [this, &StridesMap](auto AC) { + MemAccessInfo AccessWrite( + AC.getValue(), true, + replaceSymbolicStrideSCEV(PSE, StridesMap, + AC.getValue())); return DepCands.findValue(AccessWrite) == DepCands.end(); })) && "Can only skip updating CanDoRT below, if all entries in AS " @@ -789,7 +792,7 @@ } for (auto &Access : AccessInfos) { - if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId, TheLoop, + if (!createCheckForAccess(RtCheck, Access, DepSetId, TheLoop, RunningDepId, ASId, ShouldCheckWrap, false)) { LLVM_DEBUG(dbgs() << "LAA: Can't find bounds for ptr:" << *Access.getPointer() << '\n'); @@ -817,9 +820,9 @@ // and add further checks if required (overflow checks). CanDoAliasSetRT = true; for (auto Access : Retries) - if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId, - TheLoop, RunningDepId, ASId, - ShouldCheckWrap, /*Assume=*/true)) { + if (!createCheckForAccess(RtCheck, Access, DepSetId, TheLoop, + RunningDepId, ASId, ShouldCheckWrap, + /*Assume=*/true)) { CanDoAliasSetRT = false; break; } @@ -877,7 +880,7 @@ return CanDoRTIfNeeded; } -void AccessAnalysis::processMemAccesses() { +void AccessAnalysis::processMemAccesses(ValueToValueMap &SymbolicStrides) { // We process the set twice: first we process read-write pointers, last we // process read-only pointers. This allows us to skip dependence tests for // read-only pointers. @@ -932,13 +935,16 @@ bool IsReadOnlyPtr = ReadOnlyPtr.count(Ptr) && !IsWrite; if (UseDeferred && !IsReadOnlyPtr) continue; + + const SCEV *PtrExpr = + replaceSymbolicStrideSCEV(PSE, SymbolicStrides, Ptr); // Otherwise, the pointer must be in the PtrAccessSet, either as a // read or a write. assert(((IsReadOnlyPtr && UseDeferred) || IsWrite || - S.count(MemAccessInfo(Ptr, false))) && + S.count(MemAccessInfo(Ptr, false, PtrExpr))) && "Alias-set pointer not in the access set?"); - MemAccessInfo Access(Ptr, IsWrite); + MemAccessInfo Access(Ptr, IsWrite, PtrExpr); DepCands.insert(Access); // Memorize read-only pointers for later processing and skip them in @@ -1049,16 +1055,13 @@ } /// Check whether the access through \p Ptr has a constant stride. -int64_t llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, - Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap, bool Assume, - bool ShouldCheckWrap) { +static int64_t getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, + Value *Ptr, const SCEV *PtrScev, const Loop *Lp, + bool Assume, bool ShouldCheckWrap) { Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); assert(!AccessTy->isAggregateType() && "Bad stride - Not a pointer to a scalar type"); - const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr); - const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (Assume && !AR) AR = PSE.getAsAddRec(Ptr); @@ -1085,9 +1088,10 @@ // space 0, therefore we can also vectorize this case. unsigned AddrSpace = Ty->getPointerAddressSpace(); bool IsInBoundsGEP = isInBoundsGep(Ptr); - bool IsNoWrapAddRec = !ShouldCheckWrap || - PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW) || - isNoWrapAddRec(Ptr, AR, PSE, Lp); + bool IsNoWrapAddRec = + !ShouldCheckWrap || + PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW) || + isNoWrapAddRec(Ptr, AR, PSE, Lp); if (!IsNoWrapAddRec && !IsInBoundsGEP && NullPointerIsDefined(Lp->getHeader()->getParent(), AddrSpace)) { if (Assume) { @@ -1136,8 +1140,8 @@ // know we can't "wrap around the address space". In case of address space // zero we know that this won't happen without triggering undefined behavior. if (!IsNoWrapAddRec && Stride != 1 && Stride != -1 && - (IsInBoundsGEP || !NullPointerIsDefined(Lp->getHeader()->getParent(), - AddrSpace))) { + (IsInBoundsGEP || + !NullPointerIsDefined(Lp->getHeader()->getParent(), AddrSpace))) { if (Assume) { // We can avoid this case by adding a run-time check. LLVM_DEBUG(dbgs() << "LAA: Non unit strided pointer which is not either " @@ -1153,6 +1157,28 @@ return Stride; } +/// Check whether the access through \p Ptr has a constant stride. +int64_t llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, + Value *Ptr, const Loop *Lp, + const ValueToValueMap &StridesMap, bool Assume, + bool ShouldCheckWrap) { + Type *Ty = Ptr->getType(); + assert(Ty->isPointerTy() && "Unexpected non-ptr"); + unsigned AddrSpace = Ty->getPointerAddressSpace(); + + // Make sure we're not accessing an aggregate type. + // TODO: Why? This doesn't make any sense. + if (AccessTy->isAggregateType()) { + LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type" + << *Ptr << "\n"); + return 0; + } + + const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr); + return ::getPtrStride(PSE, AccessTy, Ptr, PtrScev, Lp, Assume, + ShouldCheckWrap); +} + Optional llvm::getPointersDiff(Type *ElemTyA, Value *PtrA, Type *ElemTyB, Value *PtrB, const DataLayout &DL, ScalarEvolution &SE, bool StrictCheck, @@ -1279,19 +1305,23 @@ return Diff && *Diff == 1; } -void MemoryDepChecker::addAccess(StoreInst *SI) { - visitPointers(SI->getPointerOperand(), *InnermostLoop, - [this, SI](Value *Ptr) { - Accesses[MemAccessInfo(Ptr, true)].push_back(AccessIdx); +void MemoryDepChecker::addAccess(StoreInst *SI, + const ValueToValueMap &SymbolicStrides) { + visitPointers(SI->getPointerOperand(), *InnermostLoop, PSE, SymbolicStrides, + [this, SI](Value *Ptr, const SCEV *PtrScev) { + Accesses[MemAccessInfo(Ptr, true, PtrScev)].push_back( + AccessIdx); InstMap.push_back(SI); ++AccessIdx; }); } -void MemoryDepChecker::addAccess(LoadInst *LI) { - visitPointers(LI->getPointerOperand(), *InnermostLoop, - [this, LI](Value *Ptr) { - Accesses[MemAccessInfo(Ptr, false)].push_back(AccessIdx); +void MemoryDepChecker::addAccess(LoadInst *LI, + const ValueToValueMap &SymbolicStrides) { + visitPointers(LI->getPointerOperand(), *InnermostLoop, PSE, SymbolicStrides, + [this, LI](Value *Ptr, const SCEV *PtrScev) { + Accesses[MemAccessInfo(Ptr, false, PtrScev)].push_back( + AccessIdx); InstMap.push_back(LI); ++AccessIdx; }); @@ -1705,9 +1735,9 @@ const ValueToValueMap &Strides) { MaxSafeDepDistBytes = -1; - SmallPtrSet Visited; + SmallPtrSet Visited; for (MemAccessInfo CurAccess : CheckDeps) { - if (Visited.count(CurAccess)) + if (Visited.count(&CurAccess)) continue; // Get the relevant memory access set. @@ -1722,7 +1752,7 @@ // Check every access pair. while (AI != AE) { - Visited.insert(*AI); + Visited.insert(&*AI); bool AIIsWrite = AI->getInt(); // Check loads only against next equivalent class, but stores also against // other stores in the same equivalence class - to the same address. @@ -1777,15 +1807,17 @@ return isSafeForVectorization(); } -SmallVector -MemoryDepChecker::getInstructionsForAccess(Value *Ptr, bool isWrite) const { - MemAccessInfo Access(Ptr, isWrite); - auto &IndexVector = Accesses.find(Access)->second; - +SmallVector MemoryDepChecker::getInstructionsForAccess( + Value *Ptr, bool isWrite, const ValueToValueMap &SymbolicStrides) const { SmallVector Insts; - transform(IndexVector, - std::back_inserter(Insts), - [&](unsigned Idx) { return this->InstMap[Idx]; }); + for (auto &KV : Accesses) { + if (KV.first.getPointer() == Ptr) { + auto &IndexVector = KV.second; + + transform(IndexVector, std::back_inserter(Insts), + [&](unsigned Idx) { return this->InstMap[Idx]; }); + } + } return Insts; } @@ -1916,9 +1948,10 @@ } NumLoads++; Loads.push_back(Ld); - DepChecker->addAccess(Ld); if (EnableMemAccessVersioningOfLoop) collectStridedAccess(Ld); + + DepChecker->addAccess(Ld, SymbolicStrides); continue; } @@ -1940,9 +1973,11 @@ } NumStores++; Stores.push_back(St); - DepChecker->addAccess(St); + if (EnableMemAccessVersioningOfLoop) collectStridedAccess(St); + + DepChecker->addAccess(St, SymbolicStrides); } } // Next instr. } // Next block. @@ -1996,10 +2031,11 @@ if (blockNeedsPredication(ST->getParent(), TheLoop, DT)) Loc.AATags.TBAA = nullptr; - visitPointers(const_cast(Loc.Ptr), *TheLoop, - [&Accesses, Loc](Value *Ptr) { + visitPointers(const_cast(Loc.Ptr), *TheLoop, *PSE, + SymbolicStrides, + [&Accesses, Loc](Value *Ptr, const SCEV *PtrScev) { MemoryLocation NewLoc = Loc.getWithNewPtr(Ptr); - Accesses.addStore(NewLoc); + Accesses.addStore(NewLoc, PtrScev); }); } } @@ -2044,11 +2080,12 @@ if (blockNeedsPredication(LD->getParent(), TheLoop, DT)) Loc.AATags.TBAA = nullptr; - visitPointers(const_cast(Loc.Ptr), *TheLoop, - [&Accesses, Loc, IsReadOnlyPtr](Value *Ptr) { - MemoryLocation NewLoc = Loc.getWithNewPtr(Ptr); - Accesses.addLoad(NewLoc, IsReadOnlyPtr); - }); + visitPointers( + const_cast(Loc.Ptr), *TheLoop, *PSE, SymbolicStrides, + [&Accesses, Loc, IsReadOnlyPtr](Value *Ptr, const SCEV *PtrScev) { + MemoryLocation NewLoc = Loc.getWithNewPtr(Ptr); + Accesses.addLoad(NewLoc, IsReadOnlyPtr, PtrScev); + }); } // If we write (or read-write) to a single destination and there are no @@ -2061,7 +2098,7 @@ // Build dependence sets and check whether we need a runtime pointer bounds // check. - Accesses.buildDependenceSets(); + Accesses.buildDependenceSets(SymbolicStrides); // Find pointers with computable bounds. We are going to use this information // to place a runtime bound check. diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -510,8 +510,8 @@ SmallVector PtrToPartitions(N); for (unsigned I = 0; I < N; ++I) { Value *Ptr = RtPtrCheck->Pointers[I].PointerValue; - auto Instructions = - LAI.getInstructionsForAccess(Ptr, RtPtrCheck->Pointers[I].IsWritePtr); + auto Instructions = LAI.getInstructionsForAccess( + Ptr, RtPtrCheck->Pointers[I].IsWritePtr, LAI.getSymbolicStrides()); int &Partition = PtrToPartitions[I]; // First set it to uninitialized. diff --git a/llvm/test/Analysis/LoopAccessAnalysis/pointer-phis.ll b/llvm/test/Analysis/LoopAccessAnalysis/pointer-phis.ll --- a/llvm/test/Analysis/LoopAccessAnalysis/pointer-phis.ll +++ b/llvm/test/Analysis/LoopAccessAnalysis/pointer-phis.ll @@ -206,6 +206,10 @@ ; CHECK-NEXT: %v8 = load double, double* %arrayidx, align 8 -> ; CHECK-NEXT: store double %mul16, double* %ptr.2, align 8 ; CHECK-EMPTY: +; CHECK-NEXT: Unknown: +; CHECK-NEXT: %v8 = load double, double* %arrayidx, align 8 -> +; CHECK-NEXT: store double %mul16, double* %ptr.2, align 8 +; CHECK-EMPTY: ; CHECK-NEXT: Run-time memory checks: ; CHECK-NEXT: Check 0: ; CHECK-NEXT: Comparing group ([[GROUP_C:.+]]): @@ -282,6 +286,10 @@ ; CHECK-NEXT: %v8 = load double, double* %arrayidx, align 8 -> ; CHECK-NEXT: store double %mul16, double* %ptr.3, align 8 ; CHECK-EMPTY: +; CHECK-NEXT: Unknown: +; CHECK-NEXT: %v8 = load double, double* %arrayidx, align 8 -> +; CHECK-NEXT: store double %mul16, double* %ptr.3, align 8 +; CHECK-EMPTY: ; CHECK-NEXT: Run-time memory checks: ; CHECK-NEXT: Check 0: ; CHECK-NEXT: Comparing group ([[GROUP_C:.+]]): @@ -430,6 +438,14 @@ ; CHECK-NEXT: store i16 %lv, i16* %A, align 1 -> ; CHECK-NEXT: %lv2 = load i16, i16* %A, align 1 ; CHECK-EMPTY: +; CHECK-NEXT: Unknown: +; CHECK-NEXT: %lv3 = load i16, i16* %c.sink, align 2 -> +; CHECK-NEXT: store i16 %add, i16* %c.sink, align 1 +; CHECK-EMPTY: +; CHECK-NEXT: Unknown: +; CHECK-NEXT: %lv3 = load i16, i16* %c.sink, align 2 -> +; CHECK-NEXT: store i16 %add, i16* %c.sink, align 1 +; CHECK-EMPTY: ; CHECK-NEXT: Run-time memory checks: ; CHECK-NEXT: Check 0: ; CHECK-NEXT: Comparing group ([[GROUP_A:.+]]): diff --git a/llvm/test/Transforms/LoopLoadElim/symbolic-stride.ll b/llvm/test/Transforms/LoopLoadElim/symbolic-stride.ll --- a/llvm/test/Transforms/LoopLoadElim/symbolic-stride.ll +++ b/llvm/test/Transforms/LoopLoadElim/symbolic-stride.ll @@ -59,8 +59,8 @@ define void @two_strides(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i64 %N, i64 %stride.1, i64 %stride.2) { -; TWO_STRIDE_SPEC: %ident.check = icmp ne i64 %stride.2, 1 -; TWO_STRIDE_SPEC: %ident.check1 = icmp ne i64 %stride.1, 1 +; TWO_STRIDE_SPEC: %ident.check = icmp ne i64 %stride.1, 1 +; TWO_STRIDE_SPEC: %ident.check1 = icmp ne i64 %stride.2, 1 ; NO_TWO_STRIDE_SPEC-NOT: %ident.check{{.*}} = icmp ne i64 %stride{{.*}}, 1 entry: