diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h --- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h @@ -341,17 +341,21 @@ /// pointer, with index \p Index in RtCheck. RuntimeCheckingPtrGroup(unsigned Index, RuntimePointerChecking &RtCheck); + RuntimeCheckingPtrGroup(unsigned Index, const SCEV *Start, const SCEV *End, + unsigned AS) + : High(End), Low(Start), AddressSpace(AS) { + Members.push_back(Index); + } + /// Tries to add the pointer recorded in RtCheck at index /// \p Index to this pointer checking group. We can only add a pointer /// to a checking group if we will still be able to get /// the upper and lower bounds of the check. Returns true in case /// of success, false otherwise. - bool addPointer(unsigned Index); + bool addPointer(unsigned Index, RuntimePointerChecking &RtCheck); + bool addPointer(unsigned Index, const SCEV *Start, const SCEV *End, + unsigned AS, ScalarEvolution &SE); - /// Constitutes the context of this pointer checking group. For each - /// pointer that is a member of this group we will retain the index - /// at which it appears in RtCheck. - RuntimePointerChecking &RtCheck; /// The SCEV expression which represents the upper bound of all the /// pointers in this group. const SCEV *High; @@ -360,6 +364,8 @@ const SCEV *Low; /// Indices of all the pointers that constitute this grouping. SmallVector Members; + /// Address space of the involved pointers. + unsigned AddressSpace; }; /// A memcheck which made up of a pair of grouped pointers. diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -170,8 +170,10 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup( unsigned Index, RuntimePointerChecking &RtCheck) - : RtCheck(RtCheck), High(RtCheck.Pointers[Index].End), - Low(RtCheck.Pointers[Index].Start) { + : High(RtCheck.Pointers[Index].End), Low(RtCheck.Pointers[Index].Start), + AddressSpace(RtCheck.Pointers[Index] + .PointerValue->getType() + ->getPointerAddressSpace()) { Members.push_back(Index); } @@ -279,18 +281,28 @@ return I; } -bool RuntimeCheckingPtrGroup::addPointer(unsigned Index) { - const SCEV *Start = RtCheck.Pointers[Index].Start; - const SCEV *End = RtCheck.Pointers[Index].End; +bool RuntimeCheckingPtrGroup::addPointer(unsigned Index, + RuntimePointerChecking &RtCheck) { + return addPointer( + Index, RtCheck.Pointers[Index].Start, RtCheck.Pointers[Index].End, + RtCheck.Pointers[Index].PointerValue->getType()->getPointerAddressSpace(), + *RtCheck.SE); +} + +bool RuntimeCheckingPtrGroup::addPointer(unsigned Index, const SCEV *Start, + const SCEV *End, unsigned AS, + ScalarEvolution &SE) { + assert(AddressSpace == AS && + "all pointers in a checking group must be in the same address space"); // Compare the starts and ends with the known minimum and maximum // of this set. We need to know how we compare against the min/max // of the set in order to be able to emit memchecks. - const SCEV *Min0 = getMinFromExprs(Start, Low, RtCheck.SE); + const SCEV *Min0 = getMinFromExprs(Start, Low, &SE); if (!Min0) return false; - const SCEV *Min1 = getMinFromExprs(End, High, RtCheck.SE); + const SCEV *Min1 = getMinFromExprs(End, High, &SE); if (!Min1) return false; @@ -410,7 +422,7 @@ TotalComparisons++; - if (Group.addPointer(Pointer)) { + if (Group.addPointer(Pointer, *this)) { Merged = true; break; } diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1524,14 +1524,8 @@ static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, Loop *TheLoop, Instruction *Loc, SCEVExpander &Exp) { - // TODO: Add helper to retrieve pointers to CG. - Value *Ptr = CG->RtCheck.Pointers[CG->Members[0]].PointerValue; - - unsigned AS = Ptr->getType()->getPointerAddressSpace(); LLVMContext &Ctx = Loc->getContext(); - - // Use this type for pointer arithmetic. - Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS); + Type *PtrArithTy = Type::getInt8PtrTy(Ctx, CG->AddressSpace); Value *Start = nullptr, *End = nullptr; LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n");