Index: include/llvm/Analysis/LoopAccessAnalysis.h =================================================================== --- include/llvm/Analysis/LoopAccessAnalysis.h +++ include/llvm/Analysis/LoopAccessAnalysis.h @@ -163,6 +163,13 @@ unsigned getNumStores() const { return NumStores; } unsigned getNumLoads() const { return NumLoads;} + /// \brief Add code that checks at runtime if the accessed arrays overlap. + /// + /// Returns a pair of instructions where the first element is the first + /// instruction generated in possibly a sequence of instructions and the + /// second value is the final comparator value or NULL if no check is needed. + std::pair addRuntimeCheck(Instruction *Loc); + private: void emitAnalysis(VectorizationReport &Message); Index: lib/Analysis/LoopAccessAnalysis.cpp =================================================================== --- lib/Analysis/LoopAccessAnalysis.cpp +++ lib/Analysis/LoopAccessAnalysis.cpp @@ -14,9 +14,11 @@ #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Utils/VectorUtils.h" using namespace llvm; @@ -1082,3 +1084,107 @@ bool LoopAccessAnalysis::isUniform(Value *V) { return (SE->isLoopInvariant(SE->getSCEV(V), TheLoop)); } + +// FIXME: this function is currently a duplicate of the one in +// LoopVectorize.cpp. +static Instruction *getFirstInst(Instruction *FirstInst, Value *V, + Instruction *Loc) { + if (FirstInst) + return FirstInst; + if (Instruction *I = dyn_cast(V)) + return I->getParent() == Loc->getParent() ? I : nullptr; + return nullptr; +} + +std::pair +LoopAccessAnalysis::addRuntimeCheck(Instruction *Loc) { + Instruction *tnullptr = nullptr; + if (!PtrRtCheck.Need) + return std::pair(tnullptr, tnullptr); + + unsigned NumPointers = PtrRtCheck.Pointers.size(); + SmallVector , 2> Starts; + SmallVector , 2> Ends; + + LLVMContext &Ctx = Loc->getContext(); + SCEVExpander Exp(*SE, "induction"); + Instruction *FirstInst = nullptr; + + for (unsigned i = 0; i < NumPointers; ++i) { + Value *Ptr = PtrRtCheck.Pointers[i]; + const SCEV *Sc = SE->getSCEV(Ptr); + + if (SE->isLoopInvariant(Sc, TheLoop)) { + DEBUG(dbgs() << "LV: Adding RT check for a loop invariant ptr:" << + *Ptr <<"\n"); + Starts.push_back(Ptr); + Ends.push_back(Ptr); + } else { + DEBUG(dbgs() << "LV: Adding RT check for range:" << *Ptr << '\n'); + unsigned AS = Ptr->getType()->getPointerAddressSpace(); + + // Use this type for pointer arithmetic. + Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS); + + Value *Start = Exp.expandCodeFor(PtrRtCheck.Starts[i], PtrArithTy, Loc); + Value *End = Exp.expandCodeFor(PtrRtCheck.Ends[i], PtrArithTy, Loc); + Starts.push_back(Start); + Ends.push_back(End); + } + } + + IRBuilder<> ChkBuilder(Loc); + // Our instructions might fold to a constant. + Value *MemoryRuntimeCheck = nullptr; + for (unsigned i = 0; i < NumPointers; ++i) { + for (unsigned j = i+1; j < NumPointers; ++j) { + // No need to check if two readonly pointers intersect. + if (!PtrRtCheck.IsWritePtr[i] && !PtrRtCheck.IsWritePtr[j]) + continue; + + // Only need to check pointers between two different dependency sets. + if (PtrRtCheck.DependencySetId[i] == PtrRtCheck.DependencySetId[j]) + continue; + // Only need to check pointers in the same alias set. + if (PtrRtCheck.AliasSetId[i] != PtrRtCheck.AliasSetId[j]) + continue; + + unsigned AS0 = Starts[i]->getType()->getPointerAddressSpace(); + unsigned AS1 = Starts[j]->getType()->getPointerAddressSpace(); + + assert((AS0 == Ends[j]->getType()->getPointerAddressSpace()) && + (AS1 == Ends[i]->getType()->getPointerAddressSpace()) && + "Trying to bounds check pointers with different address spaces"); + + Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0); + Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1); + + Value *Start0 = ChkBuilder.CreateBitCast(Starts[i], PtrArithTy0, "bc"); + Value *Start1 = ChkBuilder.CreateBitCast(Starts[j], PtrArithTy1, "bc"); + Value *End0 = ChkBuilder.CreateBitCast(Ends[i], PtrArithTy1, "bc"); + Value *End1 = ChkBuilder.CreateBitCast(Ends[j], PtrArithTy0, "bc"); + + Value *Cmp0 = ChkBuilder.CreateICmpULE(Start0, End1, "bound0"); + FirstInst = getFirstInst(FirstInst, Cmp0, Loc); + Value *Cmp1 = ChkBuilder.CreateICmpULE(Start1, End0, "bound1"); + FirstInst = getFirstInst(FirstInst, Cmp1, Loc); + Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); + FirstInst = getFirstInst(FirstInst, IsConflict, Loc); + if (MemoryRuntimeCheck) { + IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, + "conflict.rdx"); + FirstInst = getFirstInst(FirstInst, IsConflict, Loc); + } + MemoryRuntimeCheck = IsConflict; + } + } + + // We have to do this trickery because the IRBuilder might fold the check to a + // constant expression in which case there is no Instruction anchored in a + // the block. + Instruction *Check = BinaryOperator::CreateAnd(MemoryRuntimeCheck, + ConstantInt::getTrue(Ctx)); + ChkBuilder.Insert(Check, "memcheck.conflict"); + FirstInst = getFirstInst(FirstInst, Check, Loc); + return std::make_pair(FirstInst, Check); +} Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -271,13 +271,6 @@ typedef DenseMap, VectorParts> EdgeMaskCache; - /// \brief Add code that checks at runtime if the accessed arrays overlap. - /// - /// Returns a pair of instructions where the first element is the first - /// instruction generated in possibly a sequence of instructions and the - /// second value is the final comparator value or NULL if no check is needed. - std::pair addRuntimeCheck(Instruction *Loc); - /// \brief Add checks for strides that where assumed to be 1. /// /// Returns the last check instruction and the first check instruction in the @@ -751,6 +744,10 @@ return LAA.getRuntimePointerCheck(); } + LoopAccessAnalysis *getLAA() { + return &LAA; + } + /// This function returns the identity element (or neutral element) for /// the operation K. static Constant *getReductionIdentity(ReductionKind K, Type *Tp); @@ -2009,102 +2006,6 @@ return std::make_pair(FirstInst, TheCheck); } -std::pair -InnerLoopVectorizer::addRuntimeCheck(Instruction *Loc) { - LoopAccessAnalysis::RuntimePointerCheck *PtrRtCheck = - Legal->getRuntimePointerCheck(); - - Instruction *tnullptr = nullptr; - if (!PtrRtCheck->Need) - return std::pair(tnullptr, tnullptr); - - unsigned NumPointers = PtrRtCheck->Pointers.size(); - SmallVector , 2> Starts; - SmallVector , 2> Ends; - - LLVMContext &Ctx = Loc->getContext(); - SCEVExpander Exp(*SE, "induction"); - Instruction *FirstInst = nullptr; - - for (unsigned i = 0; i < NumPointers; ++i) { - Value *Ptr = PtrRtCheck->Pointers[i]; - const SCEV *Sc = SE->getSCEV(Ptr); - - if (SE->isLoopInvariant(Sc, OrigLoop)) { - DEBUG(dbgs() << "LV: Adding RT check for a loop invariant ptr:" << - *Ptr <<"\n"); - Starts.push_back(Ptr); - Ends.push_back(Ptr); - } else { - DEBUG(dbgs() << "LV: Adding RT check for range:" << *Ptr << '\n'); - unsigned AS = Ptr->getType()->getPointerAddressSpace(); - - // Use this type for pointer arithmetic. - Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS); - - Value *Start = Exp.expandCodeFor(PtrRtCheck->Starts[i], PtrArithTy, Loc); - Value *End = Exp.expandCodeFor(PtrRtCheck->Ends[i], PtrArithTy, Loc); - Starts.push_back(Start); - Ends.push_back(End); - } - } - - IRBuilder<> ChkBuilder(Loc); - // Our instructions might fold to a constant. - Value *MemoryRuntimeCheck = nullptr; - for (unsigned i = 0; i < NumPointers; ++i) { - for (unsigned j = i+1; j < NumPointers; ++j) { - // No need to check if two readonly pointers intersect. - if (!PtrRtCheck->IsWritePtr[i] && !PtrRtCheck->IsWritePtr[j]) - continue; - - // Only need to check pointers between two different dependency sets. - if (PtrRtCheck->DependencySetId[i] == PtrRtCheck->DependencySetId[j]) - continue; - // Only need to check pointers in the same alias set. - if (PtrRtCheck->AliasSetId[i] != PtrRtCheck->AliasSetId[j]) - continue; - - unsigned AS0 = Starts[i]->getType()->getPointerAddressSpace(); - unsigned AS1 = Starts[j]->getType()->getPointerAddressSpace(); - - assert((AS0 == Ends[j]->getType()->getPointerAddressSpace()) && - (AS1 == Ends[i]->getType()->getPointerAddressSpace()) && - "Trying to bounds check pointers with different address spaces"); - - Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0); - Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1); - - Value *Start0 = ChkBuilder.CreateBitCast(Starts[i], PtrArithTy0, "bc"); - Value *Start1 = ChkBuilder.CreateBitCast(Starts[j], PtrArithTy1, "bc"); - Value *End0 = ChkBuilder.CreateBitCast(Ends[i], PtrArithTy1, "bc"); - Value *End1 = ChkBuilder.CreateBitCast(Ends[j], PtrArithTy0, "bc"); - - Value *Cmp0 = ChkBuilder.CreateICmpULE(Start0, End1, "bound0"); - FirstInst = getFirstInst(FirstInst, Cmp0, Loc); - Value *Cmp1 = ChkBuilder.CreateICmpULE(Start1, End0, "bound1"); - FirstInst = getFirstInst(FirstInst, Cmp1, Loc); - Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); - FirstInst = getFirstInst(FirstInst, IsConflict, Loc); - if (MemoryRuntimeCheck) { - IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, - "conflict.rdx"); - FirstInst = getFirstInst(FirstInst, IsConflict, Loc); - } - MemoryRuntimeCheck = IsConflict; - } - } - - // We have to do this trickery because the IRBuilder might fold the check to a - // constant expression in which case there is no Instruction anchored in a - // the block. - Instruction *Check = BinaryOperator::CreateAnd(MemoryRuntimeCheck, - ConstantInt::getTrue(Ctx)); - ChkBuilder.Insert(Check, "memcheck.conflict"); - FirstInst = getFirstInst(FirstInst, Check, Loc); - return std::make_pair(FirstInst, Check); -} - void InnerLoopVectorizer::createEmptyLoop() { /* In this function we generate a new loop. The new loop will contain @@ -2329,7 +2230,7 @@ // faster. Instruction *MemRuntimeCheck; std::tie(FirstCheckInst, MemRuntimeCheck) = - addRuntimeCheck(LastBypassBlock->getTerminator()); + Legal->getLAA()->addRuntimeCheck(LastBypassBlock->getTerminator()); if (MemRuntimeCheck) { // Create a new block containing the memory check. BasicBlock *CheckBlock =