diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -10236,13 +10236,31 @@ // We want to avoid hurting the compile time with analysis of too big trees. if (Depth > MaxSCEVOperationsImplicationDepth) return false; - // We only want to work with ICMP_SGT comparison so far. - // TODO: Extend to ICMP_UGT? - if (Pred == ICmpInst::ICMP_SLT) { - Pred = ICmpInst::ICMP_SGT; + + // We only want to work with GT comparison so far. + if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) { + Pred = CmpInst::getSwappedPredicate(Pred); std::swap(LHS, RHS); std::swap(FoundLHS, FoundRHS); } + + // For unsigned, try to reduce it to corresponding signed comparison. + if (Pred == ICmpInst::ICMP_UGT) + // We can replace unsigned predicate with its signed counterpart if all + // involved values are non-negative. + // TODO: We could have better support for unsigned. + if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) { + // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing + // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us + // use this fact to prove that LHS and RHS are non-negative. + const SCEV *MinusOne = getNegativeSCEV(getOne(LHS->getType())); + if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS, + FoundRHS) && + isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS, + FoundRHS)) + Pred = ICmpInst::ICMP_SGT; + } + if (Pred != ICmpInst::ICMP_SGT) return false; diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1283,4 +1283,37 @@ }); } +TEST_F(ScalarEvolutionsTest, UnsignedIsImpliedViaOperations) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = + parseAssemblyString("define void @foo(i32* %p1, i32* %p2) { " + "entry: " + " %x = load i32, i32* %p1, !range !0 " + " %cond = icmp ne i32 %x, 0 " + " br i1 %cond, label %guarded, label %exit " + "guarded: " + " %y = add i32 %x, -1 " + " ret void " + "exit: " + " ret void " + "} " + "!0 = !{i32 0, i32 2147483647}", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *X = SE.getSCEV(getInstructionByName(F, "x")); + auto *Y = SE.getSCEV(getInstructionByName(F, "y")); + auto *Guarded = getInstructionByName(F, "y")->getParent(); + ASSERT_TRUE(Guarded); + EXPECT_TRUE( + SE.isBasicBlockEntryGuardedByCond(Guarded, ICmpInst::ICMP_ULT, Y, X)); + EXPECT_TRUE( + SE.isBasicBlockEntryGuardedByCond(Guarded, ICmpInst::ICMP_UGT, X, Y)); + }); +} + } // end namespace llvm