Index: llvm/lib/Analysis/ScalarEvolutionNormalization.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolutionNormalization.cpp +++ llvm/lib/Analysis/ScalarEvolutionNormalization.cpp @@ -41,9 +41,55 @@ : SCEVRewriteVisitor(SE), Kind(Kind), Pred(Pred) {} const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr); + + const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr); + + const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr); + + const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr); + + const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr); }; } // namespace +// Don't do anything for max/min expressions. Normalizing SCEV returns +// an inequivalent expression in which max/min may become redundant, so scalar +// evolution can simplify it. Denormalizing such expression wouldn't return the +// missing max/min back. +// Consider the following example. Imagine we have the following loop: +// loop: +// %iv = phi [ 11, %entry ], [ %iv.next, %loop ] +// %umax = umax(%iv, 10) +// %iv.next = add %iv, -4 +// %loop.cond = %iv.next u< 7 +// br i1 %loop.cond, %loop, %exit +// It executes exactly two iterations. The SCEV for %umax is (10 umax +// {11,+,-4}). Normalizing it for loop '%loop' would give us (10 umax +// {15,+,-4}). Now, as the loop has only two iterations, the AddRec is always +// greater than 10, so SCEV simplifies the expression to {15,+,-4}. +// Denormalizing it back would give a wrong result - {11,+,-4}. Obviously on the +// 2nd iteration the value of %umax is 10, but according to SCEV it's 11 - 4 +// = 7. +const SCEV * +NormalizeDenormalizeRewriter::visitSMaxExpr(const SCEVSMaxExpr *Expr) { + return Expr; +} + +const SCEV * +NormalizeDenormalizeRewriter::visitUMaxExpr(const SCEVUMaxExpr *Expr) { + return Expr; +} + +const SCEV * +NormalizeDenormalizeRewriter::visitSMinExpr(const SCEVSMinExpr *Expr) { + return Expr; +} + +const SCEV * +NormalizeDenormalizeRewriter::visitUMinExpr(const SCEVUMinExpr *Expr) { + return Expr; +} + const SCEV * NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) { SmallVector Operands; Index: llvm/unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -516,6 +516,33 @@ "end: " " ret void " "} " + " " + "declare i32 @llvm.umax.i32(i32, i32)" + " " + "declare i32 @llvm.smax.i32(i32, i32)" + " " + "declare i32 @llvm.umin.i32(i32, i32)" + " " + "declare i32 @llvm.smin.i32(i32, i32)" + " " + "define void @f_3() " + " local_unnamed_addr { " + "entry: " + " br label %loop " + " " + "loop: " + " %iv = phi i32 [ 155, %entry ], [ %iv.next, %loop ]" // 155 ... 7 + " %iv.next = add i32 %iv, -4" + " %umax = call i32 @llvm.umax.i32(i32 %iv, i32 10)" + " %smax = call i32 @llvm.smax.i32(i32 %iv, i32 10)" + " %umin = call i32 @llvm.umin.i32(i32 %iv, i32 10)" + " %smin = call i32 @llvm.smin.i32(i32 %iv, i32 10)" + " %loop.cond = icmp ult i32 %iv.next, 7" + " br i1 %loop.cond, label %end, label %loop " + " " + "end: " + " ret void " + "} " , Err, C); @@ -619,6 +646,37 @@ } } }); + + runWithSE(*M, "f_3", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto &UMaxInst = GetInstByName(F, "umax"); + auto &SMaxInst = GetInstByName(F, "smax"); + auto &UMinInst = GetInstByName(F, "umin"); + auto &SMinInst = GetInstByName(F, "smin"); + + auto *L = LI.getLoopFor(UMaxInst.getParent()); + PostIncLoopSet Loops; + Loops.insert(L); + + // Check that normalizing and then denormalizing (and vice versa) preserves + // umax, smax, umin, smin expressions. + auto *UMaxS = SE.getSCEV(&UMaxInst); + auto *SMaxS = SE.getSCEV(&SMaxInst); + auto *UMinS = SE.getSCEV(&UMinInst); + auto *SMinS = SE.getSCEV(&SMinInst); + + for (auto *S : {UMaxS, SMaxS, UMinS, SMinS}) { + { + auto *N = normalizeForPostIncUse(S, Loops, SE); + auto *D = denormalizeForPostIncUse(N, Loops, SE); + EXPECT_EQ(S, D) << *S << " " << *D; + } + { + auto *D = denormalizeForPostIncUse(S, Loops, SE); + auto *N = normalizeForPostIncUse(D, Loops, SE); + EXPECT_EQ(S, N) << *S << " " << *N; + } + } + }); } // Expect the call of getZeroExtendExpr will not cost exponential time.