Index: llvm/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolution.h +++ llvm/include/llvm/Analysis/ScalarEvolution.h @@ -844,7 +844,8 @@ /// (at every loop iteration). It is, at the same time, the minimum number /// of times S is divisible by 2. For example, given {4,+,8} it returns 2. /// If S is guaranteed to be 0, it returns the bitwidth of S. - uint32_t GetMinTrailingZeros(const SCEV *S); + uint32_t GetMinTrailingZeros(const SCEV *S, + const Instruction *CxtI = nullptr); /// Determine the unsigned range for a particular SCEV. /// NOTE: This returns a copy of the reference returned by getRangeRef. @@ -1280,7 +1281,7 @@ SetVector *getSCEVValues(const SCEV *S); /// Private helper method for the GetMinTrailingZeros method - uint32_t GetMinTrailingZerosImpl(const SCEV *S); + uint32_t GetMinTrailingZerosImpl(const SCEV *S, const Instruction *CxtI); /// Information about the number of loop iterations for which a loop exit's /// branch condition evaluates to the not-taken path. This is a temporary Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -5522,26 +5522,27 @@ return getGEPExpr(GEP, IndexExprs); } -uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) { +uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S, + const Instruction *CxtI) { if (const SCEVConstant *C = dyn_cast(S)) return C->getAPInt().countTrailingZeros(); if (const SCEVPtrToIntExpr *I = dyn_cast(S)) - return GetMinTrailingZeros(I->getOperand()); + return GetMinTrailingZeros(I->getOperand(), CxtI); if (const SCEVTruncateExpr *T = dyn_cast(S)) - return std::min(GetMinTrailingZeros(T->getOperand()), + return std::min(GetMinTrailingZeros(T->getOperand(), CxtI), (uint32_t)getTypeSizeInBits(T->getType())); if (const SCEVZeroExtendExpr *E = dyn_cast(S)) { - uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), CxtI); return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? getTypeSizeInBits(E->getType()) : OpRes; } if (const SCEVSignExtendExpr *E = dyn_cast(S)) { - uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), CxtI); return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? getTypeSizeInBits(E->getType()) : OpRes; @@ -5549,9 +5550,10 @@ if (const SCEVAddExpr *A = dyn_cast(S)) { // The result is the min of all operands results. - uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), CxtI); for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); + MinOpRes = + std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), CxtI)); return MinOpRes; } @@ -5561,8 +5563,8 @@ uint32_t BitWidth = getTypeSizeInBits(M->getType()); for (unsigned i = 1, e = M->getNumOperands(); SumOpRes != BitWidth && i != e; ++i) - SumOpRes = - std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth); + SumOpRes = std::min( + SumOpRes + GetMinTrailingZeros(M->getOperand(i), CxtI), BitWidth); return SumOpRes; } @@ -5570,15 +5572,17 @@ // The result is the min of all operands results. uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); + MinOpRes = + std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), CxtI)); return MinOpRes; } if (const SCEVSMaxExpr *M = dyn_cast(S)) { // The result is the min of all operands results. - uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), CxtI); for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); + MinOpRes = + std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), CxtI)); return MinOpRes; } @@ -5586,13 +5590,15 @@ // The result is the min of all operands results. uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); + MinOpRes = + std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), CxtI)); return MinOpRes; } if (const SCEVUnknown *U = dyn_cast(S)) { // For a SCEVUnknown, ask ValueTracking. - KnownBits Known = computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT); + KnownBits Known = + computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, CxtI, &DT); return Known.countMinTrailingZeros(); } @@ -5600,12 +5606,18 @@ return 0; } -uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { - auto I = MinTrailingZerosCache.find(S); - if (I != MinTrailingZerosCache.end()) - return I->second; +uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S, + const Instruction *CxtI) { + if (!CxtI) { + auto I = MinTrailingZerosCache.find(S); + if (I != MinTrailingZerosCache.end()) + return I->second; + } + + uint32_t Result = GetMinTrailingZerosImpl(S, CxtI); + if (CxtI) + return Result; - uint32_t Result = GetMinTrailingZerosImpl(S); auto InsertPair = MinTrailingZerosCache.insert({S, Result}); assert(InsertPair.second && "Should insert a new key"); return InsertPair.first->second; @@ -6876,7 +6888,9 @@ // Attempt to factor more general cases. Returns the greatest power of // two divisor. If overflow happens, the trip count expression is still // divisible by the greatest power of 2 divisor returned. - return 1U << std::min((uint32_t)31, GetMinTrailingZeros(TCExpr)); + return 1U << std::min( + (uint32_t)31, + GetMinTrailingZeros(TCExpr, L->getHeader()->getTerminator())); ConstantInt *Result = TC->getValue(); Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -5544,6 +5544,19 @@ // Accept MaxVF if we do not have a tail. LLVM_DEBUG(dbgs() << "LV: No tail will remain for any chosen VF.\n"); return MaxVF; + } else if (isPowerOf2_32(MaxVFtimesIC)) { + // Since getURemExpr() doesn't take an assumption context it may have + // ignored relevant assumptions. For powers of two also try proving + // divisibility using GetMinTrailingZeros() with the header's terminator as + // context. + unsigned TCisMultipleOf = + 1 << SE->GetMinTrailingZeros(ExitCount, + TheLoop->getHeader()->getTerminator()); + if (TCisMultipleOf % MaxVFtimesIC == 0) { + // Accept MaxVF if we do not have a tail. + LLVM_DEBUG(dbgs() << "LV: No tail will remain for any chosen VF.\n"); + return MaxVF; + } } // If we don't know the precise trip count, or if the trip count that we Index: llvm/test/Transforms/LoopUnroll/runtime-unroll-assume-no-remainder.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LoopUnroll/runtime-unroll-assume-no-remainder.ll @@ -0,0 +1,96 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -S -loop-unroll -unroll-runtime=true -unroll-runtime-epilog=true | FileCheck %s + +; Make sure the loop is unrolled without a remainder loop based on an assumption +; that the lower bits are known to be zero. + +define dso_local void @assumeDivisibleTC(i8* noalias nocapture %a, i8* noalias nocapture readonly %b, i32 %n, i1 %c) local_unnamed_addr { +; CHECK-LABEL: @assumeDivisibleTC( +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 [[C:%.*]], label [[BAIL_OUT:%.*]], label [[DO_WORK:%.*]] +; CHECK: bail.out: +; CHECK-NEXT: ret void +; CHECK: do.work: +; CHECK-NEXT: [[AND:%.*]] = and i32 [[N:%.*]], 3 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], 0 +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) +; CHECK-NEXT: [[CMP110:%.*]] = icmp sgt i32 [[N]], 0 +; CHECK-NEXT: br i1 [[CMP110]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_COND_CLEANUP:%.*]] +; CHECK: for.body.preheader: +; CHECK-NEXT: br label [[FOR_BODY:%.*]] +; CHECK: for.cond.cleanup.loopexit: +; CHECK-NEXT: br label [[FOR_COND_CLEANUP]] +; CHECK: for.cond.cleanup: +; CHECK-NEXT: ret void +; CHECK: for.body: +; CHECK-NEXT: [[I_011:%.*]] = phi i32 [ 0, [[FOR_BODY_PREHEADER]] ], [ [[INC_3:%.*]], [[FOR_BODY]] ] +; CHECK-NEXT: [[IDXPROM:%.*]] = zext i32 [[I_011]] to i64 +; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i8, i8* [[B:%.*]], i64 [[IDXPROM]] +; CHECK-NEXT: [[TMP0:%.*]] = load i8, i8* [[ARRAYIDX]], align 1 +; CHECK-NEXT: [[ADD:%.*]] = add i8 [[TMP0]], 3 +; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds i8, i8* [[A:%.*]], i64 [[IDXPROM]] +; CHECK-NEXT: store i8 [[ADD]], i8* [[ARRAYIDX4]], align 1 +; CHECK-NEXT: [[INC:%.*]] = add nuw nsw i32 [[I_011]], 1 +; CHECK-NEXT: [[IDXPROM_1:%.*]] = zext i32 [[INC]] to i64 +; CHECK-NEXT: [[ARRAYIDX_1:%.*]] = getelementptr inbounds i8, i8* [[B]], i64 [[IDXPROM_1]] +; CHECK-NEXT: [[TMP1:%.*]] = load i8, i8* [[ARRAYIDX_1]], align 1 +; CHECK-NEXT: [[ADD_1:%.*]] = add i8 [[TMP1]], 3 +; CHECK-NEXT: [[ARRAYIDX4_1:%.*]] = getelementptr inbounds i8, i8* [[A]], i64 [[IDXPROM_1]] +; CHECK-NEXT: store i8 [[ADD_1]], i8* [[ARRAYIDX4_1]], align 1 +; CHECK-NEXT: [[INC_1:%.*]] = add nuw nsw i32 [[INC]], 1 +; CHECK-NEXT: [[IDXPROM_2:%.*]] = zext i32 [[INC_1]] to i64 +; CHECK-NEXT: [[ARRAYIDX_2:%.*]] = getelementptr inbounds i8, i8* [[B]], i64 [[IDXPROM_2]] +; CHECK-NEXT: [[TMP2:%.*]] = load i8, i8* [[ARRAYIDX_2]], align 1 +; CHECK-NEXT: [[ADD_2:%.*]] = add i8 [[TMP2]], 3 +; CHECK-NEXT: [[ARRAYIDX4_2:%.*]] = getelementptr inbounds i8, i8* [[A]], i64 [[IDXPROM_2]] +; CHECK-NEXT: store i8 [[ADD_2]], i8* [[ARRAYIDX4_2]], align 1 +; CHECK-NEXT: [[INC_2:%.*]] = add nuw nsw i32 [[INC_1]], 1 +; CHECK-NEXT: [[IDXPROM_3:%.*]] = zext i32 [[INC_2]] to i64 +; CHECK-NEXT: [[ARRAYIDX_3:%.*]] = getelementptr inbounds i8, i8* [[B]], i64 [[IDXPROM_3]] +; CHECK-NEXT: [[TMP3:%.*]] = load i8, i8* [[ARRAYIDX_3]], align 1 +; CHECK-NEXT: [[ADD_3:%.*]] = add i8 [[TMP3]], 3 +; CHECK-NEXT: [[ARRAYIDX4_3:%.*]] = getelementptr inbounds i8, i8* [[A]], i64 [[IDXPROM_3]] +; CHECK-NEXT: store i8 [[ADD_3]], i8* [[ARRAYIDX4_3]], align 1 +; CHECK-NEXT: [[INC_3]] = add nuw nsw i32 [[INC_2]], 1 +; CHECK-NEXT: [[CMP1_3:%.*]] = icmp slt i32 [[INC_3]], [[N]] +; CHECK-NEXT: br i1 [[CMP1_3]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], [[LOOP0:!llvm.loop !.*]] +; +entry: + br i1 %c, label %bail.out, label %do.work + +bail.out: + ret void + +do.work: + %and = and i32 %n, 3 + %cmp = icmp eq i32 %and, 0 + tail call void @llvm.assume(i1 %cmp) + %cmp110 = icmp sgt i32 %n, 0 + br i1 %cmp110, label %for.body.preheader, label %for.cond.cleanup + +for.body.preheader: ; preds = %do.work + br label %for.body + +for.cond.cleanup.loopexit: ; preds = %for.body + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry + ret void + +for.body: ; preds = %for.body.preheader, %for.body + %i.011 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ] + %idxprom = zext i32 %i.011 to i64 + %arrayidx = getelementptr inbounds i8, i8* %b, i64 %idxprom + %0 = load i8, i8* %arrayidx, align 1 + %add = add i8 %0, 3 + %arrayidx4 = getelementptr inbounds i8, i8* %a, i64 %idxprom + store i8 %add, i8* %arrayidx4, align 1 + %inc = add nuw nsw i32 %i.011, 1 + %cmp1 = icmp slt i32 %inc, %n + br i1 %cmp1, label %for.body, label %for.cond.cleanup.loopexit, !llvm.loop !0 +} + +declare void @llvm.assume(i1 noundef) nofree nosync nounwind willreturn +!0 = distinct !{!0, !1, !2} +!1 = !{!"llvm.loop.mustprogress"} +!2 = !{!"llvm.loop.unroll.count", i32 4} Index: llvm/test/Transforms/LoopVectorize/dont-fold-tail-for-assumed-divisible-TC.ll =================================================================== --- llvm/test/Transforms/LoopVectorize/dont-fold-tail-for-assumed-divisible-TC.ll +++ llvm/test/Transforms/LoopVectorize/dont-fold-tail-for-assumed-divisible-TC.ll @@ -6,13 +6,21 @@ ; Make sure the loop is vectorized under -Os without folding its tail based on ; its trip-count's lower bits assumed to be zero. -define dso_local void @assumeAlignedTC(i32* noalias nocapture %A, i32* %p) optsize { +define dso_local void @assumeAlignedTC(i32* noalias nocapture %A, i32 %p, i32 %q, i1 %c) optsize { ; CHECK-LABEL: @assumeAlignedTC( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[N:%.*]] = load i32, i32* [[P:%.*]], align 4 -; CHECK-NEXT: [[AND:%.*]] = and i32 [[N]], 3 -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], 0 -; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) +; CHECK-NEXT: br i1 [[C:%.*]], label [[BAIL_OUT:%.*]], label [[DO_WORK:%.*]] +; CHECK: bail.out: +; CHECK-NEXT: ret void +; CHECK: do.work: +; CHECK-NEXT: [[AND1:%.*]] = and i32 [[P:%.*]], 3 +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[AND1]], 0 +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP1]]) +; CHECK-NEXT: [[AND2:%.*]] = and i32 [[Q:%.*]], 7 +; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[AND2]], 0 +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP2]]) +; CHECK-NEXT: [[GT:%.*]] = icmp sgt i32 [[P]], [[Q]] +; CHECK-NEXT: [[N:%.*]] = select i1 [[GT]], i32 [[P]], i32 [[Q]] ; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 4 ; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]] ; CHECK: vector.ph: @@ -36,7 +44,7 @@ ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]] ; CHECK-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]] ; CHECK: scalar.ph: -; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[DO_WORK]] ] ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: ; CHECK-NEXT: [[RIV:%.*]] = phi i32 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[RIVPLUS1:%.*]], [[LOOP]] ] @@ -49,14 +57,24 @@ ; CHECK-NEXT: ret void ; entry: - %n = load i32, i32* %p - %and = and i32 %n, 3 - %cmp = icmp eq i32 %and, 0 - tail call void @llvm.assume(i1 %cmp) + br i1 %c, label %bail.out, label %do.work + +bail.out: + ret void + +do.work: + %and1 = and i32 %p, 3 + %cmp1 = icmp eq i32 %and1, 0 + tail call void @llvm.assume(i1 %cmp1) + %and2 = and i32 %q, 7 + %cmp2 = icmp eq i32 %and2, 0 + tail call void @llvm.assume(i1 %cmp2) + %gt = icmp sgt i32 %p, %q + %n = select i1 %gt, i32 %p, i32 %q br label %loop loop: - %riv = phi i32 [ 0, %entry ], [ %rivPlus1, %loop ] + %riv = phi i32 [ 0, %do.work ], [ %rivPlus1, %loop ] %arrayidx = getelementptr inbounds i32, i32* %A, i32 %riv store i32 13, i32* %arrayidx, align 1 %rivPlus1 = add nuw nsw i32 %riv, 1