Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -978,6 +978,15 @@ /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is + /// true. Here LHS is an operation that includes FoundLHS as one of its + /// arguments. + bool isImpliedViaOperations(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS); + + /// Test whether the condition described by Pred, LHS, and RHS is true + /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. bool isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -8491,6 +8491,111 @@ llvm_unreachable("covered switch fell through?!"); } +bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + // We only want to work with ICMP_SGT comparison so far. + // TODO: Extend to ICMP_UGT? + if (Pred == ICmpInst::ICMP_SLT) + return isImpliedViaOperations(ICmpInst::ICMP_SGT, RHS, LHS, FoundRHS, + FoundLHS); + if (Pred != ICmpInst::ICMP_SGT) + return false; + + auto GetOpFromExt = [&](const SCEV *S) { + if (auto *Ext = dyn_cast(S)) + return Ext->getOperand(); + return S; + }; + + auto ExtendType = [&](const SCEV *S, Type *T) { + if (S->getType() == T) + return S; + return getSignExtendExpr(S, T); + }; + + auto MaxType = [&](const SCEV *S1, const SCEV *S2) { + return S1->getType()->getIntegerBitWidth() > + S2->getType()->getIntegerBitWidth() + ? S1->getType() + : S2->getType(); + }; + + // Acquire values from extensions. + auto *OrigLHS = LHS; + auto *OrigFoundLHS = FoundLHS; + LHS = GetOpFromExt(LHS); + FoundLHS = GetOpFromExt(FoundLHS); + + if (auto *Operation = dyn_cast(LHS)) { + // Should not overflow. + if (!Operation->hasNoSignedWrap()) + return false; + auto *Op1 = Operation->getOperand(0); + auto *Op2 = Operation->getOperand(1); + // Try to prove the following rule: + // (LHS = Op1 + Op2) && (Op1 >= 0 || Op2 >= 0) && (Op1 > RHS) && + // (Op2 > RHS) => (LHS > RHS). + auto *ZeroRHS = getConstant(RHS->getType(), 0u, true); + auto *Op1Ext = ExtendType(Op1, OrigLHS->getType()); + auto *Op2Ext = ExtendType(Op2, OrigLHS->getType()); + // At least one of the operands should be proved to be non-negative. + // We try to prove it trivially or using the known predicate. + if (isKnownNonNegative(Op1) || isKnownNonNegative(Op2) || + isImpliedCondOperandsHelper(ICmpInst::ICMP_SGE, Op1Ext, ZeroRHS, + OrigFoundLHS, FoundRHS) || + isImpliedCondOperandsHelper(ICmpInst::ICMP_SGE, Op2Ext, ZeroRHS, + OrigFoundLHS, FoundRHS)) + // Now we know that at least one of the operands in non-negative. It + // implies that Op1 + Op2 >= min(Op1, Op2). If we prove that + // (Op1 > RHS) && (Op2 > RHS), then (Op1 + Op2) >= min(Op1, Op2) > RHS. + return (isKnownPredicate(Pred, Op1Ext, RHS) || + isImpliedCondOperandsHelper(Pred, Op1Ext, RHS, OrigFoundLHS, + FoundRHS)) && + (isKnownPredicate(Pred, Op2Ext, RHS) || + isImpliedCondOperandsHelper(Pred, Op2Ext, RHS, OrigFoundLHS, + FoundRHS)); + } else if (auto *Operation = dyn_cast(LHS)) { + Value *Op1, *Op2; + // FIXME: Once we have SDiv implemented, we can get rid of this matching. + using namespace llvm::PatternMatch; + if (match(Operation->getValue(), m_SDiv(m_Value(Op1), m_Value(Op2)))) { + // Rules for division. + auto *Num = getSCEV(Op1); + auto *Denum = getSCEV(Op2); + auto *Ty1 = MaxType(Num, FoundLHS); + auto *NumExt = ExtendType(Num, Ty1); + auto *FoundLHSExt = ExtendType(FoundLHS, Ty1); + + if (!HasSameValue(NumExt, FoundLHSExt) || !isKnownPositive(Denum)) + return false; + + // Given that FoundLHS > FoundRHS, LHS = FoundLHS / Denum, Denum > 0. + auto *Ty2 = MaxType(Denum, FoundRHS); + auto *DenumExt = ExtendType(Denum, Ty2); + auto *FoundRHSExt = ExtendType(FoundRHS, Ty2); + + // Try to prove the following rule: + // (Denum <= FoundRHS + 1) && (RHS <= 0) => (LHS > RHS). + auto *One = getConstant(Ty2, 1u, true); + auto *Next = getAddExpr(FoundRHSExt, One); + if (isKnownNonPositive(RHS) && + isKnownPredicate(ICmpInst::ICMP_SLE, DenumExt, Next)) + return true; + + // Try to prove the following rule: + // (-Denum <= FoundRHS) && (RHS < 0) => (LHS > RHS). + auto *NegDenum = getNegativeSCEV(DenumExt); + if (isKnownNegative(RHS) && + isKnownPredicate(ICmpInst::ICMP_SLE, NegDenum, FoundRHSExt)) + return true; + } + } + + return false; +} + bool ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, @@ -8537,6 +8642,10 @@ break; } + // Maybe it can be proved via operations? + if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + return false; } Index: test/Analysis/ScalarEvolution/scev-division.ll =================================================================== --- test/Analysis/ScalarEvolution/scev-division.ll +++ test/Analysis/ScalarEvolution/scev-division.ll @@ -0,0 +1,333 @@ +; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s + +declare void @llvm.experimental.guard(i1, ...) + +define void @test01(i32 %a, i32 %n) nounwind { +; Prove that (n > 1) ===> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @test01 +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + %n.div.2) +entry: + %cmp1 = icmp sgt i32 %n, 1 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sgt i32 %n.div.2, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test01neg(i32 %a, i32 %n) nounwind { +; Prove that (n > 0) =\=> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @test01neg +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (1 smax %n.div.2)) +entry: + %cmp1 = icmp sgt i32 %n, 0 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sgt i32 %n.div.2, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test02(i32 %a, i32 %n) nounwind { +; Prove that (n >= 2) ===> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @test02 +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + %n.div.2) +entry: + %cmp1 = icmp sge i32 %n, 2 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sgt i32 %n.div.2, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test02neg(i32 %a, i32 %n) nounwind { +; Prove that (n >= 1) =\=> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @test02neg +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (1 smax %n.div.2)) +entry: + %cmp1 = icmp sge i32 %n, 1 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sgt i32 %n.div.2, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test03(i32 %a, i32 %n) nounwind { +; Prove that (n > -2) ===> (n / 2 >= 0). +; TODO: We should be able to prove that (n > -2) ===> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @test03 +; CHECK: Loop %header: Predicated backedge-taken count is (1 + %n.div.2) +entry: + %cmp1 = icmp sgt i32 %n, -2 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sge i32 %n.div.2, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test03neg(i32 %a, i32 %n) nounwind { +; Prove that (n > -3) =\=> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @test03neg +; CHECK: Loop %header: Predicated backedge-taken count is (0 smax (1 + %n.div.2)) +entry: + %cmp1 = icmp sgt i32 %n, -3 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sge i32 %n.div.2, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test04(i32 %a, i32 %n) nounwind { +; Prove that (n >= -1) ===> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @test04 +; CHECK: Loop %header: Predicated backedge-taken count is (1 + %n.div.2) +entry: + %cmp1 = icmp sge i32 %n, -1 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sge i32 %n.div.2, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @test04neg(i32 %a, i32 %n) nounwind { +; Prove that (n >= -2) =\=> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @test04neg +; CHECK: Loop %header: Predicated backedge-taken count is (0 smax (1 + %n.div.2)) +entry: + %cmp1 = icmp sge i32 %n, -2 + %n.div.2 = sdiv i32 %n, 2 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i32 %indvar, 1 + %exitcond = icmp sge i32 %n.div.2, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext01(i32 %a, i32 %n) nounwind { +; Prove that (n > 1) ===> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @testext01 +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (sext i32 %n.div.2 to i64)) +entry: + %cmp1 = icmp sgt i32 %n, 1 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sgt i64 %n.div.2.ext, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext01neg(i32 %a, i32 %n) nounwind { +; Prove that (n > 0) =\=> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @testext01neg +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (1 smax (sext i32 %n.div.2 to i64))) +entry: + %cmp1 = icmp sgt i32 %n, 0 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sgt i64 %n.div.2.ext, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext02(i32 %a, i32 %n) nounwind { +; Prove that (n >= 2) ===> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @testext02 +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (sext i32 %n.div.2 to i64)) +entry: + %cmp1 = icmp sge i32 %n, 2 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sgt i64 %n.div.2.ext, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext02neg(i32 %a, i32 %n) nounwind { +; Prove that (n >= 1) =\=> (n / 2 > 0). +; CHECK: Determining loop execution counts for: @testext02neg +; CHECK: Loop %header: Predicated backedge-taken count is (-1 + (1 smax (sext i32 %n.div.2 to i64))) +entry: + %cmp1 = icmp sge i32 %n, 1 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sgt i64 %n.div.2.ext, %indvar.next + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext03(i32 %a, i32 %n) nounwind { +; Prove that (n > -2) ===> (n / 2 >= 0). +; TODO: We should be able to prove that (n > -2) ===> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @testext03 +; CHECK: Loop %header: Predicated backedge-taken count is (1 + (sext i32 %n.div.2 to i64)) +entry: + %cmp1 = icmp sgt i32 %n, -2 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sge i64 %n.div.2.ext, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext03neg(i32 %a, i32 %n) nounwind { +; Prove that (n > -3) =\=> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @testext03neg +; CHECK: Loop %header: Predicated backedge-taken count is (0 smax (1 + (sext i32 %n.div.2 to i64))) +entry: + %cmp1 = icmp sgt i32 %n, -3 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sge i64 %n.div.2.ext, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext04(i32 %a, i32 %n) nounwind { +; Prove that (n >= -1) ===> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @testext04 +; CHECK: Loop %header: Predicated backedge-taken count is (1 + (sext i32 %n.div.2 to i64)) +entry: + %cmp1 = icmp sge i32 %n, -1 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sge i64 %n.div.2.ext, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} + +define void @testext04neg(i32 %a, i32 %n) nounwind { +; Prove that (n >= -2) =\=> (n / 2 >= 0). +; CHECK: Determining loop execution counts for: @testext04neg +; CHECK: Loop %header: Predicated backedge-taken count is (0 smax (1 + (sext i32 %n.div.2 to i64))) +entry: + %cmp1 = icmp sge i32 %n, -2 + %n.div.2 = sdiv i32 %n, 2 + %n.div.2.ext = sext i32 %n.div.2 to i64 + call void(i1, ...) @llvm.experimental.guard(i1 %cmp1) [ "deopt"() ] + br label %header + +header: + %indvar = phi i64 [ %indvar.next, %header ], [ 0, %entry ] + %indvar.next = add i64 %indvar, 1 + %exitcond = icmp sge i64 %n.div.2.ext, %indvar + br i1 %exitcond, label %header, label %exit + +exit: + ret void +} \ No newline at end of file