diff --git a/mlir/include/mlir/Analysis/PresburgerSet.h b/mlir/include/mlir/Analysis/PresburgerSet.h --- a/mlir/include/mlir/Analysis/PresburgerSet.h +++ b/mlir/include/mlir/Analysis/PresburgerSet.h @@ -60,7 +60,7 @@ /// Return the intersection of this set and the given set. PresburgerSet intersect(const PresburgerSet &set) const; - /// Return true if the set contains the given point, or false otherwise. + /// Return true if the set contains the given point, and false otherwise. bool containsPoint(ArrayRef point) const; /// Print the set's internal state. @@ -74,6 +74,9 @@ /// return `this \ set`. PresburgerSet subtract(const PresburgerSet &set) const; + /// Return true if this set is equal to the given set, and false otherwise. + bool isEqual(const PresburgerSet &set) const; + /// Return a universe set of the specified type that contains all points. static PresburgerSet getUniverse(unsigned nDim = 0, unsigned nSym = 0); /// Return an empty set of the specified type that contains no points. diff --git a/mlir/lib/Analysis/PresburgerSet.cpp b/mlir/lib/Analysis/PresburgerSet.cpp --- a/mlir/lib/Analysis/PresburgerSet.cpp +++ b/mlir/lib/Analysis/PresburgerSet.cpp @@ -281,6 +281,22 @@ return result; } +/// Two sets S and T are equal iff S contains T and T contains S. +/// By "S contains T", we mean that S is a superset of or equal to T. +/// +/// S contains T iff T \ S is empty, since if T \ S contains a +/// point then this is a point that is contained in T but not S. +/// +/// Therefore, S is equal to T iff S \ T and T \ S are both empty. +bool PresburgerSet::isEqual(const PresburgerSet &set) const { + assertDimensionsCompatible(set, *this); + PresburgerSet thisMinusSet = subtract(set); + if (!thisMinusSet.isIntegerEmpty()) + return false; + PresburgerSet setMinusThis = set.subtract(*this); + return setMinusThis.isIntegerEmpty(); +} + /// Return true if all the sets in the union are known to be integer empty, /// false otherwise. bool PresburgerSet::isIntegerEmpty() const { diff --git a/mlir/unittests/Analysis/PresburgerSetTest.cpp b/mlir/unittests/Analysis/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/PresburgerSetTest.cpp @@ -6,10 +6,11 @@ // //===----------------------------------------------------------------------===// // -// This file contains tests for PresburgerSet. Each test works by computing -// an operation (union, intersection, subtract, or complement) on two sets -// and checking, for a set of points, that the resulting set contains the point -// iff the result is supposed to contain it. +// This file contains tests for PresburgerSet. The tests for union, +// intersection, subtract, and complement work by computing the operation on +// two sets and checking, for a set of points, that the resulting set contains +// the point iff the result is supposed to contain it. The test for isEqual just +// checks if the result for two sets matches the expected result. // //===----------------------------------------------------------------------===// @@ -34,7 +35,7 @@ } /// Compute the intersection of s and t, and check that each of the given points -/// belongs to the intersection iff it belongs to both of s and t. +/// belongs to the intersection iff it belongs to both s and t. static void testIntersectAtPoints(PresburgerSet s, PresburgerSet t, ArrayRef> points) { PresburgerSet intersection = s.intersect(t); @@ -521,4 +522,73 @@ {1, 10}}); } +TEST(SetTest, isEqual) { + // set = [2, 8] U [10, 20]. + PresburgerSet universe = PresburgerSet::getUniverse(1); + PresburgerSet emptySet = PresburgerSet::getEmptySet(1); + PresburgerSet set = + makeSetFromFACs(1, { + makeFACFromIneqs(1, {{1, -2}, // x >= 2. + {-1, 8}}), // x <= 8. + makeFACFromIneqs(1, {{1, -10}, // x >= 10. + {-1, 20}}), // x <= 20. + }); + + // universe != emptySet. + EXPECT_FALSE(universe.isEqual(emptySet)); + // emptySet != universe. + EXPECT_FALSE(emptySet.isEqual(universe)); + // emptySet == emptySet. + EXPECT_TRUE(emptySet.isEqual(emptySet)); + // universe == universe. + EXPECT_TRUE(universe.isEqual(universe)); + + // universe U emptySet == universe. + EXPECT_TRUE(universe.unionSet(emptySet).isEqual(universe)); + // universe U universe == universe. + EXPECT_TRUE(universe.unionSet(universe).isEqual(universe)); + // emptySet U emptySet == emptySet. + EXPECT_TRUE(emptySet.unionSet(emptySet).isEqual(emptySet)); + // universe U emptySet != emptySet. + EXPECT_FALSE(universe.unionSet(emptySet).isEqual(emptySet)); + // universe U universe != emptySet. + EXPECT_FALSE(universe.unionSet(universe).isEqual(emptySet)); + // emptySet U emptySet != universe. + EXPECT_FALSE(emptySet.unionSet(emptySet).isEqual(universe)); + + // set \ set == emptySet. + EXPECT_TRUE(set.subtract(set).isEqual(emptySet)); + // set == set. + EXPECT_TRUE(set.isEqual(set)); + // set U (universe \ set) == universe. + EXPECT_TRUE(set.unionSet(set.complement()).isEqual(universe)); + // set U (universe \ set) != set. + EXPECT_FALSE(set.unionSet(set.complement()).isEqual(set)); + // set != set U (universe \ set). + EXPECT_FALSE(set.isEqual(set.unionSet(set.complement()))); + + // square is one unit taller than rect. + PresburgerSet square = + makeSetFromFACs(2, {makeFACFromIneqs(2, { + {1, 0, -2}, // x >= 2. + {0, 1, -2}, // y >= 2. + {-1, 0, 9}, // x <= 9. + {0, -1, 9} // y <= 9. + })}); + PresburgerSet rect = + makeSetFromFACs(2, {makeFACFromIneqs(2, { + {1, 0, -2}, // x >= 2. + {0, 1, -2}, // y >= 2. + {-1, 0, 9}, // x <= 9. + {0, -1, 8} // y <= 8. + })}); + EXPECT_FALSE(square.isEqual(rect)); + PresburgerSet universeRect = square.unionSet(square.complement()); + PresburgerSet universeSquare = rect.unionSet(rect.complement()); + EXPECT_TRUE(universeRect.isEqual(universeSquare)); + EXPECT_FALSE(universeRect.isEqual(rect)); + EXPECT_FALSE(universeSquare.isEqual(square)); + EXPECT_FALSE(rect.complement().isEqual(square.complement())); +} + } // namespace mlir