diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h --- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h @@ -465,6 +465,8 @@ return Pointers[PtrIdx]; } + ScalarEvolution *getSE() const { return SE; } + private: /// Groups pointers such that a single memcheck is required /// between two different groups. This will clear the CheckingGroups vector @@ -541,16 +543,6 @@ unsigned getNumStores() const { return NumStores; } unsigned getNumLoads() const { return NumLoads;} - /// Add code that checks at runtime if the accessed arrays \in p PointerChecks - /// overlap. - /// - /// Returns a pair of instructions where the first element is the first - /// instruction generated in possibly a sequence of instructions and the - /// second value is the final comparator value or NULL if no check is needed. - std::pair addRuntimeChecks( - Instruction *Loc, - const SmallVectorImpl &PointerChecks) const; - /// The diagnostics report generated for the analysis. E.g. why we /// couldn't analyze the loop. const OptimizationRemarkAnalysis *getReport() const { return Report.get(); } diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h --- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -41,6 +41,10 @@ class TargetTransformInfo; class LPPassManager; class Instruction; +struct RuntimeCheckingPtrGroup; +typedef std::pair + RuntimePointerCheck; template class Optional; template class SmallSetVector; @@ -428,6 +432,17 @@ Loop *cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, LoopInfo *LI, LPPassManager *LPM); +/// Add code that checks at runtime if the accessed arrays in \p PointerChecks +/// overlap. +/// +/// Returns a pair of instructions where the first element is the first +/// instruction generated in possibly a sequence of instructions and the +/// second value is the final comparator value or NULL if no check is needed. +std::pair +addRuntimeChecks(Instruction *Loc, Loop *TheLoop, + const SmallVectorImpl &PointerChecks, + ScalarEvolution *SE); + } // end namespace llvm #endif // LLVM_TRANSFORMS_UTILS_LOOPUTILS_H diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -43,7 +43,6 @@ #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" @@ -2123,158 +2122,6 @@ return (SE->isLoopInvariant(SE->getSCEV(V), TheLoop)); } -// FIXME: this function is currently a duplicate of the one in -// LoopVectorize.cpp. -static Instruction *getFirstInst(Instruction *FirstInst, Value *V, - Instruction *Loc) { - if (FirstInst) - return FirstInst; - if (Instruction *I = dyn_cast(V)) - return I->getParent() == Loc->getParent() ? I : nullptr; - return nullptr; -} - -namespace { - -/// IR Values for the lower and upper bounds of a pointer evolution. We -/// need to use value-handles because SCEV expansion can invalidate previously -/// expanded values. Thus expansion of a pointer can invalidate the bounds for -/// a previous one. -struct PointerBounds { - TrackingVH Start; - TrackingVH End; -}; - -} // end anonymous namespace - -/// Expand code for the lower and upper bound of the pointer group \p CG -/// in \p TheLoop. \return the values for the bounds. -static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, - Loop *TheLoop, Instruction *Loc, - SCEVExpander &Exp, ScalarEvolution *SE) { - Value *Ptr = CG->RtCheck.Pointers[CG->Members[0]].PointerValue; - - const SCEV *Sc = SE->getSCEV(Ptr); - - unsigned AS = Ptr->getType()->getPointerAddressSpace(); - LLVMContext &Ctx = Loc->getContext(); - - // Use this type for pointer arithmetic. - Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS); - - if (SE->isLoopInvariant(Sc, TheLoop)) { - LLVM_DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" - << *Ptr << "\n"); - // Ptr could be in the loop body. If so, expand a new one at the correct - // location. - Instruction *Inst = dyn_cast(Ptr); - Value *NewPtr = (Inst && TheLoop->contains(Inst)) - ? Exp.expandCodeFor(Sc, PtrArithTy, Loc) - : Ptr; - // We must return a half-open range, which means incrementing Sc. - const SCEV *ScPlusOne = SE->getAddExpr(Sc, SE->getOne(PtrArithTy)); - Value *NewPtrPlusOne = Exp.expandCodeFor(ScPlusOne, PtrArithTy, Loc); - return {NewPtr, NewPtrPlusOne}; - } else { - Value *Start = nullptr, *End = nullptr; - LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); - Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc); - End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc); - LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High - << "\n"); - return {Start, End}; - } -} - -/// Turns a collection of checks into a collection of expanded upper and -/// lower bounds for both pointers in the check. -static SmallVector, 4> -expandBounds(const SmallVectorImpl &PointerChecks, Loop *L, - Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp) { - SmallVector, 4> ChecksWithBounds; - - // Here we're relying on the SCEV Expander's cache to only emit code for the - // same bounds once. - transform( - PointerChecks, std::back_inserter(ChecksWithBounds), - [&](const RuntimePointerCheck &Check) { - PointerBounds - First = expandBounds(Check.first, L, Loc, Exp, SE), - Second = expandBounds(Check.second, L, Loc, Exp, SE); - return std::make_pair(First, Second); - }); - - return ChecksWithBounds; -} - -std::pair LoopAccessInfo::addRuntimeChecks( - Instruction *Loc, - const SmallVectorImpl &PointerChecks) const { - const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); - auto *SE = PSE->getSE(); - SCEVExpander Exp(*SE, DL, "induction"); - auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, SE, Exp); - - LLVMContext &Ctx = Loc->getContext(); - Instruction *FirstInst = nullptr; - IRBuilder<> ChkBuilder(Loc); - // Our instructions might fold to a constant. - Value *MemoryRuntimeCheck = nullptr; - - for (const auto &Check : ExpandedChecks) { - const PointerBounds &A = Check.first, &B = Check.second; - // Check if two pointers (A and B) conflict where conflict is computed as: - // start(A) <= end(B) && start(B) <= end(A) - unsigned AS0 = A.Start->getType()->getPointerAddressSpace(); - unsigned AS1 = B.Start->getType()->getPointerAddressSpace(); - - assert((AS0 == B.End->getType()->getPointerAddressSpace()) && - (AS1 == A.End->getType()->getPointerAddressSpace()) && - "Trying to bounds check pointers with different address spaces"); - - Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0); - Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1); - - Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc"); - Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc"); - Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc"); - Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc"); - - // [A|B].Start points to the first accessed byte under base [A|B]. - // [A|B].End points to the last accessed byte, plus one. - // There is no conflict when the intervals are disjoint: - // NoConflict = (B.Start >= A.End) || (A.Start >= B.End) - // - // bound0 = (B.Start < A.End) - // bound1 = (A.Start < B.End) - // IsConflict = bound0 & bound1 - Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0"); - FirstInst = getFirstInst(FirstInst, Cmp0, Loc); - Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1"); - FirstInst = getFirstInst(FirstInst, Cmp1, Loc); - Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); - FirstInst = getFirstInst(FirstInst, IsConflict, Loc); - if (MemoryRuntimeCheck) { - IsConflict = - ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); - FirstInst = getFirstInst(FirstInst, IsConflict, Loc); - } - MemoryRuntimeCheck = IsConflict; - } - - if (!MemoryRuntimeCheck) - return std::make_pair(nullptr, nullptr); - - // We have to do this trickery because the IRBuilder might fold the check to a - // constant expression in which case there is no Instruction anchored in a - // the block. - Instruction *Check = BinaryOperator::CreateAnd(MemoryRuntimeCheck, - ConstantInt::getTrue(Ctx)); - ChkBuilder.Insert(Check, "memcheck.conflict"); - FirstInst = getFirstInst(FirstInst, Check, Loc); - return std::make_pair(FirstInst, Check); -} - void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { Value *Ptr = nullptr; if (LoadInst *LI = dyn_cast(MemAccess)) diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1534,3 +1534,152 @@ return &New; } + +/// IR Values for the lower and upper bounds of a pointer evolution. We +/// need to use value-handles because SCEV expansion can invalidate previously +/// expanded values. Thus expansion of a pointer can invalidate the bounds for +/// a previous one. +struct PointerBounds { + TrackingVH Start; + TrackingVH End; +}; + +/// Expand code for the lower and upper bound of the pointer group \p CG +/// in \p TheLoop. \return the values for the bounds. +static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, + Loop *TheLoop, Instruction *Loc, + SCEVExpander &Exp, ScalarEvolution *SE) { + // TODO: Add helper to retrieve pointers to CG. + Value *Ptr = CG->RtCheck.Pointers[CG->Members[0]].PointerValue; + const SCEV *Sc = SE->getSCEV(Ptr); + + unsigned AS = Ptr->getType()->getPointerAddressSpace(); + LLVMContext &Ctx = Loc->getContext(); + + // Use this type for pointer arithmetic. + Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS); + + if (SE->isLoopInvariant(Sc, TheLoop)) { + LLVM_DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" + << *Ptr << "\n"); + // Ptr could be in the loop body. If so, expand a new one at the correct + // location. + Instruction *Inst = dyn_cast(Ptr); + Value *NewPtr = (Inst && TheLoop->contains(Inst)) + ? Exp.expandCodeFor(Sc, PtrArithTy, Loc) + : Ptr; + // We must return a half-open range, which means incrementing Sc. + const SCEV *ScPlusOne = SE->getAddExpr(Sc, SE->getOne(PtrArithTy)); + Value *NewPtrPlusOne = Exp.expandCodeFor(ScPlusOne, PtrArithTy, Loc); + return {NewPtr, NewPtrPlusOne}; + } else { + Value *Start = nullptr, *End = nullptr; + LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); + Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc); + End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc); + LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High + << "\n"); + return {Start, End}; + } +} + +/// Turns a collection of checks into a collection of expanded upper and +/// lower bounds for both pointers in the check. +static SmallVector, 4> +expandBounds(const SmallVectorImpl &PointerChecks, Loop *L, + Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp) { + SmallVector, 4> ChecksWithBounds; + + // Here we're relying on the SCEV Expander's cache to only emit code for the + // same bounds once. + transform(PointerChecks, std::back_inserter(ChecksWithBounds), + [&](const RuntimePointerCheck &Check) { + PointerBounds First = expandBounds(Check.first, L, Loc, Exp, SE), + Second = + expandBounds(Check.second, L, Loc, Exp, SE); + return std::make_pair(First, Second); + }); + + return ChecksWithBounds; +} + +std::pair llvm::addRuntimeChecks( + Instruction *Loc, Loop *TheLoop, + const SmallVectorImpl &PointerChecks, + ScalarEvolution *SE) { + // TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible. + // TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible + const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); + SCEVExpander Exp(*SE, DL, "induction"); + auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, SE, Exp); + + LLVMContext &Ctx = Loc->getContext(); + Instruction *FirstInst = nullptr; + IRBuilder<> ChkBuilder(Loc); + // Our instructions might fold to a constant. + Value *MemoryRuntimeCheck = nullptr; + + // FIXME: this helper is currently a duplicate of the one in + // LoopVectorize.cpp. + auto GetFirstInst = [](Instruction *FirstInst, Value *V, + Instruction *Loc) -> Instruction * { + if (FirstInst) + return FirstInst; + if (Instruction *I = dyn_cast(V)) + return I->getParent() == Loc->getParent() ? I : nullptr; + return nullptr; + }; + + for (const auto &Check : ExpandedChecks) { + const PointerBounds &A = Check.first, &B = Check.second; + // Check if two pointers (A and B) conflict where conflict is computed as: + // start(A) <= end(B) && start(B) <= end(A) + unsigned AS0 = A.Start->getType()->getPointerAddressSpace(); + unsigned AS1 = B.Start->getType()->getPointerAddressSpace(); + + assert((AS0 == B.End->getType()->getPointerAddressSpace()) && + (AS1 == A.End->getType()->getPointerAddressSpace()) && + "Trying to bounds check pointers with different address spaces"); + + Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0); + Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1); + + Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc"); + Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc"); + Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc"); + Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc"); + + // [A|B].Start points to the first accessed byte under base [A|B]. + // [A|B].End points to the last accessed byte, plus one. + // There is no conflict when the intervals are disjoint: + // NoConflict = (B.Start >= A.End) || (A.Start >= B.End) + // + // bound0 = (B.Start < A.End) + // bound1 = (A.Start < B.End) + // IsConflict = bound0 & bound1 + Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0"); + FirstInst = GetFirstInst(FirstInst, Cmp0, Loc); + Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1"); + FirstInst = GetFirstInst(FirstInst, Cmp1, Loc); + Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); + FirstInst = GetFirstInst(FirstInst, IsConflict, Loc); + if (MemoryRuntimeCheck) { + IsConflict = + ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); + FirstInst = GetFirstInst(FirstInst, IsConflict, Loc); + } + MemoryRuntimeCheck = IsConflict; + } + + if (!MemoryRuntimeCheck) + return std::make_pair(nullptr, nullptr); + + // We have to do this trickery because the IRBuilder might fold the check to a + // constant expression in which case there is no Instruction anchored in a + // the block. + Instruction *Check = + BinaryOperator::CreateAnd(MemoryRuntimeCheck, ConstantInt::getTrue(Ctx)); + ChkBuilder.Insert(Check, "memcheck.conflict"); + FirstInst = GetFirstInst(FirstInst, Check, Loc); + return std::make_pair(FirstInst, Check); +} diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -62,8 +62,10 @@ // Add the memcheck in the original preheader (this is empty initially). BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader(); + const auto &RtPtrChecking = *LAI.getRuntimePointerChecking(); std::tie(FirstCheckInst, MemRuntimeCheck) = - LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks); + addRuntimeChecks(RuntimeCheckBB->getTerminator(), VersionedLoop, + AliasChecks, RtPtrChecking.getSE()); const SCEVUnionPredicate &Pred = LAI.getPSE().getUnionPredicate(); SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(), diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -2791,8 +2791,9 @@ return; Instruction *FirstCheckInst; Instruction *MemRuntimeCheck; - std::tie(FirstCheckInst, MemRuntimeCheck) = LAI->addRuntimeChecks( - MemCheckBlock->getTerminator(), RtPtrChecking.getChecks()); + std::tie(FirstCheckInst, MemRuntimeCheck) = + addRuntimeChecks(MemCheckBlock->getTerminator(), OrigLoop, + RtPtrChecking.getChecks(), RtPtrChecking.getSE()); if (!MemRuntimeCheck) return;