Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -529,6 +529,17 @@ const SCEV *FoundLHS, const SCEV *FoundRHS); + /// Test whether the condition described by Pred, LHS, and RHS is true + /// whenever the condition described by Pred, FoundLHS, and FoundRHS is + /// true. + /// + /// This routine tries to rule out certain kinds of integer overflow, and + /// then tries to reason about arithmetic properties of the predicates. + bool isImpliedCondOperandsViaNoOverflow(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS); + /// If we know that the specified Phi is in the header of its containing /// loop, we know the loop executes a constant number of times, and the PHI /// node is just a recurrence involving constants, fold it. Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -7280,6 +7280,146 @@ return false; } +// Return true if More == (Less + C), where C is a constant. +static bool IsConstDiff(ScalarEvolution &SE, const SCEV *Less, const SCEV *More, + APInt &C) { + // We avoid subtracting expressions here because this function is usually + // fairly deep in the call stack (i.e. is called many times). + + auto SplitBinaryAdd = [](const SCEV *Expr, const SCEV *&L, const SCEV *&R) { + const auto *AE = dyn_cast(Expr); + if (!AE || AE->getNumOperands() != 2) + return false; + + L = AE->getOperand(0); + R = AE->getOperand(1); + return true; + }; + + if (isa(Less) && isa(More)) { + const auto *LAR = cast(Less); + const auto *MAR = cast(More); + + if (LAR->getLoop() != MAR->getLoop()) + return false; + + // We look at affine expressions only; not for correctness but to keep + // getStepRecurrence cheap. + if (!LAR->isAffine() || !MAR->isAffine()) + return false; + + if (LAR->getStepRecurrence(SE) != MAR->getStepRecurrence(SE)) + return false; + + Less = LAR->getStart(); + More = MAR->getStart(); + + // fall through + } + + if (isa(Less) && isa(More)) { + const auto &M = cast(More)->getValue()->getValue(); + const auto &L = cast(Less)->getValue()->getValue(); + C = M - L; + return true; + } + + const SCEV *L, *R; + if (SplitBinaryAdd(Less, L, R)) + if (const auto *LC = dyn_cast(L)) + if (R == More) { + C = -(LC->getValue()->getValue()); + return true; + } + + if (SplitBinaryAdd(More, L, R)) + if (const auto *LC = dyn_cast(L)) + if (R == Less) { + C = LC->getValue()->getValue(); + return true; + } + + return false; +} + +bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, const SCEV *FoundRHS) { + if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT) + return false; + + const auto *AddRecLHS = dyn_cast(LHS); + if (!AddRecLHS) + return false; + + const auto *AddRecFoundLHS = dyn_cast(FoundLHS); + if (!AddRecFoundLHS) + return false; + + // We'd like to let SCEV reason about control dependencies, so we constrain + // both the inequalities to be about add recurrences on the same loop. This + // way we can use isLoopEntryGuardedByCond later. + + const Loop *L = AddRecFoundLHS->getLoop(); + if (L != AddRecLHS->getLoop()) + return false; + + // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1) + // + // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C) + // ... (2) + // + // Informal proof for (2), assuming (1) [*]: + // + // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**] + // + // Then + // + // FoundLHS s< FoundRHS s< INT_MIN - C + // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ] + // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ] + // <=> (FoundLHS + INT_MIN + C + INT_MIN) s< + // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ] + // <=> FoundLHS + C s< FoundRHS + C + // + // [*]: (1) can be proved by ruling out overflow. + // + // [**]: This can be proved by analyzing all the four possibilities: + // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and + // (A s>= 0, B s>= 0). + // + // Note: + // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C" + // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS + // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS + // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is + // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS + + // C)". + + APInt LDiff, RDiff; + if (!IsConstDiff(*this, FoundLHS, LHS, LDiff) || + !IsConstDiff(*this, FoundRHS, RHS, RDiff) || + LDiff != RDiff) + return false; + + if (LDiff == 0) + return true; + + unsigned Width = cast(RHS->getType())->getBitWidth(); + APInt FoundRHSLimit; + + if (Pred == CmpInst::ICMP_ULT) { + FoundRHSLimit = -RDiff; + } else { + assert(Pred == CmpInst::ICMP_SLT && "Checked above!"); + FoundRHSLimit = APInt::getSignedMinValue(Width) - RDiff; + } + + // Try to prove (1) or (2), as needed. + return isLoopEntryGuardedByCond(L, Pred, FoundRHS, + getConstant(FoundRHSLimit)); +} + /// isImpliedCondOperands - Test whether the condition described by Pred, /// LHS, and RHS is true whenever the condition described by Pred, FoundLHS, /// and FoundRHS is true. @@ -7290,6 +7430,9 @@ if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS)) return true; + if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + return isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS) || // ~x < ~y --> x > y Index: test/Transforms/IndVarSimplify/eliminate-comparison.ll =================================================================== --- test/Transforms/IndVarSimplify/eliminate-comparison.ll +++ test/Transforms/IndVarSimplify/eliminate-comparison.ll @@ -209,3 +209,152 @@ unrolledend: ; preds = %forcond38 ret i32 0 } + +declare void @side_effect() + +define void @func_13(i32* %len.ptr) { +; CHECK-LABEL: @func_13( + entry: + %len = load i32, i32* %len.ptr, !range !0 + %len.sub.1 = add i32 %len, -1 + %len.is.zero = icmp eq i32 %len, 0 + br i1 %len.is.zero, label %leave, label %loop + + loop: +; CHECK: loop: + %iv = phi i32 [ 0, %entry ], [ %iv.inc, %be ] + call void @side_effect() + %iv.inc = add i32 %iv, 1 + %iv.cmp = icmp ult i32 %iv, %len + br i1 %iv.cmp, label %be, label %leave +; CHECK: br i1 true, label %be, label %leave + + be: + call void @side_effect() + %be.cond = icmp ult i32 %iv, %len.sub.1 + br i1 %be.cond, label %loop, label %leave + + leave: + ret void +} + +define void @func_14(i32* %len.ptr) { +; CHECK-LABEL: @func_14( + entry: + %len = load i32, i32* %len.ptr, !range !0 + %len.sub.1 = add i32 %len, -1 + %len.is.zero = icmp eq i32 %len, 0 + %len.is.int_min = icmp eq i32 %len, 2147483648 + %no.entry = or i1 %len.is.zero, %len.is.int_min + br i1 %no.entry, label %leave, label %loop + + loop: +; CHECK: loop: + %iv = phi i32 [ 0, %entry ], [ %iv.inc, %be ] + call void @side_effect() + %iv.inc = add i32 %iv, 1 + %iv.cmp = icmp slt i32 %iv, %len + br i1 %iv.cmp, label %be, label %leave +; CHECK: br i1 true, label %be, label %leave + + be: + call void @side_effect() + %be.cond = icmp slt i32 %iv, %len.sub.1 + br i1 %be.cond, label %loop, label %leave + + leave: + ret void +} + +define void @func_15(i32* %len.ptr) { +; CHECK-LABEL: @func_15( + entry: + %len = load i32, i32* %len.ptr, !range !0 + %len.add.1 = add i32 %len, 1 + %len.add.1.is.zero = icmp eq i32 %len.add.1, 0 + br i1 %len.add.1.is.zero, label %leave, label %loop + + loop: +; CHECK: loop: + %iv = phi i32 [ 0, %entry ], [ %iv.inc, %be ] + call void @side_effect() + %iv.inc = add i32 %iv, 1 + %iv.cmp = icmp ult i32 %iv, %len.add.1 + br i1 %iv.cmp, label %be, label %leave +; CHECK: br i1 true, label %be, label %leave + + be: + call void @side_effect() + %be.cond = icmp ult i32 %iv, %len + br i1 %be.cond, label %loop, label %leave + + leave: + ret void +} + +define void @func_16(i32* %len.ptr) { +; CHECK-LABEL: @func_16( + entry: + %len = load i32, i32* %len.ptr, !range !0 + %len.add.5 = add i32 %len, 5 + %entry.cond.0 = icmp slt i32 %len, 2147483643 + %entry.cond.1 = icmp slt i32 4, %len.add.5 + %entry.cond = and i1 %entry.cond.0, %entry.cond.1 + br i1 %entry.cond, label %loop, label %leave + + loop: +; CHECK: loop: + %iv = phi i32 [ 0, %entry ], [ %iv.inc, %be ] + call void @side_effect() + %iv.inc = add i32 %iv, 1 + %iv.add.4 = add i32 %iv, 4 + %iv.cmp = icmp slt i32 %iv.add.4, %len.add.5 + br i1 %iv.cmp, label %be, label %leave +; CHECK: br i1 true, label %be, label %leave + + be: + call void @side_effect() + %be.cond = icmp slt i32 %iv, %len + br i1 %be.cond, label %loop, label %leave + + leave: + ret void +} + +define void @func_17(i32* %len.ptr) { +; CHECK-LABEL: @func_17( + entry: + %len = load i32, i32* %len.ptr + %len.add.5 = add i32 %len, -5 + %entry.cond.0 = icmp slt i32 %len, 2147483653 ;; 2147483653 == INT_MIN - (-5) + %entry.cond.1 = icmp slt i32 -6, %len.add.5 + %entry.cond = and i1 %entry.cond.0, %entry.cond.1 + br i1 %entry.cond, label %loop, label %leave + + loop: +; CHECK: loop: + %iv.2 = phi i32 [ 0, %entry ], [ %iv.2.inc, %be ] + %iv = phi i32 [ -6, %entry ], [ %iv.inc, %be ] + call void @side_effect() + %iv.inc = add i32 %iv, 1 + %iv.2.inc = add i32 %iv.2, 1 + %iv.cmp = icmp slt i32 %iv, %len.add.5 + +; Deduces {-5,+,1} s< (-5 + %len) from {0,+,1} < %len +; since %len s< INT_MIN - (-5) from the entry condition + +; CHECK: br i1 true, label %be, label %leave + br i1 %iv.cmp, label %be, label %leave + + be: +; CHECK: be: + call void @side_effect() + %be.cond = icmp slt i32 %iv.2, %len + br i1 %be.cond, label %loop, label %leave + + leave: + ret void +} + +!0 = !{i32 0, i32 2147483647} +