Index: llvm/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolution.h +++ llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1099,6 +1099,15 @@ const SCEV *S, const Loop *L, SmallPtrSetImpl &Preds); + /// Compute \p LHS - \p RHS and returns the result as an APInt if it is a + /// constant, and None if it isn't. + /// + /// This is intended to be a cheaper version of getMinusSCEV. We can be + /// frugal here since we just bail out of actually constructing and + /// canonicalizing an expression in the cases where the result isn't going + /// to be a constant. + Optional computeConstantDifference(const SCEV *LHS, const SCEV *RHS); + private: /// A CallbackVH to arrange for ScalarEvolution to be notified whenever a /// Value is deleted. @@ -1752,15 +1761,6 @@ bool splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R, SCEV::NoWrapFlags &Flags); - /// Compute \p LHS - \p RHS and returns the result as an APInt if it is a - /// constant, and None if it isn't. - /// - /// This is intended to be a cheaper version of getMinusSCEV. We can be - /// frugal here since we just bail out of actually constructing and - /// canonicalizing an expression in the cases where the result isn't going - /// to be a constant. - Optional computeConstantDifference(const SCEV *LHS, const SCEV *RHS); - /// Drop memoized information computed for S. void forgetMemoizedResults(const SCEV *S); Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -9833,6 +9833,10 @@ // We avoid subtracting expressions here because this function is usually // fairly deep in the call stack (i.e. is called many times). + // X - X = 0. + if (More == Less) + return APInt(getTypeSizeInBits(More->getType()), 0); + if (isa(Less) && isa(More)) { const auto *LAR = cast(Less); const auto *MAR = cast(More); Index: llvm/unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1678,5 +1678,67 @@ "} "); } +TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "define void @foo(i32 %sz, i32 %pp) { " + "entry: " + " %v0 = add i32 %pp, 0 " + " %v3 = add i32 %pp, 3 " + " br label %loop.body " + "loop.body: " + " %iv = phi i32 [ %iv.next, %loop.body ], [ 0, %entry ] " + " %xa = add nsw i32 %iv, %v0 " + " %yy = add nsw i32 %iv, %v3 " + " %xb = sub nsw i32 %yy, 3 " + " %iv.next = add nsw i32 %iv, 1 " + " %cmp = icmp sle i32 %iv.next, %sz " + " br i1 %cmp, label %loop.body, label %exit " + "exit: " + " ret void " + "} ", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *ScevV0 = SE.getSCEV(getInstructionByName(F, "v0")); // %pp + auto *ScevV3 = SE.getSCEV(getInstructionByName(F, "v3")); // (3 + %pp) + auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv")); // {0,+,1} + auto *ScevXA = SE.getSCEV(getInstructionByName(F, "xa")); // {%pp,+,1} + auto *ScevYY = SE.getSCEV(getInstructionByName(F, "yy")); // {(3 + %pp),+,1} + auto *ScevXB = SE.getSCEV(getInstructionByName(F, "xb")); // {%pp,+,1} + auto *ScevIVNext = SE.getSCEV(getInstructionByName(F, "iv.next")); // {1,+,1} + + auto diff = [&SE](const SCEV *LHS, const SCEV *RHS) -> Optional { + auto ConstantDiffOrNone = SE.computeConstantDifference(LHS, RHS); + if (!ConstantDiffOrNone) + return None; + + auto ExtDiff = ConstantDiffOrNone->getSExtValue(); + int Diff = ExtDiff; + assert(Diff == ExtDiff && "Integer overflow"); + return Diff; + }; + + EXPECT_EQ(diff(ScevV3, ScevV0), 3); + EXPECT_EQ(diff(ScevV0, ScevV3), -3); + EXPECT_EQ(diff(ScevV0, ScevV0), 0); + EXPECT_EQ(diff(ScevV3, ScevV3), 0); + EXPECT_EQ(diff(ScevIV, ScevIV), 0); + EXPECT_EQ(diff(ScevXA, ScevXB), 0); + EXPECT_EQ(diff(ScevXA, ScevYY), -3); + EXPECT_EQ(diff(ScevYY, ScevXB), 3); + EXPECT_EQ(diff(ScevIV, ScevIVNext), -1); + EXPECT_EQ(diff(ScevIVNext, ScevIV), 1); + EXPECT_EQ(diff(ScevIVNext, ScevIVNext), 0); + EXPECT_EQ(diff(ScevV0, ScevIV), None); + EXPECT_EQ(diff(ScevIVNext, ScevV3), None); + EXPECT_EQ(diff(ScevYY, ScevV3), None); + }); +} + } // end anonymous namespace } // end namespace llvm