diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -9699,6 +9699,25 @@ // Balance the types. if (getTypeSizeInBits(LHS->getType()) < getTypeSizeInBits(FoundLHS->getType())) { + // For unsigned and equality predicates, try to prove that both found + // operands fit into narrow unsigned range. If so, try to prove facts in + // narrow types. + if (!CmpInst::isSigned(FoundPred)) { + auto *NarrowType = LHS->getType(); + auto *WideType = FoundLHS->getType(); + auto BitWidth = getTypeSizeInBits(NarrowType); + const SCEV *MaxValue = getZeroExtendExpr( + getConstant(APInt::getMaxValue(BitWidth)), WideType); + if (isKnownPredicate(ICmpInst::ICMP_ULE, FoundLHS, MaxValue) && + isKnownPredicate(ICmpInst::ICMP_ULE, FoundRHS, MaxValue)) { + const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType); + const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType); + if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS, + TruncFoundRHS, Context)) + return true; + } + } + if (CmpInst::isSigned(Pred)) { LHS = getSignExtendExpr(LHS, FoundLHS->getType()); RHS = getSignExtendExpr(RHS, FoundLHS->getType()); diff --git a/llvm/test/Analysis/ScalarEvolution/srem.ll b/llvm/test/Analysis/ScalarEvolution/srem.ll --- a/llvm/test/Analysis/ScalarEvolution/srem.ll +++ b/llvm/test/Analysis/ScalarEvolution/srem.ll @@ -29,7 +29,7 @@ ; CHECK-NEXT: %add = add nsw i32 %2, %call ; CHECK-NEXT: --> (%2 + %call) U: full-set S: full-set Exits: <> LoopDispositions: { %for.cond: Variant } ; CHECK-NEXT: %inc = add nsw i32 %i.0, 1 -; CHECK-NEXT: --> {1,+,1}<%for.cond> U: [1,0) S: [1,0) Exits: (1 + %width) LoopDispositions: { %for.cond: Computable } +; CHECK-NEXT: --> {1,+,1}<%for.cond> U: full-set S: full-set Exits: (1 + %width) LoopDispositions: { %for.cond: Computable } ; CHECK-NEXT: Determining loop execution counts for: @_Z4loopi ; CHECK-NEXT: Loop %for.cond: backedge-taken count is %width ; CHECK-NEXT: Loop %for.cond: max backedge-taken count is -1 diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1316,4 +1316,45 @@ }); } +TEST_F(ScalarEvolutionsTest, ProveImplicationViaNarrowing) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "define i32 @foo(i32 %start, i32* %q) { " + "entry: " + " %wide.start = zext i32 %start to i64 " + " br label %loop " + "loop: " + " %wide.iv = phi i64 [%wide.start, %entry], [%wide.iv.next, %backedge] " + " %iv = phi i32 [%start, %entry], [%iv.next, %backedge] " + " %cond = icmp eq i64 %wide.iv, 0 " + " br i1 %cond, label %exit, label %backedge " + "backedge: " + " %iv.next = add i32 %iv, -1 " + " %index = zext i32 %iv.next to i64 " + " %load.addr = getelementptr i32, i32* %q, i64 %index " + " %stop = load i32, i32* %load.addr " + " %loop.cond = icmp eq i32 %stop, 0 " + " %wide.iv.next = add nsw i64 %wide.iv, -1 " + " br i1 %loop.cond, label %loop, label %failure " + "exit: " + " ret i32 0 " + "failure: " + " unreachable " + "} ", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *IV = SE.getSCEV(getInstructionByName(F, "iv")); + auto *Zero = SE.getZero(IV->getType()); + auto *Backedge = getInstructionByName(F, "iv.next")->getParent(); + ASSERT_TRUE(Backedge); + EXPECT_TRUE(SE.isBasicBlockEntryGuardedByCond(Backedge, ICmpInst::ICMP_UGT, + IV, Zero)); + }); +} + } // end namespace llvm