Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -15017,6 +15017,81 @@ if (MatchRangeCheckIdiom()) return; + // Return true if \p Expr is a MinMax SCEV expression with a constant + // operand. If so, return in \p SCTy the SCEV type and in \p LHS the + // non-constant operand and in \p RHS the constant operand. + auto IsMinMaxSCEVWithConstant = [&](const SCEV *Expr, SCEVTypes &SCTy, + const SCEV *&LHS, const SCEV *&RHS) { + if (auto *MinMax = dyn_cast(Expr)) { + if (MinMax->getNumOperands() != 2) + return false; + SCTy = MinMax->getSCEVType(); + LHS = MinMax->getOperand(0); + RHS = MinMax->getOperand(1); + return isa(RHS); + } + return false; + }; + + // Return a new SCEV that modifies \p Expr to the closest number divides by + // \p Divisor and greater or equal than Expr. + // For now, only handle constant Expr and Divisor. + auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, + const SCEV *Divisor) { + if (!isKnownNonNegative(Expr) || !isKnownPositive(Divisor)) + return Expr; + auto *ConstExpr = dyn_cast(Expr); + auto *ConstDivisor = dyn_cast(Divisor); + if (!ConstExpr || !ConstDivisor) + return Expr; + APInt ExprVal = ConstExpr->getAPInt(); + APInt DivisorVal = ConstDivisor->getAPInt(); + APInt Rem = ExprVal.urem(DivisorVal); + if (!Rem.isZero()) { + // return the SCEV: Expr + Divisor - Expr % Divisor + return getConstant(ExprVal + DivisorVal - Rem); + } + return Expr; + }; + + // Return a new SCEV that modifies \p Expr to the closest number divides by + // \p Divisor and less or equal than Expr. + // For now, only handle constant Expr and Divisor. + auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, + const SCEV *Divisor) { + if (!isKnownNonNegative(Expr) || !isKnownPositive(Divisor)) + return Expr; + auto *ConstExpr = dyn_cast(Expr); + auto *ConstDivisor = dyn_cast(Divisor); + if (!ConstExpr || !ConstDivisor) + return Expr; + APInt ExprVal = ConstExpr->getAPInt(); + APInt DivisorVal = ConstDivisor->getAPInt(); + APInt Rem = ExprVal.urem(DivisorVal); + // return the SCEV: Expr - Expr % Divisor + return getConstant(ExprVal - Rem); + }; + + // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, + // recursively. This is done by aligning up/down the constant value to the + // Divisor. + std::function + ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr, + const SCEV *Divisor) { + const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr; + SCEVTypes SCTy; + if (!IsMinMaxSCEVWithConstant(MinMaxExpr, SCTy, MinMaxLHS, MinMaxRHS)) + return MinMaxExpr; + auto IsMin = + isa(MinMaxExpr) || isa(MinMaxExpr); + auto DivisibleExpr = + IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxRHS, Divisor) + : GetNextSCEVDividesByDivisor(MinMaxRHS, Divisor); + SmallVector Ops = { + ApplyDivisibiltyOnMinMaxExpr(MinMaxLHS, Divisor), DivisibleExpr}; + return getMinMaxExpr(SCTy, Ops); + }; + // If we have LHS == 0, check if LHS is computing a property of some unknown // SCEV %v which we can rewrite %v to express explicitly. const SCEVConstant *RHSC = dyn_cast(RHS); @@ -15028,7 +15103,12 @@ const SCEV *URemRHS = nullptr; if (matchURem(LHS, URemLHS, URemRHS)) { if (const SCEVUnknown *LHSUnknown = dyn_cast(URemLHS)) { - const auto *Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS); + auto I = RewriteMap.find(LHSUnknown); + const SCEV *RewrittenLHS = + I != RewriteMap.end() ? I->second : LHSUnknown; + RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS); + const auto *Multiple = + getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS); RewriteMap[LHSUnknown] = Multiple; ExprsToRewrite.push_back(LHSUnknown); return; @@ -15051,48 +15131,129 @@ auto I = RewriteMap.find(LHS); const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS; + // Check for the SCEV expression (A /u B) * B while B is a constant, inside + // \p Expr. The check is done recuresively on \p Expr, which is assumed to + // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A + // /u B) * B was found, and return the divisor B in \p DividesBy. For + // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since + // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p + // DividesBy. + std::function HasDivisibiltyInfo = + [&](const SCEV *Expr, const SCEV *&DividesBy) { + if (auto *Mul = dyn_cast(Expr)) { + if (Mul->getNumOperands() != 2) + return false; + auto *MulLHS = Mul->getOperand(0); + auto *MulRHS = Mul->getOperand(1); + if (isa(MulLHS)) + std::swap(MulLHS, MulRHS); + if (auto *Div = dyn_cast(MulLHS)) { + if (Div->getOperand(1) == MulRHS) { + DividesBy = MulRHS; + return true; + } + } + } + if (auto *MinMax = dyn_cast(Expr)) { + return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) || + HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy); + } + return false; + }; + + // Return true if Expr known to divide by \p DividesBy. + std::function IsKnownToDivideBy = + [&](const SCEV *Expr, const SCEV *DividesBy) { + if (getURemExpr(Expr, DividesBy)->isZero()) + return true; + if (isa(Expr)) { + auto MinMax = cast(Expr); + return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) && + IsKnownToDivideBy(MinMax->getOperand(1), DividesBy); + } + return false; + }; + + const SCEV *DividesBy = nullptr; + if (HasDivisibiltyInfo(RewrittenLHS, DividesBy)) + // Check that the whole expression is divided by DividesBy + DividesBy = + IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr; + const SCEV *RewrittenRHS = nullptr; switch (Predicate) { case CmpInst::ICMP_ULT: { if (RHS->getType()->isPointerTy()) break; const SCEV *One = getOne(RHS->getType()); - RewrittenRHS = - getUMinExpr(RewrittenLHS, getMinusSCEV(getUMaxExpr(RHS, One), One)); + auto *ModifiedRHS = getMinusSCEV(getUMaxExpr(RHS, One), One); + ModifiedRHS = + DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy) + : ModifiedRHS; + RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS); break; } - case CmpInst::ICMP_SLT: - RewrittenRHS = - getSMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType()))); + case CmpInst::ICMP_SLT: { + auto *ModifiedRHS = getMinusSCEV(RHS, getOne(RHS->getType())); + ModifiedRHS = + DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy) + : ModifiedRHS; + RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_ULE: - RewrittenRHS = getUMinExpr(RewrittenLHS, RHS); + } + case CmpInst::ICMP_ULE: { + auto *ModifiedRHS = + DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_SLE: - RewrittenRHS = getSMinExpr(RewrittenLHS, RHS); + } + case CmpInst::ICMP_SLE: { + auto *ModifiedRHS = + DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_UGT: - RewrittenRHS = - getUMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType()))); + } + case CmpInst::ICMP_UGT: { + auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType())); + ModifiedRHS = DividesBy + ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy) + : ModifiedRHS; + RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_SGT: - RewrittenRHS = - getSMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType()))); + } + case CmpInst::ICMP_SGT: { + auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType())); + ModifiedRHS = DividesBy + ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy) + : ModifiedRHS; + RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_UGE: - RewrittenRHS = getUMaxExpr(RewrittenLHS, RHS); + } + case CmpInst::ICMP_UGE: { + auto *ModifiedRHS = + DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_SGE: - RewrittenRHS = getSMaxExpr(RewrittenLHS, RHS); + } + case CmpInst::ICMP_SGE: { + auto *ModifiedRHS = + DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS); break; + } case CmpInst::ICMP_EQ: if (isa(RHS)) RewrittenRHS = RHS; break; case CmpInst::ICMP_NE: if (isa(RHS) && - cast(RHS)->getValue()->isNullValue()) - RewrittenRHS = getUMaxExpr(RewrittenLHS, getOne(RHS->getType())); + cast(RHS)->getValue()->isNullValue()) { + auto *ModifiedRHS = getOne(RHS->getType()); + ModifiedRHS = DividesBy + ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy) + : ModifiedRHS; + RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS); + } break; default: break; Index: llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll =================================================================== --- llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll +++ llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll @@ -40,14 +40,14 @@ define void @test_trip_multiple_4_guard(i32 %num) { ; CHECK-LABEL: 'test_trip_multiple_4_guard' -; CHECK-NEXT: Classifying expressions for: @test_trip_multiple_4 +; CHECK-NEXT: Classifying expressions for: @test_trip_multiple_4_guard ; CHECK-NEXT: %u = urem i32 %num, 4 ; CHECK-NEXT: --> (zext i2 (trunc i32 %num to i2) to i32) U: [0,4) S: [0,4) ; CHECK-NEXT: %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ] ; CHECK-NEXT: --> {0,+,1}<%for.body> U: [0,-2147483648) S: [0,-2147483648) Exits: (-1 + %num) LoopDispositions: { %for.body: Computable } ; CHECK-NEXT: %inc = add nuw nsw i32 %i.010, 1 ; CHECK-NEXT: --> {1,+,1}<%for.body> U: [1,-2147483648) S: [1,-2147483648) Exits: %num LoopDispositions: { %for.body: Computable } -; CHECK-NEXT: Determining loop execution counts for: @test_trip_multiple_4 +; CHECK-NEXT: Determining loop execution counts for: @test_trip_multiple_4_guard ; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is -2 ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) @@ -125,7 +125,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 2 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -196,7 +196,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 2 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -267,7 +267,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -338,7 +338,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -394,6 +394,119 @@ ret void } +define void @test_trip_multiple_4_upper_lower_bounds(i32 %num) { +; CHECK-LABEL: 'test_trip_multiple_4_upper_lower_bounds' +; CHECK-NEXT: Classifying expressions for: @test_trip_multiple_4_upper_lower_bounds +; CHECK-NEXT: %u = urem i32 %num, 4 +; CHECK-NEXT: --> (zext i2 (trunc i32 %num to i2) to i32) U: [0,4) S: [0,4) +; CHECK-NEXT: %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ] +; CHECK-NEXT: --> {0,+,1}<%for.body> U: [0,-2147483648) S: [0,-2147483648) Exits: (-1 + %num) LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: %inc = add nuw nsw i32 %i.010, 1 +; CHECK-NEXT: --> {1,+,1}<%for.body> U: [1,-2147483648) S: [1,-2147483648) Exits: %num LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: Determining loop execution counts for: @test_trip_multiple_4_upper_lower_bounds +; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + %num) +; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is -2 +; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) +; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) +; CHECK-NEXT: Predicates: +; CHECK: Loop %for.body: Trip multiple is 4 +; +entry: + %u = urem i32 %num, 4 + %cmp = icmp eq i32 %u, 0 + tail call void @llvm.assume(i1 %cmp) + %cmp.1 = icmp uge i32 %num, 5 + tail call void @llvm.assume(i1 %cmp.1) + %cmp.2 = icmp ult i32 %num, 59000 + tail call void @llvm.assume(i1 %cmp.2) + br label %for.body + +for.body: + %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ] + %inc = add nuw nsw i32 %i.010, 1 + %cmp2 = icmp ult i32 %inc, %num + br i1 %cmp2, label %for.body, label %exit + +exit: + ret void +} + +define void @test_trip_multiple_4_upper_lower_bounds_swapped1(i32 %num) { +; CHECK-LABEL: 'test_trip_multiple_4_upper_lower_bounds_swapped1' +; CHECK-NEXT: Classifying expressions for: @test_trip_multiple_4_upper_lower_bounds_swapped1 +; CHECK-NEXT: %u = urem i32 %num, 4 +; CHECK-NEXT: --> (zext i2 (trunc i32 %num to i2) to i32) U: [0,4) S: [0,4) +; CHECK-NEXT: %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ] +; CHECK-NEXT: --> {0,+,1}<%for.body> U: [0,-2147483648) S: [0,-2147483648) Exits: (-1 + %num) LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: %inc = add nuw nsw i32 %i.010, 1 +; CHECK-NEXT: --> {1,+,1}<%for.body> U: [1,-2147483648) S: [1,-2147483648) Exits: %num LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: Determining loop execution counts for: @test_trip_multiple_4_upper_lower_bounds_swapped1 +; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + %num) +; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is -2 +; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) +; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) +; CHECK-NEXT: Predicates: +; CHECK: Loop %for.body: Trip multiple is 4 +; +entry: + %cmp.1 = icmp uge i32 %num, 5 + tail call void @llvm.assume(i1 %cmp.1) + %u = urem i32 %num, 4 + %cmp = icmp eq i32 %u, 0 + tail call void @llvm.assume(i1 %cmp) + %cmp.2 = icmp ult i32 %num, 59000 + tail call void @llvm.assume(i1 %cmp.2) + br label %for.body + +for.body: + %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ] + %inc = add nuw nsw i32 %i.010, 1 + %cmp2 = icmp ult i32 %inc, %num + br i1 %cmp2, label %for.body, label %exit + +exit: + ret void +} + +; Even though the SCEV should be a multiple of 4, it isn't concluded. +; See PR60114 +define void @test_trip_multiple_4_upper_lower_bounds_swapped2(i32 %num) { +; CHECK-LABEL: 'test_trip_multiple_4_upper_lower_bounds_swapped2' +; CHECK-NEXT: Classifying expressions for: @test_trip_multiple_4_upper_lower_bounds_swapped2 +; CHECK-NEXT: %u = urem i32 %num, 4 +; CHECK-NEXT: --> (zext i2 (trunc i32 %num to i2) to i32) U: [0,4) S: [0,4) +; CHECK-NEXT: %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ] +; CHECK-NEXT: --> {0,+,1}<%for.body> U: [0,-2147483648) S: [0,-2147483648) Exits: (-1 + %num) LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: %inc = add nuw nsw i32 %i.010, 1 +; CHECK-NEXT: --> {1,+,1}<%for.body> U: [1,-2147483648) S: [1,-2147483648) Exits: %num LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: Determining loop execution counts for: @test_trip_multiple_4_upper_lower_bounds_swapped2 +; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + %num) +; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is -2 +; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) +; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) +; CHECK-NEXT: Predicates: +; CHECK: Loop %for.body: Trip multiple is 4 +; +entry: + %cmp.1 = icmp uge i32 %num, 5 + tail call void @llvm.assume(i1 %cmp.1) + %cmp.2 = icmp ult i32 %num, 59000 + tail call void @llvm.assume(i1 %cmp.2) + %u = urem i32 %num, 4 + %cmp = icmp eq i32 %u, 0 + tail call void @llvm.assume(i1 %cmp) + br label %for.body + +for.body: + %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ] + %inc = add nuw nsw i32 %i.010, 1 + %cmp2 = icmp ult i32 %inc, %num + br i1 %cmp2, label %for.body, label %exit + +exit: + ret void +} + define void @test_trip_multiple_5(i32 %num) { ; CHECK-LABEL: 'test_trip_multiple_5' ; CHECK-NEXT: Classifying expressions for: @test_trip_multiple_5