Index: include/llvm/Analysis/LoopAccessAnalysis.h =================================================================== --- include/llvm/Analysis/LoopAccessAnalysis.h +++ include/llvm/Analysis/LoopAccessAnalysis.h @@ -345,6 +345,10 @@ /// to needsChecking. bool needsAnyChecking(const SmallVectorImpl *PtrPartition) const; + /// \brief Returns the number of run-time checks required according to + /// needsChecking. + unsigned getNumberOfChecks(const SmallVectorImpl *PtrPartition) const; + /// \brief Print the list run-time memory checks necessary. /// /// If \p PtrPartition is set, it contains the partition number for @@ -385,7 +389,10 @@ /// \brief Number of memchecks required to prove independence of otherwise /// may-alias pointers. - unsigned getNumRuntimePointerChecks() const { return NumComparisons; } + unsigned getNumRuntimePointerChecks( + const SmallVectorImpl *PtrPartition = nullptr) const { + return PtrRtCheck.getNumberOfChecks(PtrPartition); + } /// Return true if the block BB needs to be predicated in order for the loop /// to be vectorized. @@ -460,10 +467,6 @@ /// loop-independent and loop-carried dependences between memory accesses. MemoryDepChecker DepChecker; - /// \brief Number of memchecks required to prove independence of otherwise - /// may-alias pointers - unsigned NumComparisons; - Loop *TheLoop; ScalarEvolution *SE; const DataLayout &DL; Index: lib/Analysis/LoopAccessAnalysis.cpp =================================================================== --- lib/Analysis/LoopAccessAnalysis.cpp +++ lib/Analysis/LoopAccessAnalysis.cpp @@ -177,15 +177,21 @@ } } -bool LoopAccessInfo::RuntimePointerCheck::needsAnyChecking( +unsigned LoopAccessInfo::RuntimePointerCheck::getNumberOfChecks( const SmallVectorImpl *PtrPartition) const { unsigned NumPointers = Pointers.size(); + unsigned CheckCount = 0; for (unsigned I = 0; I < NumPointers; ++I) for (unsigned J = I + 1; J < NumPointers; ++J) if (needsChecking(I, J, PtrPartition)) - return true; - return false; + CheckCount++; + return CheckCount; +} + +bool LoopAccessInfo::RuntimePointerCheck::needsAnyChecking( + const SmallVectorImpl *PtrPartition) const { + return getNumberOfChecks(PtrPartition) != 0; } namespace { @@ -220,10 +226,11 @@ } /// \brief Check whether we can check the pointers at runtime for - /// non-intersection. + /// non-intersection. Returns true when we have 0 pointers + /// (a check on 0 pointers for non-intersection will always return true). bool canCheckPtrAtRT(LoopAccessInfo::RuntimePointerCheck &RtCheck, - unsigned &NumComparisons, ScalarEvolution *SE, - Loop *TheLoop, const ValueToValueMap &Strides, + bool &NeedRTCheck, ScalarEvolution *SE, Loop *TheLoop, + const ValueToValueMap &Strides, bool ShouldCheckStride = false); /// \brief Goes over all memory accesses, checks whether a RT check is needed @@ -290,23 +297,22 @@ } bool AccessAnalysis::canCheckPtrAtRT( - LoopAccessInfo::RuntimePointerCheck &RtCheck, unsigned &NumComparisons, + LoopAccessInfo::RuntimePointerCheck &RtCheck, bool &NeedRTCheck, ScalarEvolution *SE, Loop *TheLoop, const ValueToValueMap &StridesMap, bool ShouldCheckStride) { // Find pointers with computable bounds. We are going to use this information // to place a runtime bound check. bool CanDoRT = true; + NeedRTCheck = false; + if (!IsRTCheckNeeded) return true; + bool IsDepCheckNeeded = isDependencyCheckNeeded(); - NumComparisons = 0; // We assign a consecutive id to access from different alias sets. // Accesses between different groups doesn't need to be checked. unsigned ASId = 1; for (auto &AS : AST) { - unsigned NumReadPtrChecks = 0; - unsigned NumWritePtrChecks = 0; - // We assign consecutive id to access from different dependence sets. // Accesses within the same set don't need a runtime check. unsigned RunningDepId = 1; @@ -317,11 +323,6 @@ bool IsWrite = Accesses.count(MemAccessInfo(Ptr, true)); MemAccessInfo Access(Ptr, IsWrite); - if (IsWrite) - ++NumWritePtrChecks; - else - ++NumReadPtrChecks; - if (hasComputableBounds(SE, StridesMap, Ptr) && // When we run after a failing dependency check we have to make sure // we don't have wrapping pointers. @@ -349,16 +350,15 @@ } } - if (IsDepCheckNeeded && CanDoRT && RunningDepId == 2) - NumComparisons += 0; // Only one dependence set. - else { - NumComparisons += (NumWritePtrChecks * (NumReadPtrChecks + - NumWritePtrChecks - 1)); - } - ++ASId; } + // We need a runtime check if there are any accesses that need checking. + // However, some accesses cannot be checked (for example because we + // can't determine their bounds). In these cases we would need a check + // but wouldn't be able to add it. + NeedRTCheck = !CanDoRT || RtCheck.needsAnyChecking(nullptr); + // If the pointers that we would use for the bounds comparison have different // address spaces, assume the values aren't directly comparable, so we can't // use them for the runtime check. We also have to assume they could @@ -1207,22 +1207,17 @@ // Build dependence sets and check whether we need a runtime pointer bounds // check. Accesses.buildDependenceSets(); - bool NeedRTCheck = Accesses.isRTCheckNeeded(); // Find pointers with computable bounds. We are going to use this information // to place a runtime bound check. - bool CanDoRT = false; - if (NeedRTCheck) - CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, NumComparisons, SE, TheLoop, - Strides); - - DEBUG(dbgs() << "LAA: We need to do " << NumComparisons << - " pointer comparisons.\n"); - - // If we only have one set of dependences to check pointers among we don't - // need a runtime check. - if (NumComparisons == 0 && NeedRTCheck) - NeedRTCheck = false; + bool NeedRTCheck; + bool CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, + NeedRTCheck, SE, + TheLoop, Strides); + + DEBUG(dbgs() << "LAA: We need to do " + << PtrRtCheck.getNumberOfChecks(nullptr) + << " pointer comparisons.\n"); // Check that we found the bounds for the pointer. if (CanDoRT) @@ -1255,10 +1250,11 @@ PtrRtCheck.reset(); PtrRtCheck.Need = true; - CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, NumComparisons, SE, + CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, NeedRTCheck, SE, TheLoop, Strides, true); + // Check that we found the bounds for the pointer. - if (!CanDoRT && NumComparisons > 0) { + if (NeedRTCheck && !CanDoRT) { emitAnalysis(LoopAccessReport() << "cannot check memory dependencies at runtime"); DEBUG(dbgs() << "LAA: Can't vectorize with memory checks\n"); @@ -1403,7 +1399,7 @@ const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, const ValueToValueMap &Strides) - : DepChecker(SE, L), NumComparisons(0), TheLoop(L), SE(SE), DL(DL), + : DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL), TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1U), CanVecMem(false), StoreToLoopInvariantAddress(false) { Index: test/Analysis/LoopAccessAnalysis/number-of-memchecks.ll =================================================================== --- /dev/null +++ test/Analysis/LoopAccessAnalysis/number-of-memchecks.ll @@ -0,0 +1,58 @@ +; RUN: opt -loop-accesses -analyze < %s | FileCheck %s + +; 3 reads and 3 writes should need 12 memchecks + +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64--linux-gnueabi" + +; CHECK: Memory dependences are safe with run-time checks +; Memory dependecies have labels starting from 0, so in +; order to verify that we have n checks, we look for +; (n-1): and not n:. + +; CHECK: Run-time memory checks: +; CHECK-NEXT: 0: +; CHECK: 11: +; CHECK-NOT: 12: + +define void @testf(i16* %a, + i16* %b, + i16* %c, + i16* %d, + i16* %e, + i16* %f) { +entry: + br label %for.body + +for.body: ; preds = %for.body, %entry + %ind = phi i64 [ 0, %entry ], [ %add, %for.body ] + + %add = add nuw nsw i64 %ind, 1 + + %arrayidxA = getelementptr inbounds i16, i16* %a, i64 %ind + %loadA = load i16, i16* %arrayidxA, align 2 + + %arrayidxB = getelementptr inbounds i16, i16* %b, i64 %ind + %loadB = load i16, i16* %arrayidxB, align 2 + + %arrayidxC = getelementptr inbounds i16, i16* %c, i64 %ind + %loadC = load i16, i16* %arrayidxC, align 2 + + %mul = mul i16 %loadB, %loadA + %mul1 = mul i16 %mul, %loadC + + %arrayidxD = getelementptr inbounds i16, i16* %d, i64 %ind + store i16 %mul1, i16* %arrayidxD, align 2 + + %arrayidxE = getelementptr inbounds i16, i16* %e, i64 %ind + store i16 %mul, i16* %arrayidxE, align 2 + + %arrayidxF = getelementptr inbounds i16, i16* %f, i64 %ind + store i16 %mul1, i16* %arrayidxF, align 2 + + %exitcond = icmp eq i64 %add, 20 + br i1 %exitcond, label %for.end, label %for.body + +for.end: ; preds = %for.body + ret void +}