Index: include/llvm/Analysis/LoopAccessAnalysis.h =================================================================== --- include/llvm/Analysis/LoopAccessAnalysis.h +++ include/llvm/Analysis/LoopAccessAnalysis.h @@ -311,7 +311,7 @@ /// This struct holds information about the memory runtime legality check that /// a group of pointers do not overlap. struct RuntimePointerCheck { - RuntimePointerCheck() : Need(false) {} + RuntimePointerCheck(ScalarEvolution *SE) : Need(false), SE(SE) {} /// Reset the state of the pointer runtime information. void reset() { @@ -322,16 +322,35 @@ IsWritePtr.clear(); DependencySetId.clear(); AliasSetId.clear(); + Exprs.clear(); } /// Insert a pointer and calculate the start and end SCEVs. - void insert(ScalarEvolution *SE, Loop *Lp, Value *Ptr, bool WritePtr, - unsigned DepSetId, unsigned ASId, - const ValueToValueMap &Strides); + void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, + unsigned ASId, const ValueToValueMap &Strides); /// \brief No run-time memory checking is necessary. bool empty() const { return Pointers.empty(); } + /// A grouping a pointers. A single memcheck is required between + /// two groups. + struct CheckGroup { + CheckGroup() : High(0), Low(0) {} + /// Index of the pointer wich will represent the upper bound + /// of the memcheck. + unsigned High; + /// Index of the pointer wich will represent the lower bound + /// of the memcheck. + unsigned Low; + /// Indices of all the pointers that constitute this grouping. + SmallVector Members; + }; + + /// \brief Groups pointers pointers such that a single memcheck is required + /// between two different groups. + SmallVector + groupChecks(const SmallVectorImpl *PtrPartition) const; + /// \brief Decide whether we need to issue a run-time check for pointer at /// index \p I and \p J to prove their independence. /// @@ -341,6 +360,11 @@ bool needsChecking(unsigned I, unsigned J, const SmallVectorImpl *PtrPartition) const; + /// \brief Decide if we need to add a check between two groups of pointers, + /// according to needsChecking. + bool needsChecking(struct CheckGroup &M, struct CheckGroup &N, + const SmallVectorImpl *PtrPartition) const; + /// \brief Return true if any pointer requires run-time checking according /// to needsChecking. bool needsAnyChecking(const SmallVectorImpl *PtrPartition) const; @@ -372,6 +396,10 @@ SmallVector DependencySetId; /// Holds the id of the disjoint alias set to which this pointer belongs. SmallVector AliasSetId; + /// Holds at position i the SCEV for the access i + SmallVector Exprs; + /// Holds a pointer to the ScalarEvolution analysis. + ScalarEvolution *SE; }; LoopAccessInfo(Loop *L, ScalarEvolution *SE, const DataLayout &DL, Index: lib/Analysis/LoopAccessAnalysis.cpp =================================================================== --- lib/Analysis/LoopAccessAnalysis.cpp +++ lib/Analysis/LoopAccessAnalysis.cpp @@ -113,8 +113,8 @@ } void LoopAccessInfo::RuntimePointerCheck::insert( - ScalarEvolution *SE, Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, - unsigned ASId, const ValueToValueMap &Strides) { + Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, unsigned ASId, + const ValueToValueMap &Strides) { // Get the stride replaced scev. const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr); const SCEVAddRecExpr *AR = dyn_cast(Sc); @@ -127,6 +127,169 @@ IsWritePtr.push_back(WritePtr); DependencySetId.push_back(DepSetId); AliasSetId.push_back(ASId); + Exprs.push_back(Sc); +} + +bool LoopAccessInfo::RuntimePointerCheck::needsChecking( + struct CheckGroup &M, struct CheckGroup &N, + const SmallVectorImpl *PtrPartition) const { + for (unsigned I = 0, EI = M.Members.size(); EI != I; ++I) + for (unsigned J = 0, EJ = N.Members.size(); EJ != J; ++J) + if (needsChecking(M.Members[I], N.Members[J], PtrPartition)) + return true; + return false; +} + +SmallVector +LoopAccessInfo::RuntimePointerCheck::groupChecks( + const SmallVectorImpl *PtrPartition) const { + // We build the groups as we make a linear traversal of all the pointers. + // For each pointer, we find the first existing group that: + // - contains no other pointers which need to be checked against + // the current pointer. + // - the difference between this pointer and all other members + // of the group is a constant. To verify this property it is + // sufficient to only check against one member of the group. + // We will always compare against the same pointer (the first + // one added to the group) - this makes it easier to order the + // pointers. + // + // For each group we maintain the indeces of the highest + // and lowest pointer (which we will later need for constructing + // memory checks). We will also keep the (constant) value of the + // differences between the first added pointer in the group + // and the highest and lowest pointer. + + // Each pointer is part of an equivalence class. If two + // pointers are part of the same equivalence class, they will + // be part of the same group. + EquivalenceClasses CheckClass; + // For each pointer, holds the index to the first element + // that was added to the pointers equivalence class. We do this + // because the leader of an equivalence class can change, and + // would like to always compare against the same element. + SmallVector FirstInSetIndex; + // Holds the index of the pointer which has the highest 'End' + // value. This will form the upper bound of the groups range. + // This is only valid for an index I if SetLeaderIndex[I] == I. + SmallVector High; + // Same as the High vector. + SmallVector Low; + // Same as High, except we keep the maximum difference from the + // leader. + SmallVector HighConstant; + // Same as HighConstant, except we keep the minimum. + SmallVector LowConstant; + + for (unsigned Pointer = 0; Pointer < Pointers.size(); ++Pointer) { + const SCEV *Sc = Exprs[Pointer]; + const SCEVAddRecExpr *AR = static_cast(Sc); + + // Create a new set containing only this element. + CheckClass.getOrInsertLeaderValue(Pointer); + FirstInSetIndex.push_back(Pointer); + High.push_back(Pointer); + Low.push_back(Pointer); + + const SCEVConstant *Zero = + static_cast(SE->getConstant(AR->getType(), 0)); + HighConstant.push_back(Zero); + LowConstant.push_back(Zero); + + // If we don't need to add any check for this pointer + // then adding it to another pointer group wouldn't decrease + // the total number of memchecks that we have to do. + // Therefore it is better to leave it in its own group. + bool NeedsCheck = false; + for (unsigned I = 0; I < Pointers.size(); ++I) { + if (I != Pointer && needsChecking(Pointer, I, PtrPartition)) { + NeedsCheck = true; + break; + } + } + + if (!NeedsCheck) + continue; + + // Go through all the existing sets and see if we can find one + // which can include this pointer. + for (EquivalenceClasses::iterator EI = CheckClass.begin(), + EE = CheckClass.end(); + EI != EE; ++EI) { + if (!EI->isLeader()) + continue; + if (Pointer == EI->getData()) + continue; + unsigned FirstInSet = FirstInSetIndex[EI->getData()]; + + const SCEV *Old = Exprs[FirstInSet]; + if (Old->getType() != AR->getType()) + continue; + const SCEV *Diff = SE->getMinusSCEV(AR, Old); + const SCEVConstant *C = dyn_cast(Diff); + if (!C) + continue; + + EquivalenceClasses::member_iterator AI, AE; + AI = CheckClass.member_begin(EI), AE = CheckClass.member_end(); + + // Merging the check this pointer into this equivalence class means that + // we won't be able to memcheck the pointer accesses against any other + // in the class. Therefore we need to first make sure that the excluded + // checks are not required. + bool Valid = true; + while (AI != AE) { + const unsigned Index = *AI; + if (needsChecking(Index, Pointer, PtrPartition)) { + Valid = false; + break; + } + AI++; + } + if (!Valid) + continue; + + // We're adding a new element to the set. Update the maximum + // and minimum differences. + if (C->getValue()->getValue().sgt( + HighConstant[FirstInSet]->getValue()->getValue())) { + HighConstant[FirstInSet] = C; + High[FirstInSet] = Pointer; + } + if (C->getValue()->getValue().slt( + LowConstant[FirstInSet]->getValue()->getValue())) { + LowConstant[FirstInSet] = C; + Low[FirstInSet] = Pointer; + } + + // Make the new element have point to the same 'first in set' + // poitner as the rest of the set. + FirstInSetIndex[Pointer] = FirstInSetIndex[FirstInSet]; + CheckClass.unionSets(Pointer, FirstInSet); + break; + } + } + + SmallVector GroupedChecks; + for (EquivalenceClasses::iterator I = CheckClass.begin(), + E = CheckClass.end(); + I != E; ++I) { + if (!I->isLeader()) + continue; + unsigned FirstInSet = FirstInSetIndex[I->getData()]; + // Add all members of this equivalence class to the memcheck + EquivalenceClasses::member_iterator AI, AE; + AI = CheckClass.member_begin(I), AE = CheckClass.member_end(); + + GroupedChecks.push_back(CheckGroup()); + GroupedChecks.back().High = High[FirstInSet]; + GroupedChecks.back().Low = Low[FirstInSet]; + while (AI != AE) { + GroupedChecks.back().Members.push_back(*AI); + ++AI; + } + } + return GroupedChecks; } bool LoopAccessInfo::RuntimePointerCheck::needsChecking( @@ -156,42 +319,77 @@ void LoopAccessInfo::RuntimePointerCheck::print( raw_ostream &OS, unsigned Depth, const SmallVectorImpl *PtrPartition) const { - unsigned NumPointers = Pointers.size(); - if (NumPointers == 0) - return; OS.indent(Depth) << "Run-time memory checks:\n"; unsigned N = 0; - for (unsigned I = 0; I < NumPointers; ++I) - for (unsigned J = I + 1; J < NumPointers; ++J) - if (needsChecking(I, J, PtrPartition)) { - OS.indent(Depth) << N++ << ":\n"; - OS.indent(Depth + 2) << *Pointers[I]; - if (PtrPartition) - OS << " (Partition: " << (*PtrPartition)[I] << ")"; - OS << "\n"; - OS.indent(Depth + 2) << *Pointers[J]; - if (PtrPartition) - OS << " (Partition: " << (*PtrPartition)[J] << ")"; - OS << "\n"; + + SmallVector GroupedChecks; + GroupedChecks = groupChecks(PtrPartition); + + unsigned NumGroups = GroupedChecks.size(); + if (NumGroups == 0) + return; + + for (unsigned I = 0; I < NumGroups; ++I) + for (unsigned J = I + 1; J < NumGroups; ++J) + if (needsChecking(GroupedChecks[I], GroupedChecks[J], PtrPartition)) { + OS.indent(Depth) << "Check " << N++ << ":\n"; + + for (unsigned K = 0; K < GroupedChecks[I].Members.size(); ++K) { + OS.indent(Depth + 2) << *Pointers[GroupedChecks[I].Members[K]] + << "\n"; + if (PtrPartition) + OS << " (Partition: " + << (*PtrPartition)[GroupedChecks[I].Members[K]] << ")" + << "\n"; + } + + for (unsigned K = 0; K < GroupedChecks[J].Members.size(); ++K) { + OS.indent(Depth + 2) << *Pointers[GroupedChecks[J].Members[K]] + << "\n"; + if (PtrPartition) + OS << " (Partition: " + << (*PtrPartition)[GroupedChecks[J].Members[K]] << ")" + << "\n"; + } } + + OS.indent(Depth) << "Grouped accesses:\n"; + for (unsigned I = 0; I < NumGroups; ++I) { + OS.indent(Depth + 2) << "Group " << I << ":\n"; + OS.indent(Depth + 4) << "(Low: " << *Starts[GroupedChecks[I].Low] + << " High: " << *Ends[GroupedChecks[I].High] << ")\n"; + for (unsigned J = 0; J < GroupedChecks[I].Members.size(); ++J) { + OS.indent(Depth + 6) << "Member: " << *Exprs[GroupedChecks[I].Members[J]] + << "\n"; + } + } } unsigned LoopAccessInfo::RuntimePointerCheck::getNumberOfChecks( const SmallVectorImpl *PtrPartition) const { - unsigned NumPointers = Pointers.size(); + SmallVector GroupedChecks; + GroupedChecks = groupChecks(PtrPartition); + + unsigned NumPartitions = GroupedChecks.size(); unsigned CheckCount = 0; - for (unsigned I = 0; I < NumPointers; ++I) - for (unsigned J = I + 1; J < NumPointers; ++J) - if (needsChecking(I, J, PtrPartition)) + for (unsigned I = 0; I < NumPartitions; ++I) + for (unsigned J = I + 1; J < NumPartitions; ++J) + if (needsChecking(GroupedChecks[I], GroupedChecks[J], PtrPartition)) CheckCount++; return CheckCount; } bool LoopAccessInfo::RuntimePointerCheck::needsAnyChecking( const SmallVectorImpl *PtrPartition) const { - return getNumberOfChecks(PtrPartition) != 0; + unsigned NumPointers = Pointers.size(); + + for (unsigned I = 0; I < NumPointers; ++I) + for (unsigned J = I + 1; J < NumPointers; ++J) + if (needsChecking(I, J, PtrPartition)) + return true; + return false; } namespace { @@ -341,7 +539,7 @@ // Each access has its own dependence set. DepId = RunningDepId++; - RtCheck.insert(SE, TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap); + RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap); DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); } else { @@ -1312,32 +1510,40 @@ if (!PtrRtCheck.Need) return std::make_pair(nullptr, nullptr); - unsigned NumPointers = PtrRtCheck.Pointers.size(); - SmallVector , 2> Starts; - SmallVector , 2> Ends; + SmallVector, 2> Starts; + SmallVector, 2> Ends; LLVMContext &Ctx = Loc->getContext(); SCEVExpander Exp(*SE, DL, "induction"); Instruction *FirstInst = nullptr; - for (unsigned i = 0; i < NumPointers; ++i) { - Value *Ptr = PtrRtCheck.Pointers[i]; + SmallVector GroupedChecks = + PtrRtCheck.groupChecks(PtrPartition); + + for (unsigned i = 0; i < GroupedChecks.size(); ++i) { + Value *Ptr = PtrRtCheck.Pointers[GroupedChecks[i].Members[0]]; const SCEV *Sc = SE->getSCEV(Ptr); if (SE->isLoopInvariant(Sc, TheLoop)) { - DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" << - *Ptr <<"\n"); + DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" << *Ptr + << "\n"); Starts.push_back(Ptr); Ends.push_back(Ptr); } else { - DEBUG(dbgs() << "LAA: Adding RT check for range:" << *Ptr << '\n'); unsigned AS = Ptr->getType()->getPointerAddressSpace(); // Use this type for pointer arithmetic. Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS); + Value *Start = nullptr, *End = nullptr; - Value *Start = Exp.expandCodeFor(PtrRtCheck.Starts[i], PtrArithTy, Loc); - Value *End = Exp.expandCodeFor(PtrRtCheck.Ends[i], PtrArithTy, Loc); + DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); + Start = Exp.expandCodeFor(PtrRtCheck.Starts[GroupedChecks[i].Low], + PtrArithTy, Loc); + End = Exp.expandCodeFor(PtrRtCheck.Ends[GroupedChecks[i].High], + PtrArithTy, Loc); + DEBUG(dbgs() << "Start: " << *PtrRtCheck.Starts[GroupedChecks[i].Low] + << " End: " << *PtrRtCheck.Ends[GroupedChecks[i].High] + << "\n"); Starts.push_back(Start); Ends.push_back(End); } @@ -1346,9 +1552,10 @@ IRBuilder<> ChkBuilder(Loc); // Our instructions might fold to a constant. Value *MemoryRuntimeCheck = nullptr; - for (unsigned i = 0; i < NumPointers; ++i) { - for (unsigned j = i+1; j < NumPointers; ++j) { - if (!PtrRtCheck.needsChecking(i, j, PtrPartition)) + for (unsigned i = 0; i < GroupedChecks.size(); ++i) { + for (unsigned j = i + 1; j < GroupedChecks.size(); ++j) { + if (!PtrRtCheck.needsChecking(GroupedChecks[i], GroupedChecks[j], + PtrPartition)) continue; unsigned AS0 = Starts[i]->getType()->getPointerAddressSpace(); @@ -1399,8 +1606,8 @@ const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, const ValueToValueMap &Strides) - : DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL), - TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), + : PtrRtCheck(SE), DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL), TLI(TLI), + AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1U), CanVecMem(false), StoreToLoopInvariantAddress(false) { if (canAnalyzeLoop()) Index: test/Analysis/LoopAccessAnalysis/number-of-memchecks.ll =================================================================== --- test/Analysis/LoopAccessAnalysis/number-of-memchecks.ll +++ test/Analysis/LoopAccessAnalysis/number-of-memchecks.ll @@ -1,19 +1,20 @@ ; RUN: opt -loop-accesses -analyze < %s | FileCheck %s -; 3 reads and 3 writes should need 12 memchecks - target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" target triple = "aarch64--linux-gnueabi" +; 3 reads and 3 writes should need 12 memchecks +; CHECK: function 'testf': ; CHECK: Memory dependences are safe with run-time checks + ; Memory dependecies have labels starting from 0, so in ; order to verify that we have n checks, we look for -; (n-1): and not n:. +; (n-1): and not n. ; CHECK: Run-time memory checks: -; CHECK-NEXT: 0: -; CHECK: 11: -; CHECK-NOT: 12: +; CHECK-NEXT: Check 0: +; CHECK: Check 11: +; CHECK-NOT: Check 12: define void @testf(i16* %a, i16* %b, @@ -56,3 +57,128 @@ for.end: ; preds = %for.body ret void } + +; The following (testg and testh) check that we can group +; memory checks of accesses which differ by a constant value. +; Both tests are based on the following C code: +; +; void testh(short *a, short *b, short *c) { +; unsigned long ind = 0; +; for (unsigned long ind = 0; ind < 20; ++ind) { +; c[2 * ind] = a[ind] * a[ind + 1]; +; c[2 * ind + 1] = a[ind] * a[ind + 1] * b[ind]; +; } +; } +; +; It is sufficient to check the intervals +; [a, a + 21], [b, b + 20] against [c, c + 41]. + +; 3 reads and 2 writes - two of the reads can be merged, +; and the writes can be merged as well. This gives us a +; total of 2 memory checks. + +; CHECK: function 'testg': + +; CHECK: Run-time memory checks: +; CHECK-NEXT: Check 0: +; CHECK: Check 1: +; CHECK-NOT: Check 2: +; CHECK: Group 0: +; CHECK: Group 1: +; CHECK: Group 2: +; CHECK-NOT: Group 3: + +define void @testg(i16* %a, + i16* %b, + i16* %c) { +entry: + br label %for.body + +for.body: ; preds = %for.body, %entry + %ind = phi i64 [ 0, %entry ], [ %add, %for.body ] + %store_ind = phi i64 [ 0, %entry ], [ %store_ind_next, %for.body ] + + %add = add nuw nsw i64 %ind, 1 + %store_ind_inc = add nuw nsw i64 %store_ind, 1 + %store_ind_next = add nuw nsw i64 %store_ind_inc, 1 + + %arrayidxA = getelementptr inbounds i16, i16* %a, i64 %ind + %loadA = load i16, i16* %arrayidxA, align 2 + + %arrayidxA1 = getelementptr inbounds i16, i16* %a, i64 %add + %loadA1 = load i16, i16* %arrayidxA1, align 2 + + %arrayidxB = getelementptr inbounds i16, i16* %b, i64 %ind + %loadB = load i16, i16* %arrayidxB, align 2 + + %mul = mul i16 %loadA, %loadA1 + %mul1 = mul i16 %mul, %loadB + + %arrayidxC = getelementptr inbounds i16, i16* %c, i64 %store_ind + store i16 %mul1, i16* %arrayidxC, align 2 + + %arrayidxC1 = getelementptr inbounds i16, i16* %c, i64 %store_ind_inc + store i16 %mul, i16* %arrayidxC1, align 2 + + %exitcond = icmp eq i64 %add, 20 + br i1 %exitcond, label %for.end, label %for.body + +for.end: ; preds = %for.body + ret void +} + +; 3 reads and 2 writes - the writes can be merged into a single +; group, but the GEPs used for the reads are not marked as inbounds. +; We cam still merge them because we are using a unit stride for +; accesses, so we cannot overflow the GEPs. + +; CHECK: function 'testh': + +; CHECK: Run-time memory checks: +; CHECK-NEXT: Check 0: +; CHECK: Check 1: +; CHECK-NOT: Check 2: +; CHECK: Group 0: +; CHECK: Group 1: +; CHECK: Group 2: +; CHECK-NOT: Group 3: + + +define void @testh(i16* %a, + i16* %b, + i16* %c) { +entry: + br label %for.body + +for.body: ; preds = %for.body, %entry + %ind = phi i64 [ 0, %entry ], [ %add, %for.body ] + %store_ind = phi i64 [ 0, %entry ], [ %store_ind_next, %for.body ] + + %add = add nuw nsw i64 %ind, 1 + %store_ind_inc = add nuw nsw i64 %store_ind, 1 + %store_ind_next = add nuw nsw i64 %store_ind_inc, 1 + + %arrayidxA = getelementptr i16, i16* %a, i64 %ind + %loadA = load i16, i16* %arrayidxA, align 2 + + %arrayidxA1 = getelementptr i16, i16* %a, i64 %add + %loadA1 = load i16, i16* %arrayidxA1, align 2 + + %arrayidxB = getelementptr i16, i16* %b, i64 %ind + %loadB = load i16, i16* %arrayidxB, align 2 + + %mul = mul i16 %loadA, %loadA1 + %mul1 = mul i16 %mul, %loadB + + %arrayidxC = getelementptr inbounds i16, i16* %c, i64 %store_ind + store i16 %mul1, i16* %arrayidxC, align 2 + + %arrayidxC1 = getelementptr inbounds i16, i16* %c, i64 %store_ind_inc + store i16 %mul, i16* %arrayidxC1, align 2 + + %exitcond = icmp eq i64 %add, 20 + br i1 %exitcond, label %for.end, label %for.body + +for.end: ; preds = %for.body + ret void +} Index: test/Transforms/LoopDistribute/basic-with-memchecks.ll =================================================================== --- test/Transforms/LoopDistribute/basic-with-memchecks.ll +++ test/Transforms/LoopDistribute/basic-with-memchecks.ll @@ -32,8 +32,9 @@ %e = load i32*, i32** @E, align 8 br label %for.body -; We have two compares for each array overlap check which is a total of 10 -; compares. +; We have two compares for each array overlap check. +; Since the checks to A and A + 4 get merged, this will give us a +; total of 8 compares. ; ; CHECK: for.body.ldist.memcheck: ; CHECK: = icmp @@ -48,9 +49,6 @@ ; CHECK: = icmp ; CHECK: = icmp -; CHECK: = icmp -; CHECK: = icmp - ; CHECK-NOT: = icmp ; CHECK: br i1 %memcheck.conflict, label %for.body.ph.ldist.nondist, label %for.body.ph.ldist1