Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -14422,6 +14422,38 @@ if (MatchRangeCheckIdiom()) return; + // Return true if \p Expr is a Max SCEV expression with a constant + // operand. If so, return in \p SCTy the SCEV type and in \p LHS the + // non-constant operand and in \p RHS the constant operand + auto IsMaxSCEVWithConstant = [&](const SCEV *Expr, SCEVTypes &SCTy, + const SCEV *&LHS, const SCEV *&RHS) { + auto Max = dyn_cast(Expr); + if (!Max) + return false; + SCTy = cast(Expr)->getSCEVType(); + if (SCTy != scUMaxExpr && SCTy != scSMaxExpr) + return false; + LHS = Max->getOperand(0); + RHS = Max->getOperand(1); + if (isa(LHS)) + std::swap(LHS, RHS); + return isa(RHS); + }; + + // Return a new SCEV that aligns up \p Expr to the multiple of by \p + // Divisor. + auto AlignUpToMultipleOfDivisor = [&](const SCEV *Expr, + const SCEV *Divisor) { + if (!isKnownNonNegative(Expr) || !isKnownPositive(Divisor)) + return Expr; + auto Rem = getURemExpr(Expr, Divisor); + if (!Rem->isZero()) { + // return the SCEV: Expr + Divisor - Expr % Divisor + return getAddExpr(getMinusSCEV(Divisor, Rem), Expr); + } + return Expr; + }; + // If we have LHS == 0, check if LHS is computing a property of some unknown // SCEV %v which we can rewrite %v to express explicitly. const SCEVConstant *RHSC = dyn_cast(RHS); @@ -14433,9 +14465,24 @@ const SCEV *URemRHS = nullptr; if (matchURem(LHS, URemLHS, URemRHS)) { if (const SCEVUnknown *LHSUnknown = dyn_cast(URemLHS)) { - auto Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS); + // Check whether LHSUnknown has already been rewritten. + // In that case we want to maintiain both the divisibility property, + // and the maximum property. + auto I = RewriteMap.find(LHSUnknown); + const SCEV *RewrittenLHS = + I != RewriteMap.end() ? I->second : LHSUnknown; + const SCEV *MaxLHS, *MaxRHS; + SCEVTypes SCTy; + if (IsMaxSCEVWithConstant(RewrittenLHS, SCTy, MaxLHS, MaxRHS)) { + auto AlignedUp = AlignUpToMultipleOfDivisor(MaxRHS, URemRHS); + SmallVector Ops = {MaxLHS, AlignedUp}; + RewrittenLHS = getMinMaxExpr(SCTy, Ops); + } + auto Multiple = + getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS); RewriteMap[LHSUnknown] = Multiple; - ExprsToRewrite.push_back(LHSUnknown); + if (RewrittenLHS == LHSUnknown) + ExprsToRewrite.push_back(LHSUnknown); return; } } @@ -14460,6 +14507,27 @@ auto I = RewriteMap.find(LHS); const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS; + // Return true if Expr is (A /u B) * B while B is a constant, and return B + // in DividesBy. + auto IsKnownToDivideBy = [&](const SCEV *Expr, const SCEV *&DividesBy) { + if (auto Mul = dyn_cast(Expr)) { + auto MulLHS = Mul->getOperand(0); + auto MulRHS = Mul->getOperand(1); + if (isa(MulLHS)) + std::swap(MulLHS, MulRHS); + if (auto Div = dyn_cast(MulLHS)) { + if (Div->getOperand(1) == MulRHS) { + DividesBy = MulRHS; + return true; + } + } + } + return false; + }; + + const SCEV *DividesBy = nullptr; + IsKnownToDivideBy(RewrittenLHS, DividesBy); + const SCEV *RewrittenRHS = nullptr; switch (Predicate) { case CmpInst::ICMP_ULT: @@ -14476,20 +14544,34 @@ case CmpInst::ICMP_SLE: RewrittenRHS = getSMinExpr(RewrittenLHS, RHS); break; - case CmpInst::ICMP_UGT: - RewrittenRHS = - getUMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType()))); + case CmpInst::ICMP_UGT: { + auto ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType())); + ModifiedRHS = DividesBy + ? AlignUpToMultipleOfDivisor(ModifiedRHS, DividesBy) + : ModifiedRHS; + RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_SGT: - RewrittenRHS = - getSMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType()))); + } + case CmpInst::ICMP_SGT: { + auto ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType())); + ModifiedRHS = DividesBy + ? AlignUpToMultipleOfDivisor(ModifiedRHS, DividesBy) + : ModifiedRHS; + RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_UGE: - RewrittenRHS = getUMaxExpr(RewrittenLHS, RHS); + } + case CmpInst::ICMP_UGE: { + auto ModifiedRHS = + DividesBy ? AlignUpToMultipleOfDivisor(RHS, DividesBy) : RHS; + RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_SGE: - RewrittenRHS = getSMaxExpr(RewrittenLHS, RHS); + } + case CmpInst::ICMP_SGE: { + auto ModifiedRHS = + DividesBy ? AlignUpToMultipleOfDivisor(RHS, DividesBy) : RHS; + RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS); break; + } case CmpInst::ICMP_EQ: if (isa(RHS)) RewrittenRHS = RHS; Index: llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll =================================================================== --- llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll +++ llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll @@ -55,12 +55,11 @@ } define void @test_trip_multiple_4_ugt_5_order_swapped(i32 %num) { -; TODO: Trip multiple can be 4, it is missed due to the processing order of the assumes. ; CHECK-LABEL: @test_trip_multiple_4_ugt_5_order_swapped ; CHECK: Loop %for.body: backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: max backedge-taken count is -2 ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) -; CHECK: Loop %for.body: Trip multiple is 2 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -106,12 +105,11 @@ } define void @test_trip_multiple_4_sgt_5_order_swapped(i32 %num) { -; TODO: Trip multiple can be 4, it is missed due to the processing order of the assumes. ; CHECK-LABEL: @test_trip_multiple_4_sgt_5_order_swapped ; CHECK: Loop %for.body: backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: max backedge-taken count is 2147483646 ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) -; CHECK: Loop %for.body: Trip multiple is 2 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -157,12 +155,11 @@ } define void @test_trip_multiple_4_uge_5_order_swapped(i32 %num) { -; TODO: Trip multiple can be 4, it is missed due to the processing order of the assumes. ; CHECK-LABEL: @test_trip_multiple_4_uge_5_order_swapped ; CHECK: Loop %for.body: backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: max backedge-taken count is -2 ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -183,7 +180,6 @@ } define void @test_trip_multiple_4_sge_5(i32 %num) { -; TODO: Trip multiple can be 4, it is missed due to the processing order of the assumes. ; CHECK-LABEL: @test_trip_multiple_4_sge_5 ; CHECK: Loop %for.body: backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: max backedge-taken count is 2147483646 @@ -213,7 +209,7 @@ ; CHECK: Loop %for.body: backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: max backedge-taken count is 2147483646 ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4