Index: include/llvm/Analysis/LoopAccessAnalysis.h =================================================================== --- include/llvm/Analysis/LoopAccessAnalysis.h +++ include/llvm/Analysis/LoopAccessAnalysis.h @@ -163,7 +163,7 @@ }; MemoryDepChecker(PredicatedScalarEvolution &PSE, const Loop *L) - : PSE(PSE), InnermostLoop(L), AccessIdx(0), + : PSE(PSE), InnermostLoop(L), AccessIdx(0), MaxSafeRegisterWidth(-1U), ShouldRetryWithRuntimeCheck(false), SafeForVectorization(true), RecordDependences(true) {} @@ -199,6 +199,10 @@ /// the accesses safely with. uint64_t getMaxSafeDepDistBytes() { return MaxSafeDepDistBytes; } + /// \brief Return the number of elements that are safe to operate on + /// simultaneously, multiplied by the size of the element in bits. + uint64_t getMaxSafeRegisterWidth() const { return MaxSafeRegisterWidth; } + /// \brief In same cases when the dependency check fails we can still /// vectorize the loop with a dynamic array access check. bool shouldRetryWithRuntimeCheck() { return ShouldRetryWithRuntimeCheck; } @@ -255,6 +259,12 @@ // We can access this many bytes in parallel safely. uint64_t MaxSafeDepDistBytes; + /// \brief Number of elements (from consecutive iterations) that are safe to + /// operate on simultaneously, multiplied by the size of the element in bits. + /// The size of the element is taken from the memory access that is most + /// restrictive. + uint64_t MaxSafeRegisterWidth; + /// \brief If we see a non-constant dependence distance we can still try to /// vectorize this loop with runtime checks. bool ShouldRetryWithRuntimeCheck; Index: lib/Analysis/LoopAccessAnalysis.cpp =================================================================== --- lib/Analysis/LoopAccessAnalysis.cpp +++ lib/Analysis/LoopAccessAnalysis.cpp @@ -1471,10 +1471,11 @@ couldPreventStoreLoadForward(Distance, TypeByteSize)) return Dependence::BackwardVectorizableButPreventsForwarding; + uint64_t MaxVF = MaxSafeDepDistBytes / (TypeByteSize * Stride); DEBUG(dbgs() << "LAA: Positive distance " << Val.getSExtValue() - << " with max VF = " - << MaxSafeDepDistBytes / (TypeByteSize * Stride) << '\n'); - + << " with max VF = " << MaxVF << '\n'); + uint64_t MaxVFInBits = MaxVF * TypeByteSize * 8; + MaxSafeRegisterWidth = std::min(MaxSafeRegisterWidth, MaxVFInBits); return Dependence::BackwardVectorizable; } Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -963,14 +963,6 @@ return InterleaveGroupMap.count(Instr); } - /// \brief Return the maximum interleave factor of all interleaved groups. - unsigned getMaxInterleaveFactor() const { - unsigned MaxFactor = 1; - for (auto &Entry : InterleaveGroupMap) - MaxFactor = std::max(MaxFactor, Entry.second->getFactor()); - return MaxFactor; - } - /// \brief Get the interleave group that \p Instr belongs to. /// /// \returns nullptr if doesn't have such group. @@ -1553,11 +1545,6 @@ return InterleaveInfo.isInterleaved(Instr); } - /// \brief Return the maximum interleave factor of all interleaved groups. - unsigned getMaxInterleaveFactor() const { - return InterleaveInfo.getMaxInterleaveFactor(); - } - /// \brief Get the interleaved access group that \p Instr belongs to. const InterleaveGroup *getInterleavedAccessGroup(Instruction *Instr) { return InterleaveInfo.getInterleaveGroup(Instr); @@ -1571,6 +1558,10 @@ unsigned getMaxSafeDepDistBytes() { return LAI->getMaxSafeDepDistBytes(); } + uint64_t getMaxSafeRegisterWidth() const { + return LAI->getDepChecker().getMaxSafeRegisterWidth(); + } + bool hasStride(Value *V) { return LAI->hasStride(V); } /// Returns true if the target machine supports masked store operation @@ -6077,9 +6068,11 @@ // Remove interleaved store groups with gaps. for (InterleaveGroup *Group : StoreGroups) - if (Group->getNumMembers() != Group->getFactor()) + if (Group->getNumMembers() != Group->getFactor()) { + DEBUG(dbgs() << "LV: Invalidate candidate interleaved store group due " + "to gaps.\n"); releaseGroup(Group); - + } // Remove interleaved groups with gaps (currently only loads) whose memory // accesses may wrap around. We have to revisit the getPtrStride analysis, // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does @@ -6132,6 +6125,8 @@ // to look for a member at index factor - 1, since every group must have // a member at index zero. if (Group->isReverse()) { + DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to " + "a reverse access with gaps.\n"); releaseGroup(Group); continue; } @@ -6215,25 +6210,21 @@ unsigned SmallestType, WidestType; std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes(); unsigned WidestRegister = TTI.getRegisterBitWidth(true); - unsigned MaxSafeDepDist = -1U; - // Get the maximum safe dependence distance in bits computed by LAA. If the - // loop contains any interleaved accesses, we divide the dependence distance - // by the maximum interleave factor of all interleaved groups. Note that - // although the division ensures correctness, this is a fairly conservative - // computation because the maximum distance computed by LAA may not involve - // any of the interleaved accesses. - if (Legal->getMaxSafeDepDistBytes() != -1U) - MaxSafeDepDist = - Legal->getMaxSafeDepDistBytes() * 8 / Legal->getMaxInterleaveFactor(); + // Get the maximum safe dependence distance in bits computed by LAA. + // It is computed by MaxVF * sizeOf(type) * 8, where type is taken from + // the memory accesses that is most restrictive (involved in the smallest + // dependence distance). + unsigned MaxSafeRegisterWidth = Legal->getMaxSafeRegisterWidth(); WidestRegister = - ((WidestRegister < MaxSafeDepDist) ? WidestRegister : MaxSafeDepDist); + ((WidestRegister < MaxSafeRegisterWidth) ? WidestRegister + : MaxSafeRegisterWidth); unsigned MaxVectorSize = WidestRegister / WidestType; DEBUG(dbgs() << "LV: The Smallest and Widest types: " << SmallestType << " / " << WidestType << " bits.\n"); - DEBUG(dbgs() << "LV: The Widest register is: " << WidestRegister + DEBUG(dbgs() << "LV: The Widest register safe to use is: " << WidestRegister << " bits.\n"); if (MaxVectorSize == 0) { Index: test/Transforms/LoopVectorize/memdep.ll =================================================================== --- test/Transforms/LoopVectorize/memdep.ll +++ test/Transforms/LoopVectorize/memdep.ll @@ -1,5 +1,7 @@ ; RUN: opt < %s -loop-vectorize -force-vector-width=2 -force-vector-interleave=1 -S | FileCheck %s ; RUN: opt < %s -loop-vectorize -force-vector-width=4 -force-vector-interleave=1 -S | FileCheck %s -check-prefix=WIDTH +; RUN: opt -S -loop-vectorize -force-vector-width=4 < %s | FileCheck %s -check-prefix=RIGHTVF +; RUN: opt -S -loop-vectorize -force-vector-width=8 < %s | FileCheck %s -check-prefix=WRONGVF target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128" @@ -220,3 +222,52 @@ for.end: ret void } + + +;Check the new calculation of the maximum safe distance in bits which can be vectorized. +;The previous behavior did not take account that the stride was 2. +;Therefore the maxVF was computed as 8 instead of 4. + +;#define M 32 +;#define N 2 * M +;unsigned int a [N]; +;void pr34283(){ +; unsigned int j=0; +; for (j = 0; j < M - 6; ++j) +; { +; a[N - 2 * j] = 69; +; a[N - 12 - 2 * j] = 7; +; } +; +;} + +; RIGHTVF-LABEL: @pr34283 +; RIGHTVF: <4 x i64> + +; WRONGVF-LABLE: @pr34283 +; WRONGVF-NOT: <8 x i64> + +@a = common local_unnamed_addr global [64 x i32] zeroinitializer, align 16 + +; Function Attrs: norecurse nounwind uwtable +define void @pr34283() local_unnamed_addr { +entry: + br label %for.body + +for.body: + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + %0 = shl i64 %indvars.iv, 1 + %1 = sub nuw nsw i64 64, %0 + %arrayidx = getelementptr inbounds [64 x i32], [64 x i32]* @a, i64 0, i64 %1 + store i32 69, i32* %arrayidx, align 8 + %2 = sub nuw nsw i64 52, %0 + %arrayidx4 = getelementptr inbounds [64 x i32], [64 x i32]* @a, i64 0, i64 %2 + store i32 7, i32* %arrayidx4, align 8 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv.next, 26 + br i1 %exitcond, label %for.end, label %for.body + +for.end: + ret void +} +