diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -15034,6 +15034,91 @@ if (MatchRangeCheckIdiom()) return; + // Return true if \p Expr is a MinMax SCEV expression with a constant + // operand. If so, return in \p SCTy the SCEV type and in \p RHS the + // non-constant operand and in \p LHS the constant operand. + auto IsMinMaxSCEVWithConstant = [&](const SCEV *Expr, SCEVTypes &SCTy, + const SCEV *&LHS, const SCEV *&RHS) { + if (auto *MinMax = dyn_cast(Expr)) { + if (MinMax->getNumOperands() != 2) + return false; + SCTy = MinMax->getSCEVType(); + if (!isa(MinMax->getOperand(0))) + return false; + LHS = MinMax->getOperand(0); + RHS = MinMax->getOperand(1); + return true; + } + return false; + }; + + // Checks whether Expr is a non-negative constant, and Divisor is a positive + // constant, and returns their APInt in ExprVal and in DivisorVal. + auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor, + APInt &ExprVal, APInt &DivisorVal) { + if (!isKnownNonNegative(Expr) || !isKnownPositive(Divisor)) + return false; + auto *ConstExpr = dyn_cast(Expr); + auto *ConstDivisor = dyn_cast(Divisor); + if (!ConstExpr || !ConstDivisor) + return false; + ExprVal = ConstExpr->getAPInt(); + DivisorVal = ConstDivisor->getAPInt(); + return true; + }; + + // Return a new SCEV that modifies \p Expr to the closest number divides by + // \p Divisor and greater or equal than Expr. + // For now, only handle constant Expr and Divisor. + auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, + const SCEV *Divisor) { + APInt ExprVal; + APInt DivisorVal; + if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) + return Expr; + APInt Rem = ExprVal.urem(DivisorVal); + if (!Rem.isZero()) + // return the SCEV: Expr + Divisor - Expr % Divisor + return getConstant(ExprVal + DivisorVal - Rem); + return Expr; + }; + + // Return a new SCEV that modifies \p Expr to the closest number divides by + // \p Divisor and less or equal than Expr. + // For now, only handle constant Expr and Divisor. + auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, + const SCEV *Divisor) { + APInt ExprVal; + APInt DivisorVal; + if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) + return Expr; + APInt Rem = ExprVal.urem(DivisorVal); + // return the SCEV: Expr - Expr % Divisor + return getConstant(ExprVal - Rem); + }; + + // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, + // recursively. This is done by aligning up/down the constant value to the + // Divisor. + std::function + ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr, + const SCEV *Divisor) { + const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr; + SCEVTypes SCTy; + if (!IsMinMaxSCEVWithConstant(MinMaxExpr, SCTy, MinMaxLHS, MinMaxRHS)) + return MinMaxExpr; + auto IsMin = + isa(MinMaxExpr) || isa(MinMaxExpr); + assert(isKnownNonNegative(MinMaxLHS) && + "Expected non-negative operand!"); + auto *DivisibleExpr = + IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor) + : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor); + SmallVector Ops = { + ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; + return getMinMaxExpr(SCTy, Ops); + }; + // If we have LHS == 0, check if LHS is computing a property of some unknown // SCEV %v which we can rewrite %v to express explicitly. const SCEVConstant *RHSC = dyn_cast(RHS); @@ -15045,7 +15130,12 @@ const SCEV *URemRHS = nullptr; if (matchURem(LHS, URemLHS, URemRHS)) { if (const SCEVUnknown *LHSUnknown = dyn_cast(URemLHS)) { - const auto *Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS); + auto I = RewriteMap.find(LHSUnknown); + const SCEV *RewrittenLHS = + I != RewriteMap.end() ? I->second : LHSUnknown; + RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS); + const auto *Multiple = + getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS); RewriteMap[LHSUnknown] = Multiple; ExprsToRewrite.push_back(LHSUnknown); return; @@ -15068,48 +15158,128 @@ auto I = RewriteMap.find(LHS); const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS; + // Check for the SCEV expression (A /u B) * B while B is a constant, inside + // \p Expr. The check is done recuresively on \p Expr, which is assumed to + // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A + // /u B) * B was found, and return the divisor B in \p DividesBy. For + // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since + // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p + // DividesBy. + std::function HasDivisibiltyInfo = + [&](const SCEV *Expr, const SCEV *&DividesBy) { + if (auto *Mul = dyn_cast(Expr)) { + if (Mul->getNumOperands() != 2) + return false; + auto *MulLHS = Mul->getOperand(0); + auto *MulRHS = Mul->getOperand(1); + if (isa(MulLHS)) + std::swap(MulLHS, MulRHS); + if (auto *Div = dyn_cast(MulLHS)) { + if (Div->getOperand(1) == MulRHS) { + DividesBy = MulRHS; + return true; + } + } + } + if (auto *MinMax = dyn_cast(Expr)) { + return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) || + HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy); + } + return false; + }; + + // Return true if Expr known to divide by \p DividesBy. + std::function IsKnownToDivideBy = + [&](const SCEV *Expr, const SCEV *DividesBy) { + if (getURemExpr(Expr, DividesBy)->isZero()) + return true; + if (auto *MinMax = dyn_cast(Expr)) { + return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) && + IsKnownToDivideBy(MinMax->getOperand(1), DividesBy); + } + return false; + }; + + const SCEV *DividesBy = nullptr; + if (HasDivisibiltyInfo(RewrittenLHS, DividesBy)) + // Check that the whole expression is divided by DividesBy + DividesBy = + IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr; + const SCEV *RewrittenRHS = nullptr; switch (Predicate) { case CmpInst::ICMP_ULT: { if (RHS->getType()->isPointerTy()) break; const SCEV *One = getOne(RHS->getType()); - RewrittenRHS = - getUMinExpr(RewrittenLHS, getMinusSCEV(getUMaxExpr(RHS, One), One)); + auto *ModifiedRHS = getMinusSCEV(getUMaxExpr(RHS, One), One); + ModifiedRHS = + DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy) + : ModifiedRHS; + RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS); break; } - case CmpInst::ICMP_SLT: - RewrittenRHS = - getSMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType()))); + case CmpInst::ICMP_SLT: { + auto *ModifiedRHS = getMinusSCEV(RHS, getOne(RHS->getType())); + ModifiedRHS = + DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy) + : ModifiedRHS; + RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_ULE: - RewrittenRHS = getUMinExpr(RewrittenLHS, RHS); + } + case CmpInst::ICMP_ULE: { + auto *ModifiedRHS = + DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_SLE: - RewrittenRHS = getSMinExpr(RewrittenLHS, RHS); + } + case CmpInst::ICMP_SLE: { + auto *ModifiedRHS = + DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_UGT: - RewrittenRHS = - getUMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType()))); + } + case CmpInst::ICMP_UGT: { + auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType())); + ModifiedRHS = DividesBy + ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy) + : ModifiedRHS; + RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_SGT: - RewrittenRHS = - getSMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType()))); + } + case CmpInst::ICMP_SGT: { + auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType())); + ModifiedRHS = DividesBy + ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy) + : ModifiedRHS; + RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_UGE: - RewrittenRHS = getUMaxExpr(RewrittenLHS, RHS); + } + case CmpInst::ICMP_UGE: { + auto *ModifiedRHS = + DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS); break; - case CmpInst::ICMP_SGE: - RewrittenRHS = getSMaxExpr(RewrittenLHS, RHS); + } + case CmpInst::ICMP_SGE: { + auto *ModifiedRHS = + DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS); break; + } case CmpInst::ICMP_EQ: if (isa(RHS)) RewrittenRHS = RHS; break; case CmpInst::ICMP_NE: if (isa(RHS) && - cast(RHS)->getValue()->isNullValue()) - RewrittenRHS = getUMaxExpr(RewrittenLHS, getOne(RHS->getType())); + cast(RHS)->getValue()->isNullValue()) { + auto *ModifiedRHS = getOne(RHS->getType()); + ModifiedRHS = DividesBy + ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy) + : ModifiedRHS; + RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS); + } break; default: break; diff --git a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll --- a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll +++ b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll @@ -125,7 +125,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 2 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -196,7 +196,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 2 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -267,7 +267,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -338,7 +338,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %u = urem i32 %num, 4 @@ -409,7 +409,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %cmp.1 = icmp uge i32 %num, 5 @@ -446,7 +446,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %cmp.1 = icmp uge i32 %num, 5 @@ -483,7 +483,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 4 ; entry: %cmp.1 = icmp uge i32 %num, 5 diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1744,4 +1744,42 @@ }); } +TEST_F(ScalarEvolutionsTest, ApplyLoopGuards) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "declare void @llvm.assume(i1)\n" + "define void @test(i32 %num) {\n" + "entry:\n" + " %u = urem i32 %num, 4\n" + " %cmp = icmp eq i32 %u, 0\n" + " tail call void @llvm.assume(i1 %cmp)\n" + " %cmp.1 = icmp ugt i32 %num, 0\n" + " tail call void @llvm.assume(i1 %cmp.1)\n" + " br label %for.body\n" + "for.body:\n" + " %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]\n" + " %inc = add nuw nsw i32 %i.010, 1\n" + " %cmp2 = icmp ult i32 %inc, %num\n" + " br i1 %cmp2, label %for.body, label %exit\n" + "exit:\n" + " ret void\n" + "}\n", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "test", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *TCScev = SE.getSCEV(getArgByName(F, "num")); + auto *ApplyLoopGuardsTC = SE.applyLoopGuards(TCScev, *LI.begin()); + // Assert that the new TC is (4 * ((4 umax %num) /u 4)) + APInt Four(32, 4); + auto *Constant4 = SE.getConstant(Four); + auto *Max = SE.getUMaxExpr(TCScev, Constant4); + auto *Mul = SE.getMulExpr(SE.getUDivExpr(Max, Constant4), Constant4); + ASSERT_TRUE(Mul == ApplyLoopGuardsTC); + }); +} + } // end namespace llvm