diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -142,7 +142,7 @@ truncateIdKind(IdKind::Symbol, counts); truncateIdKind(IdKind::Local, counts); removeInequalityRange(counts.getNumIneqs(), getNumInequalities()); - removeInequalityRange(counts.getNumEqs(), getNumEqualities()); + removeEqualityRange(counts.getNumEqs(), getNumEqualities()); } unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) { diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -1202,3 +1202,13 @@ EXPECT_TRUE(poly3.containsPointNoLocal({0, 0})); EXPECT_FALSE(poly3.containsPointNoLocal({1, 0})); } + +TEST(IntegerPolyhedronTest, truncateEqualityRegressionTest) { + // IntegerRelation::truncate was truncating inequalities to the number of + // equalities. + IntegerRelation set(1); + IntegerRelation::CountsSnapshot snapshot = set.getCounts(); + set.addEquality({1, 0}); + set.truncate(snapshot); + EXPECT_EQ(set.getNumEqualities(), 0u); +}