Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -978,6 +978,20 @@ /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is + /// true. Here LHS is an operation that includes FoundLHS as one of its + /// arguments. + bool isImpliedViaOperations(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS); + + /// Test whether the condition described by Pred, LHS, and RHS is true. + /// Use only simple non-recursive types of checks, such as range analysis etc. + bool isKnownViaSimpleReasoning(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS); + + /// Test whether the condition described by Pred, LHS, and RHS is true + /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. bool isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, @@ -1123,6 +1137,9 @@ /// return true. For pointer types, this is the pointer-sized integer type. Type *getEffectiveSCEVType(Type *Ty) const; + // Returns a wider type among {Ty1, Ty2}. + Type *getWiderType(Type *Ty1, Type *Ty2) const; + /// Return true if the SCEV is a scAddRecExpr or it contains /// scAddRecExpr. The result will be cached in HasRecMap. /// Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -3418,6 +3418,10 @@ return getDataLayout().getIntPtrType(Ty); } +Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const { + return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2; +} + const SCEV *ScalarEvolution::getCouldNotCompute() { return CouldNotCompute.get(); } @@ -8491,19 +8495,126 @@ llvm_unreachable("covered switch fell through?!"); } -bool -ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, +bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { - auto IsKnownPredicateFull = - [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { - return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || - IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || - IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || - isKnownPredicateViaNoOverflow(Pred, LHS, RHS); + // We only want to work with ICMP_SGT comparison so far. + // TODO: Extend to ICMP_UGT? + if (Pred == ICmpInst::ICMP_SLT) { + Pred = ICmpInst::ICMP_SGT; + std::swap(LHS, RHS); + std::swap(FoundLHS, FoundRHS); + } + if (Pred != ICmpInst::ICMP_SGT) + return false; + + auto GetOpFromExt = [&](const SCEV *S) { + if (auto *Ext = dyn_cast(S)) + return Ext->getOperand(); + return S; + }; + + // Acquire values from extensions. + auto *OrigFoundLHS = FoundLHS; + LHS = GetOpFromExt(LHS); + FoundLHS = GetOpFromExt(FoundLHS); + + // Is a predicate can be proved trivially or using the found context. + auto IsProvedViaContext = [&](ICmpInst::Predicate Pred, + const SCEV *S1, const SCEV *S2) { + return isKnownViaSimpleReasoning(Pred, S1, S2) || + isImpliedCondOperandsHelper(Pred, S1, S2, OrigFoundLHS, FoundRHS); }; + if (auto *LHSAddExpr = dyn_cast(LHS)) { + // Should not overflow. + if (!LHSAddExpr->hasNoSignedWrap()) + return false; + auto *LL = LHSAddExpr->getOperand(0); + auto *LR = LHSAddExpr->getOperand(1); + + // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context. + auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) { + if (IsProvedViaContext(ICmpInst::ICMP_SGE, S1, getZero(RHS->getType()))) + if (IsProvedViaContext(Pred, S2, RHS)) + return true; + return false; + }; + // Try to prove the following rule: + // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS). + // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS). + if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL)) + return true; + } else if (auto *LHSUnknownExpr = dyn_cast(LHS)) { + Value *LL, *LR; + // FIXME: Once we have SDiv implemented, we can get rid of this matching. + using namespace llvm::PatternMatch; + if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) { + // Rules for division. + // We are going to perform some comparisons with Denum and its derivative + // expressions. In general case, creating a SCEV for it may lead to a + // complex analysis of the entire graph, and in particular it can request + // trip count recalculation for the same loop. This would cache as + // SCEVCouldNotCompute to avoid the infinite recursion. This is a sad + // thing. To avoid this, we only want to create SCEVs that are constants + // in this section. So we bail if denum is not a constant. + if (!isa(LR)) + return false; + + auto *Denum = getSCEV(LR); + assert(isa(Denum) && "Denumerator must be a constant!"); + + // We want to make sure that LHS = FoundLHS / Denum. If it is so, then a + // SCEV for the numerator already exists and matches with FoundLHS. + auto *Num = getExistingSCEV(LL); + + // Make sure that it exists and has the same type. + if (!Num || Num->getType() != FoundLHS->getType()) + return false; + + // Make sure that num matches with FoundLHs and denum is positive. + if (!HasSameValue(Num, FoundLHS) || !isKnownPositive(Denum)) + return false; + + // Given that FoundLHS > FoundRHS, LHS = FoundLHS / Denum, Denum > 0. + auto *Ty2 = getWiderType(Denum->getType(), FoundRHS->getType()); + auto *DenumExt = getNoopOrSignExtend(Denum, Ty2); + auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, Ty2); + + // Try to prove the following rule: + // (Denum <= FoundRHS + 1) && (RHS <= 0) => (LHS > RHS). + auto *Next = getAddExpr(FoundRHSExt, getOne(Ty2)); + if (isKnownNonPositive(RHS) && + IsProvedViaContext(ICmpInst::ICMP_SLE, DenumExt, Next)) + return true; + + // Try to prove the following rule: + // (-Denum <= FoundRHS) && (RHS < 0) => (LHS > RHS). + auto *NegDenum = getNegativeSCEV(DenumExt); + if (isKnownNegative(RHS) && + IsProvedViaContext(ICmpInst::ICMP_SLE, NegDenum, FoundRHSExt)) + return true; + } + } + + return false; +} + +bool +ScalarEvolution::isKnownViaSimpleReasoning(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || + IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || + IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || + isKnownPredicateViaNoOverflow(Pred, LHS, RHS); +} + +bool +ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: @@ -8513,30 +8624,34 @@ break; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: - if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS)) return true; break; } + // Maybe it can be proved via operations? + if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + return false; } Index: test/Analysis/ScalarEvolution/scev-division.ll =================================================================== --- /dev/null +++ test/Analysis/ScalarEvolution/scev-division.ll @@ -0,0 +1,334 @@ +; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s + +declare void @llvm.experimental.guard(i1, ...) + +define void @test01(i32 %a, i32 %n) nounwind { +; Prove that (n > 1) ===> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @test01 +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + %n.div.2) +entry: + %cmp1 = icmp sgt i32 %n, 1 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sgt i32 %n.div.2, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test01neg(i32 %a, i32 %n) nounwind { +; Prove that (n > 0) =\=> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @test01neg +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (1 smax %n.div.2)) +entry: + %cmp1 = icmp sgt i32 %n, 0 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sgt i32 %n.div.2, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test02(i32 %a, i32 %n) nounwind { +; Prove that (n >= 2) ===> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @test02 +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + %n.div.2) +entry: + %cmp1 = icmp sge i32 %n, 2 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sgt i32 %n.div.2, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test02neg(i32 %a, i32 %n) nounwind { +; Prove that (n >= 1) =\=> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @test02neg +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (1 smax %n.div.2)) +entry: + %cmp1 = icmp sge i32 %n, 1 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sgt i32 %n.div.2, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test03(i32 %a, i32 %n) nounwind { +; Prove that (n > -2) ===> (n / 2 >= 0). +; TODO: We should be able to prove that (n > -2) ===> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @test03 +; CHECK: Loop %header: Predicated backedge-taken count is (1 + %n.div.2) +entry: + %cmp1 = icmp sgt i32 %n, -2 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sge i32 %n.div.2, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test03neg(i32 %a, i32 %n) nounwind { +; Prove that (n > -3) =\=> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @test03neg +; CHECK: Loop %header: Predicated backedge-taken count is (0 smax (1 + %n.div.2)) +entry: + %cmp1 = icmp sgt i32 %n, -3 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sge i32 %n.div.2, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test04(i32 %a, i32 %n) nounwind { +; Prove that (n >= -1) ===> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @test04 +; CHECK: Loop %header: Predicated backedge-taken count is (1 + %n.div.2) +entry: + %cmp1 = icmp sge i32 %n, -1 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sge i32 %n.div.2, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test04neg(i32 %a, i32 %n) nounwind { +; Prove that (n >= -2) =\=> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @test04neg +; CHECK: Loop %header: Predicated backedge-taken count is (0 smax (1 + %n.div.2)) +entry: + %cmp1 = icmp sge i32 %n, -2 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sge i32 %n.div.2, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext01(i32 %a, i32 %n) nounwind { +; Prove that (n > 1) ===> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @testext01 +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (sext i32 %n.div.2 to i64)) +entry: + %cmp1 = icmp sgt i32 %n, 1 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sgt i64 %n.div.2.ext, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext01neg(i32 %a, i32 %n) nounwind { +; Prove that (n > 0) =\=> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @testext01neg +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (1 smax (sext i32 %n.div.2 to i64))) +entry: + %cmp1 = icmp sgt i32 %n, 0 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sgt i64 %n.div.2.ext, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext02(i32 %a, i32 %n) nounwind { +; Prove that (n >= 2) ===> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @testext02 +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (sext i32 %n.div.2 to i64)) +entry: + %cmp1 = icmp sge i32 %n, 2 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sgt i64 %n.div.2.ext, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext02neg(i32 %a, i32 %n) nounwind { +; Prove that (n >= 1) =\=> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @testext02neg +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (1 smax (sext i32 %n.div.2 to i64))) +entry: + %cmp1 = icmp sge i32 %n, 1 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sgt i64 %n.div.2.ext, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext03(i32 %a, i32 %n) nounwind { +; Prove that (n > -2) ===> (n / 2 >= 0). +; TODO: We should be able to prove that (n > -2) ===> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @testext03 +; CHECK: Loop %header: Predicated backedge-taken count is (1 + (sext i32 %n.div.2 to i64)) +entry: + %cmp1 = icmp sgt i32 %n, -2 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sge i64 %n.div.2.ext, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext03neg(i32 %a, i32 %n) nounwind { +; Prove that (n > -3) =\=> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @testext03neg +; CHECK: Loop %header: Predicated backedge-taken count is (0 smax (1 + (sext i32 %n.div.2 to i64))) +entry: + %cmp1 = icmp sgt i32 %n, -3 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sge i64 %n.div.2.ext, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext04(i32 %a, i32 %n) nounwind { +; Prove that (n >= -1) ===> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @testext04 +; CHECK: Loop %header: Predicated backedge-taken count is (1 + (sext i32 %n.div.2 to i64)) +entry: + %cmp1 = icmp sge i32 %n, -1 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sge i64 %n.div.2.ext, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext04neg(i32 %a, i32 %n) nounwind { +; Prove that (n >= -2) =\=> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @testext04neg +; CHECK: Loop %header: Predicated backedge-taken count is (0 smax (1 + (sext i32 %n.div.2 to i64))) +entry: + %cmp1 = icmp sge i32 %n, -2 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sge i64 %n.div.2.ext, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} +