diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -15023,6 +15023,93 @@ if (MatchRangeCheckIdiom()) return; + // Return true if \p Expr is a MinMax SCEV expression with a non-negative + // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS + // the non-constant operand and in \p LHS the constant operand. + auto IsMinMaxSCEVWithNonNegativeConstant = + [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS, + const SCEV *&RHS) { + if (auto *MinMax = dyn_cast(Expr)) { + if (MinMax->getNumOperands() != 2) + return false; + if (auto *C = dyn_cast(MinMax->getOperand(0))) { + if (C->getAPInt().isNegative()) + return false; + SCTy = MinMax->getSCEVType(); + LHS = MinMax->getOperand(0); + RHS = MinMax->getOperand(1); + return true; + } + } + return false; + }; + + // Checks whether Expr is a non-negative constant, and Divisor is a positive + // constant, and returns their APInt in ExprVal and in DivisorVal. + auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor, + APInt &ExprVal, APInt &DivisorVal) { + auto *ConstExpr = dyn_cast(Expr); + auto *ConstDivisor = dyn_cast(Divisor); + if (!ConstExpr || !ConstDivisor) + return false; + ExprVal = ConstExpr->getAPInt(); + DivisorVal = ConstDivisor->getAPInt(); + return ExprVal.isNonNegative() && !DivisorVal.isNonPositive(); + }; + + // Return a new SCEV that modifies \p Expr to the closest number divides by + // \p Divisor and greater or equal than Expr. + // For now, only handle constant Expr and Divisor. + auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, + const SCEV *Divisor) { + APInt ExprVal; + APInt DivisorVal; + if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) + return Expr; + APInt Rem = ExprVal.urem(DivisorVal); + if (!Rem.isZero()) + // return the SCEV: Expr + Divisor - Expr % Divisor + return getConstant(ExprVal + DivisorVal - Rem); + return Expr; + }; + + // Return a new SCEV that modifies \p Expr to the closest number divides by + // \p Divisor and less or equal than Expr. + // For now, only handle constant Expr and Divisor. + auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, + const SCEV *Divisor) { + APInt ExprVal; + APInt DivisorVal; + if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) + return Expr; + APInt Rem = ExprVal.urem(DivisorVal); + // return the SCEV: Expr - Expr % Divisor + return getConstant(ExprVal - Rem); + }; + + // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, + // recursively. This is done by aligning up/down the constant value to the + // Divisor. + std::function + ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr, + const SCEV *Divisor) { + const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr; + SCEVTypes SCTy; + if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS, + MinMaxRHS)) + return MinMaxExpr; + auto IsMin = + isa(MinMaxExpr) || isa(MinMaxExpr); + assert(isKnownNonNegative(MinMaxLHS) && + "Expected non-negative operand!"); + auto *DivisibleExpr = + IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor) + : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor); + SmallVector Ops = { + ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; + return getMinMaxExpr(SCTy, Ops); + }; + // If we have LHS == 0, check if LHS is computing a property of some unknown // SCEV %v which we can rewrite %v to express explicitly. const SCEVConstant *RHSC = dyn_cast(RHS); @@ -15034,7 +15121,12 @@ const SCEV *URemRHS = nullptr; if (matchURem(LHS, URemLHS, URemRHS)) { if (const SCEVUnknown *LHSUnknown = dyn_cast(URemLHS)) { - const auto *Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS); + auto I = RewriteMap.find(LHSUnknown); + const SCEV *RewrittenLHS = + I != RewriteMap.end() ? I->second : LHSUnknown; + RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS); + const auto *Multiple = + getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS); RewriteMap[LHSUnknown] = Multiple; ExprsToRewrite.push_back(LHSUnknown); return; @@ -15071,6 +15163,52 @@ return I != RewriteMap.end() ? I->second : S; }; + // Check for the SCEV expression (A /u B) * B while B is a constant, inside + // \p Expr. The check is done recuresively on \p Expr, which is assumed to + // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A + // /u B) * B was found, and return the divisor B in \p DividesBy. For + // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since + // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p + // DividesBy. + std::function HasDivisibiltyInfo = + [&](const SCEV *Expr, const SCEV *&DividesBy) { + if (auto *Mul = dyn_cast(Expr)) { + if (Mul->getNumOperands() != 2) + return false; + auto *MulLHS = Mul->getOperand(0); + auto *MulRHS = Mul->getOperand(1); + if (isa(MulLHS)) + std::swap(MulLHS, MulRHS); + if (auto *Div = dyn_cast(MulLHS)) + if (Div->getOperand(1) == MulRHS) { + DividesBy = MulRHS; + return true; + } + } + if (auto *MinMax = dyn_cast(Expr)) + return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) || + HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy); + return false; + }; + + // Return true if Expr known to divide by \p DividesBy. + std::function IsKnownToDivideBy = + [&](const SCEV *Expr, const SCEV *DividesBy) { + if (getURemExpr(Expr, DividesBy)->isZero()) + return true; + if (auto *MinMax = dyn_cast(Expr)) + return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) && + IsKnownToDivideBy(MinMax->getOperand(1), DividesBy); + return false; + }; + + const SCEV *RewrittenLHS = GetMaybeRewritten(LHS); + const SCEV *DividesBy = nullptr; + if (HasDivisibiltyInfo(RewrittenLHS, DividesBy)) + // Check that the whole expression is divided by DividesBy + DividesBy = + IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr; + // Collect rewrites for LHS and its transitive operands based on the // condition. // For min/max expressions, also apply the guard to its operands: @@ -15091,11 +15229,21 @@ LLVM_FALLTHROUGH; case CmpInst::ICMP_SLT: { RHS = getMinusSCEV(RHS, One); + RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; break; } case CmpInst::ICMP_UGT: case CmpInst::ICMP_SGT: RHS = getAddExpr(RHS, One); + RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + case CmpInst::ICMP_ULE: + case CmpInst::ICMP_SLE: + RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + case CmpInst::ICMP_UGE: + case CmpInst::ICMP_SGE: + RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; break; default: break; @@ -15148,8 +15296,11 @@ break; case CmpInst::ICMP_NE: if (isa(RHS) && - cast(RHS)->getValue()->isNullValue()) - To = getUMaxExpr(FromRewritten, One); + cast(RHS)->getValue()->isNullValue()) { + const SCEV *OneAlignedUp = + DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One; + To = getUMaxExpr(FromRewritten, OneAlignedUp); + } break; default: break; diff --git a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll --- a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll +++ b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll @@ -125,7 +125,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 2 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -196,7 +196,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 2 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -267,7 +267,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -338,7 +338,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -409,7 +409,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %cmp.1 = icmp uge i32 %num, 5 @@ -446,7 +446,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %cmp.1 = icmp uge i32 %num, 5 @@ -483,7 +483,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %cmp.1 = icmp uge i32 %num, 5 diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1760,4 +1760,42 @@ ->equalsInt(1ULL << i)); } +TEST_F(ScalarEvolutionsTest, ApplyLoopGuards) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "declare void @llvm.assume(i1)\n" + "define void @test(i32 %num) {\n" + "entry:\n" + " %u = urem i32 %num, 4\n" + " %cmp = icmp eq i32 %u, 0\n" + " tail call void @llvm.assume(i1 %cmp)\n" + " %cmp.1 = icmp ugt i32 %num, 0\n" + " tail call void @llvm.assume(i1 %cmp.1)\n" + " br label %for.body\n" + "for.body:\n" + " %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]\n" + " %inc = add nuw nsw i32 %i.010, 1\n" + " %cmp2 = icmp ult i32 %inc, %num\n" + " br i1 %cmp2, label %for.body, label %exit\n" + "exit:\n" + " ret void\n" + "}\n", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "test", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *TCScev = SE.getSCEV(getArgByName(F, "num")); + auto *ApplyLoopGuardsTC = SE.applyLoopGuards(TCScev, *LI.begin()); + // Assert that the new TC is (4 * ((4 umax %num) /u 4)) + APInt Four(32, 4); + auto *Constant4 = SE.getConstant(Four); + auto *Max = SE.getUMaxExpr(TCScev, Constant4); + auto *Mul = SE.getMulExpr(SE.getUDivExpr(Max, Constant4), Constant4); + ASSERT_TRUE(Mul == ApplyLoopGuardsTC); + }); +} + } // end namespace llvm