Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -978,6 +978,20 @@ /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is + /// true. Here LHS is an operation that includes FoundLHS as one of its + /// arguments. + bool isImpliedViaOperations(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS); + + /// Test whether the condition described by Pred, LHS, and RHS is true. + /// Use only simple non-recursive types of checks, such as range analysis etc. + bool isKnownViaSimpleReasoning(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS); + + /// Test whether the condition described by Pred, LHS, and RHS is true + /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. bool isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -8491,19 +8491,115 @@ llvm_unreachable("covered switch fell through?!"); } -bool -ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, +bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { - auto IsKnownPredicateFull = - [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { - return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || - IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || - IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || - isKnownPredicateViaNoOverflow(Pred, LHS, RHS); + // We only want to work with ICMP_SGT comparison so far. + // TODO: Extend to ICMP_UGT? + if (Pred == ICmpInst::ICMP_SLT) { + Pred = ICmpInst::ICMP_SGT; + std::swap(LHS, RHS); + std::swap(FoundLHS, FoundRHS); + } + if (Pred != ICmpInst::ICMP_SGT) + return false; + + auto GetOpFromExt = [&](const SCEV *S) { + if (auto *Ext = dyn_cast(S)) + return Ext->getOperand(); + return S; }; + // Acquire values from extensions. + auto *OrigLHS = LHS; + auto *OrigFoundLHS = FoundLHS; + LHS = GetOpFromExt(LHS); + FoundLHS = GetOpFromExt(FoundLHS); + + if (auto *LHSAddExpr = dyn_cast(LHS)) { + // Should not overflow. + if (!LHSAddExpr->hasNoSignedWrap()) + return false; + auto *LL = LHSAddExpr->getOperand(0); + auto *LR = LHSAddExpr->getOperand(1); + auto *Zero = getZero(RHS->getType()); + auto *LLExt = getNoopOrSignExtend(LL, OrigLHS->getType()); + auto *LRExt = getNoopOrSignExtend(LR, OrigLHS->getType()); + + auto AdditionRule = [&](const SCEV *S1, const SCEV *S2) { + if (isKnownNonNegative(S1) || + isImpliedCondOperandsHelper(ICmpInst::ICMP_SGE, S1, Zero, + OrigFoundLHS, FoundRHS)) + if (isKnownViaSimpleReasoning(Pred, S2, RHS) || + isImpliedCondOperandsHelper(Pred, S2, RHS, OrigFoundLHS, FoundRHS)) + return true; + return false; + }; + // Try to prove the following rule: + // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS). + // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS). + if (AdditionRule(LLExt, LRExt) || AdditionRule(LRExt, LLExt)) + return true; + } else if (auto *LHSUnknownExpr = dyn_cast(LHS)) { + Value *LL, *LR; + // FIXME: Once we have SDiv implemented, we can get rid of this matching. + using namespace llvm::PatternMatch; + if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) { + auto MaxType = [&](const SCEV *S1, const SCEV *S2) { + auto Size1 = getTypeSizeInBits(S1->getType()); + auto Size2 = getTypeSizeInBits(S2->getType()); + return Size1 > Size2 ? S1->getType() : S2->getType(); + }; + // Rules for division. + auto *Num = getSCEV(LL); + auto *Denum = getSCEV(LR); + auto *Ty1 = MaxType(Num, FoundLHS); + auto *NumExt = getNoopOrSignExtend(Num, Ty1); + auto *FoundLHSExt = getNoopOrSignExtend(FoundLHS, Ty1); + + if (!HasSameValue(NumExt, FoundLHSExt) || !isKnownPositive(Denum)) + return false; + + // Given that FoundLHS > FoundRHS, LHS = FoundLHS / Denum, Denum > 0. + auto *Ty2 = MaxType(Denum, FoundRHS); + auto *DenumExt = getNoopOrSignExtend(Denum, Ty2); + auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, Ty2); + + // Try to prove the following rule: + // (Denum <= FoundRHS + 1) && (RHS <= 0) => (LHS > RHS). + auto *One = getOne(Ty2); + auto *Next = getAddExpr(FoundRHSExt, One); + if (isKnownNonPositive(RHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, DenumExt, Next)) + return true; + + // Try to prove the following rule: + // (-Denum <= FoundRHS) && (RHS < 0) => (LHS > RHS). + auto *NegDenum = getNegativeSCEV(DenumExt); + if (isKnownNegative(RHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, NegDenum, FoundRHSExt)) + return true; + } + } + + return false; +} + +bool +ScalarEvolution::isKnownViaSimpleReasoning(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || + IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || + IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || + isKnownPredicateViaNoOverflow(Pred, LHS, RHS); +} + +bool +ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: @@ -8513,30 +8609,34 @@ break; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: - if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS)) return true; break; } + // Maybe it can be proved via operations? + if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + return false; } Index: test/Analysis/ScalarEvolution/scev-division.ll =================================================================== --- /dev/null +++ test/Analysis/ScalarEvolution/scev-division.ll @@ -0,0 +1,334 @@ +; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s + +declare void @llvm.experimental.guard(i1, ...) + +define void @test01(i32 %a, i32 %n) nounwind { +; Prove that (n > 1) ===> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @test01 +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + %n.div.2) +entry: + %cmp1 = icmp sgt i32 %n, 1 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sgt i32 %n.div.2, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test01neg(i32 %a, i32 %n) nounwind { +; Prove that (n > 0) =\=> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @test01neg +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (1 smax %n.div.2)) +entry: + %cmp1 = icmp sgt i32 %n, 0 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sgt i32 %n.div.2, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test02(i32 %a, i32 %n) nounwind { +; Prove that (n >= 2) ===> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @test02 +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + %n.div.2) +entry: + %cmp1 = icmp sge i32 %n, 2 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sgt i32 %n.div.2, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test02neg(i32 %a, i32 %n) nounwind { +; Prove that (n >= 1) =\=> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @test02neg +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (1 smax %n.div.2)) +entry: + %cmp1 = icmp sge i32 %n, 1 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sgt i32 %n.div.2, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test03(i32 %a, i32 %n) nounwind { +; Prove that (n > -2) ===> (n / 2 >= 0). +; TODO: We should be able to prove that (n > -2) ===> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @test03 +; CHECK: Loop %header: Predicated backedge-taken count is (1 + %n.div.2) +entry: + %cmp1 = icmp sgt i32 %n, -2 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sge i32 %n.div.2, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test03neg(i32 %a, i32 %n) nounwind { +; Prove that (n > -3) =\=> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @test03neg +; CHECK: Loop %header: Predicated backedge-taken count is (0 smax (1 + %n.div.2)) +entry: + %cmp1 = icmp sgt i32 %n, -3 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sge i32 %n.div.2, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test04(i32 %a, i32 %n) nounwind { +; Prove that (n >= -1) ===> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @test04 +; CHECK: Loop %header: Predicated backedge-taken count is (1 + %n.div.2) +entry: + %cmp1 = icmp sge i32 %n, -1 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sge i32 %n.div.2, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test04neg(i32 %a, i32 %n) nounwind { +; Prove that (n >= -2) =\=> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @test04neg +; CHECK: Loop %header: Predicated backedge-taken count is (0 smax (1 + %n.div.2)) +entry: + %cmp1 = icmp sge i32 %n, -2 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sge i32 %n.div.2, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext01(i32 %a, i32 %n) nounwind { +; Prove that (n > 1) ===> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @testext01 +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (sext i32 %n.div.2 to i64)) +entry: + %cmp1 = icmp sgt i32 %n, 1 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sgt i64 %n.div.2.ext, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext01neg(i32 %a, i32 %n) nounwind { +; Prove that (n > 0) =\=> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @testext01neg +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (1 smax (sext i32 %n.div.2 to i64))) +entry: + %cmp1 = icmp sgt i32 %n, 0 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sgt i64 %n.div.2.ext, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext02(i32 %a, i32 %n) nounwind { +; Prove that (n >= 2) ===> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @testext02 +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (sext i32 %n.div.2 to i64)) +entry: + %cmp1 = icmp sge i32 %n, 2 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sgt i64 %n.div.2.ext, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext02neg(i32 %a, i32 %n) nounwind { +; Prove that (n >= 1) =\=> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @testext02neg +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (1 smax (sext i32 %n.div.2 to i64))) +entry: + %cmp1 = icmp sge i32 %n, 1 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sgt i64 %n.div.2.ext, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext03(i32 %a, i32 %n) nounwind { +; Prove that (n > -2) ===> (n / 2 >= 0). +; TODO: We should be able to prove that (n > -2) ===> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @testext03 +; CHECK: Loop %header: Predicated backedge-taken count is (1 + (sext i32 %n.div.2 to i64)) +entry: + %cmp1 = icmp sgt i32 %n, -2 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sge i64 %n.div.2.ext, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext03neg(i32 %a, i32 %n) nounwind { +; Prove that (n > -3) =\=> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @testext03neg +; CHECK: Loop %header: Predicated backedge-taken count is (0 smax (1 + (sext i32 %n.div.2 to i64))) +entry: + %cmp1 = icmp sgt i32 %n, -3 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sge i64 %n.div.2.ext, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext04(i32 %a, i32 %n) nounwind { +; Prove that (n >= -1) ===> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @testext04 +; CHECK: Loop %header: Predicated backedge-taken count is (1 + (sext i32 %n.div.2 to i64)) +entry: + %cmp1 = icmp sge i32 %n, -1 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sge i64 %n.div.2.ext, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext04neg(i32 %a, i32 %n) nounwind { +; Prove that (n >= -2) =\=> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @testext04neg +; CHECK: Loop %header: Predicated backedge-taken count is (0 smax (1 + (sext i32 %n.div.2 to i64))) +entry: + %cmp1 = icmp sge i32 %n, -2 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sge i64 %n.div.2.ext, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} +