Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -4280,9 +4280,10 @@ case ICmpInst::ICMP_SGE: // a >s b ? a+x : b+x -> smax(a, b)+x // a >s b ? b+x : a+x -> smin(a, b)+x - if (LHS->getType() == U->getType()) { - const SCEV *LS = getSCEV(LHS); - const SCEV *RS = getSCEV(RHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType())) { + const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType()); + const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -4303,9 +4304,10 @@ case ICmpInst::ICMP_UGE: // a >u b ? a+x : b+x -> umax(a, b)+x // a >u b ? b+x : a+x -> umin(a, b)+x - if (LHS->getType() == U->getType()) { - const SCEV *LS = getSCEV(LHS); - const SCEV *RS = getSCEV(RHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType())) { + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); + const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -4320,11 +4322,11 @@ break; case ICmpInst::ICMP_NE: // n != 0 ? n+x : 1+x -> umax(n, 1)+x - if (LHS->getType() == U->getType() && - isa(RHS) && - cast(RHS)->isZero()) { - const SCEV *One = getConstant(LHS->getType(), 1); - const SCEV *LS = getSCEV(LHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType()) && + isa(RHS) && cast(RHS)->isZero()) { + const SCEV *One = getConstant(U->getType(), 1); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -4335,11 +4337,11 @@ break; case ICmpInst::ICMP_EQ: // n == 0 ? 1+x : n+x -> umax(n, 1)+x - if (LHS->getType() == U->getType() && - isa(RHS) && - cast(RHS)->isZero()) { - const SCEV *One = getConstant(LHS->getType(), 1); - const SCEV *LS = getSCEV(LHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType()) && + isa(RHS) && cast(RHS)->isZero()) { + const SCEV *One = getConstant(U->getType(), 1); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, One); @@ -7011,8 +7013,8 @@ return false; } -// Verify if an linear IV with positive stride can overflow when in a -// less-than comparison, knowing the invariant term of the comparison, the +// Verify if an linear IV with positive stride can overflow when in a +// less-than comparison, knowing the invariant term of the comparison, the // stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { @@ -7040,7 +7042,7 @@ return (MaxValue - MaxStrideMinusOne).ult(MaxRHS); } -// Verify if an linear IV with negative stride can overflow when in a +// Verify if an linear IV with negative stride can overflow when in a // greater-than comparison, knowing the invariant term of the comparison, // the stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, @@ -7071,7 +7073,7 @@ // Compute the backedge taken count knowing the interval difference, the // stride and presence of the equality in the comparison. -const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, +const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, bool Equality) { const SCEV *One = getConstant(Step->getType(), 1); Delta = Equality ? getAddExpr(Delta, Step) @@ -7111,7 +7113,7 @@ // Avoid proven overflow cases: this will ensure that the backedge taken count // will not generate any unsigned overflow. Relaxed no-overflow conditions - // exploit NoWrapFlags, allowing to optimize in presence of undefined + // exploit NoWrapFlags, allowing to optimize in presence of undefined // behaviors like the case of C language. if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) return getCouldNotCompute(); @@ -7191,7 +7193,7 @@ // Avoid proven overflow cases: this will ensure that the backedge taken count // will not generate any unsigned overflow. Relaxed no-overflow conditions - // exploit NoWrapFlags, allowing to optimize in presence of undefined + // exploit NoWrapFlags, allowing to optimize in presence of undefined // behaviors like the case of C language. if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap)) return getCouldNotCompute(); @@ -7239,7 +7241,7 @@ if (isa(BECount)) MaxBECount = BECount; else - MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), + MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), getConstant(MinStride), false); if (isa(MaxBECount)) Index: test/Analysis/ScalarEvolution/min-max-exprs.ll =================================================================== --- /dev/null +++ test/Analysis/ScalarEvolution/min-max-exprs.ll @@ -0,0 +1,53 @@ +; RUN: opt -scalar-evolution -analyze < %s | FileCheck %s +; +; This checks if the min and max expressions are properly recognized by +; ScalarEvolution even though they the ICmpInst and SelectInst have different +; types. +; +; #define max(a, b) (a > b ? a : b) +; #define min(a, b) (a < b ? a : b) +; +; void f(int *A, int N) { +; for (int i = 0; i < N; i++) { +; A[max(0, i - 3)] = A[min(N, i + 3)] * 2; +; } +; } +; +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" + +define void @f(i32* %A, i32 %N) { +bb: + br label %bb1 + +bb1: ; preds = %bb2, %bb + %i.0 = phi i32 [ 0, %bb ], [ %tmp23, %bb2 ] + %i.0.1 = sext i32 %i.0 to i64 + %tmp = icmp slt i32 %i.0, %N + br i1 %tmp, label %bb2, label %bb24 + +bb2: ; preds = %bb1 + %tmp3 = add nuw nsw i32 %i.0, 3 + %tmp4 = icmp slt i32 %tmp3, %N + %tmp5 = sext i32 %tmp3 to i64 + %tmp6 = sext i32 %N to i64 + %tmp9 = select i1 %tmp4, i64 %tmp5, i64 %tmp6 +; min(N, i+3) +; CHECK: select i1 %tmp4, i64 %tmp5, i64 %tmp6 +; CHECK-NEXT: --> (-1 + (-1 * ((-1 + (-1 * (sext i32 {3,+,1}<%bb1> to i64))) smax (-1 + (-1 * (sext i32 %N to i64)))))) + %tmp11 = getelementptr inbounds i32* %A, i64 %tmp9 + %tmp12 = load i32* %tmp11, align 4 + %tmp13 = shl nsw i32 %tmp12, 1 + %tmp14 = icmp sge i32 3, %i.0 + %tmp17 = add nsw i64 %i.0.1, -3 + %tmp19 = select i1 %tmp14, i64 0, i64 %tmp17 +; max(0, i - 3) +; CHECK: select i1 %tmp14, i64 0, i64 %tmp17 +; CHECK-NEXT: --> (-3 + (3 smax {0,+,1}<%bb1>)) + %tmp21 = getelementptr inbounds i32* %A, i64 %tmp19 + store i32 %tmp13, i32* %tmp21, align 4 + %tmp23 = add nuw nsw i32 %i.0, 1 + br label %bb1 + +bb24: ; preds = %bb1 + ret void +}