diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -153,6 +153,10 @@ /// Returns such a point if one exists, or an empty Optional otherwise. Optional> findIntegerSample() const; + /// Returns true if the given point satisfies the constraints, or false + /// otherwise. + bool containsPoint(ArrayRef point) const; + // Clones this object. std::unique_ptr clone() const; diff --git a/mlir/include/mlir/Analysis/Presburger/Set.h b/mlir/include/mlir/Analysis/Presburger/Set.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/Presburger/Set.h @@ -0,0 +1,102 @@ +//===- Set.h - MLIR PresburgerSet Class -------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A class to represent unions of FlatAffineConstraints. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_SET_H +#define MLIR_ANALYSIS_PRESBURGER_SET_H + +#include "mlir/Analysis/AffineStructures.h" + +namespace mlir { + +/// This class can represent a union of FlatAffineConstraints, with support for +/// union, intersection, subtraction and complement operations, as well as +/// sampling. +/// +/// The FlatAffineConstraints (FACs) are stored in a vector, and the set +/// represents the union of these FACs. +class PresburgerSet { +public: + PresburgerSet(unsigned nDim = 0, unsigned nSym = 0) + : nDim(nDim), nSym(nSym) {} + PresburgerSet(FlatAffineConstraints cs); + + /// Return the number of FACs in the union. + unsigned getNumFACs() const; + + /// Return the number of real dimensions. + unsigned getNumDims() const; + + /// Return the number of symbolic dimensions. + unsigned getNumSyms() const; + + /// Returns a reference to the list of FlatAffineConstraints. + ArrayRef getFlatAffineConstraints() const; + + /// Returns the FlatAffineConsatraints at the specified index. + const FlatAffineConstraints &getFlatAffineConstraints(unsigned index) const; + + /// Add the given FlatAffineConstraints to the union. + void addFlatAffineConstraints(FlatAffineConstraints cs); + + /// Intersect the given set with the current set. + void unionSet(const PresburgerSet &set); + + /// Intersect the given set with the current set. + void intersectSet(const PresburgerSet &set); + + /// Returns true if the set contains the given point, or false otherwise. + bool containsPoint(ArrayRef point) const; + + /// Print the set's internal state. + void print(raw_ostream &os) const; + void dump() const; + + /// Returns the complement of the given set. + static PresburgerSet complement(const PresburgerSet &set); + + /// Subtract the given set from the current set. + void subtract(const PresburgerSet &set); + + /// Return the set difference c - set. + static PresburgerSet subtract(FlatAffineConstraints &fac, + const PresburgerSet &set); + + /// Return the set difference c - set. + static PresburgerSet subtract(FlatAffineConstraints &&fac, + const PresburgerSet &set); + + /// Return a universe set of the specified type that contains all points. + static PresburgerSet makeUniverse(unsigned nDim = 0, unsigned nSym = 0); + + /// Returns true if all the sets in the union are known to be integer empty + /// false otherwise. + bool isIntegerEmpty() const; + + /// Find an integer sample from the given set. This should not be called if + /// any of the FACs in the union are unbounded. + llvm::Optional> findIntegerSample(); + +private: + /// Number of identifiers corresponding to real dimensions. + unsigned nDim; + + /// Number of symbolic dimensions, unknown but constant for analysis, as in + /// FlatAffineConstraints. + unsigned nSym; + + /// The list of flatAffineConstraints that this set is the union of. + SmallVector flatAffineConstraints; +}; + +} // namespace mlir + +#endif // MLIR_ANALYSIS_PRESBURGER_SET_H diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -169,6 +169,9 @@ /// Rollback to a snapshot. This invalidates all later snapshots. void rollback(unsigned snapshot); + /// Add all the constraints from the given FlatAffineConstraints. + void addFlatAffineConstraints(const FlatAffineConstraints &fac); + /// Compute the maximum or minimum value of the given row, depending on /// direction. The specified row is never pivoted. /// diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1058,6 +1058,33 @@ return Simplex(*this).findIntegerSample(); } +/// Helper to evaluate the value of an affine expression at a point. +/// The expression is a list of coefficients for the dimensions followed by the +/// constant term. +static int64_t valueAt(ArrayRef expr, ArrayRef point) { + assert(expr.size() == 1 + point.size() && + "Dimensionalities of point and expresion don't match!"); + int64_t value = expr.back(); + for (unsigned i = 0; i < point.size(); ++i) + value += expr[i] * point[i]; + return value; +} + +/// A point satisfies an equality iff the value of the equality at the +/// expression is zero, and it satisfies an inequality iff the value of the +/// inequality at that point is non-negative. +bool FlatAffineConstraints::containsPoint(ArrayRef point) const { + for (unsigned i = 0; i < getNumEqualities(); ++i) { + if (valueAt(getEquality(i), point) != 0) + return false; + } + for (unsigned i = 0; i < getNumInequalities(); ++i) { + if (valueAt(getInequality(i), point) < 0) + return false; + } + return true; +} + /// Tightens inequalities given that we are dealing with integer spaces. This is /// analogous to the GCD test but applied to inequalities. The constant term can /// be reduced to the preceding multiple of the GCD of the coefficients, i.e., diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt --- a/mlir/lib/Analysis/Presburger/CMakeLists.txt +++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt @@ -1,4 +1,6 @@ add_mlir_library(MLIRPresburger + ../AffineStructures.cpp Simplex.cpp Matrix.cpp + Set.cpp ) diff --git a/mlir/lib/Analysis/Presburger/Set.cpp b/mlir/lib/Analysis/Presburger/Set.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Set.cpp @@ -0,0 +1,284 @@ +//===- Set.cpp - MLIR PresburgerSet Class ---------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/Set.h" +#include "mlir/Analysis/Presburger/Simplex.h" + +using namespace mlir; + +PresburgerSet::PresburgerSet(FlatAffineConstraints fac) + : nDim(fac.getNumDimIds()), nSym(fac.getNumSymbolIds()) { + addFlatAffineConstraints(fac); +} + +unsigned PresburgerSet::getNumFACs() const { + return flatAffineConstraints.size(); +} + +unsigned PresburgerSet::getNumDims() const { return nDim; } + +unsigned PresburgerSet::getNumSyms() const { return nSym; } + +ArrayRef +PresburgerSet::getFlatAffineConstraints() const { + return flatAffineConstraints; +} + +const FlatAffineConstraints & +PresburgerSet::getFlatAffineConstraints(unsigned index) const { + assert(index < flatAffineConstraints.size() && "index out of bounds!"); + return flatAffineConstraints[index]; +} + +/// Assert that the FlatAffineConstraints and PresburgerSet live in +/// compatible spaces.. +static void assertDimensionsCompatible(const FlatAffineConstraints &fac, + const PresburgerSet &set) { + assert(fac.getNumDimIds() == set.getNumDims() && + "Number of dimensions of the FlatAffineConstraints and PresburgerSet" + "do not match!"); + assert(fac.getNumSymbolIds() == set.getNumSyms() && + "Number of symbols of the FlatAffineConstraints and PresburgerSet" + "do not match!"); +} + +/// Assert that the two PresburgerSets live in compatible spaces. +static void assertDimensionsCompatible(const PresburgerSet &set1, + const PresburgerSet &set2) { + assert(set1.getNumDims() == set2.getNumDims() && + "Number of dimensions of the PresburgerSets do not match!"); + assert(set1.getNumSyms() == set2.getNumSyms() && + "Number of symbols of the PresburgerSets do not match!"); +} + +/// Add an FAC to the union. +void PresburgerSet::addFlatAffineConstraints(FlatAffineConstraints fac) { + assertDimensionsCompatible(fac, *this); + flatAffineConstraints.push_back(std::move(fac)); +} + +/// Union the current set with the given set. +/// +/// This is accomplished by simply adding all the FACs of the given set to the +/// current set. +void PresburgerSet::unionSet(const PresburgerSet &set) { + assertDimensionsCompatible(set, *this); + for (const FlatAffineConstraints &fac : set.flatAffineConstraints) + addFlatAffineConstraints(std::move(fac)); +} + +/// A point is contained in the union iff any of the parts contain the point. +bool PresburgerSet::containsPoint(ArrayRef point) const { + for (const FlatAffineConstraints &fac : flatAffineConstraints) { + if (fac.containsPoint(point)) + return true; + } + return false; +} + +PresburgerSet PresburgerSet::makeUniverse(unsigned nDim, unsigned nSym) { + PresburgerSet result(nDim, nSym); + result.addFlatAffineConstraints(FlatAffineConstraints(nDim, nSym)); + return result; +} + +// Compute the intersection of the two sets. +// +// We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...) +// as (S_1 and T_1) or (S_1 and T_2) or ... +void PresburgerSet::intersectSet(const PresburgerSet &set) { + assertDimensionsCompatible(set, *this); + + PresburgerSet result(nDim, nSym); + for (const FlatAffineConstraints &csA : flatAffineConstraints) { + for (const FlatAffineConstraints &csB : set.flatAffineConstraints) { + FlatAffineConstraints intersection(csA); + intersection.append(csB); + if (!intersection.isEmpty()) + result.addFlatAffineConstraints(std::move(intersection)); + } + } + *this = std::move(result); +} + +/// An equality can be decomposed into two inequalities. This function allows +/// p: +static SmallVector negatedCoeffs(ArrayRef coeffs) { + SmallVector negatedCoeffs; + for (int64_t coeff : coeffs) + negatedCoeffs.emplace_back(-coeff); + return negatedCoeffs; +} + +/// Return the complement of the given inequality. +/// +/// The complement of a_1 x_1 + ... + a_n x_ + c >= 0 is +/// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0. +static SmallVector complementIneq(ArrayRef ineq) { + SmallVector coeffs; + for (int64_t coeff : ineq) + coeffs.emplace_back(-coeff); + --coeffs.back(); + return coeffs; +} + +/// Return the set difference b - s and accumulate the result into `result`. +/// `simplex` must correspond to b. +/// +/// In the following, U denotes union, /\ denotes intersection, - denotes set +/// subtraction and ~ denotes complement. +/// Let b be the basic set and s = (U_i s_i) be the set. We want b - (U_i s_i). +/// +/// Let s_i = /\_j s_ij. To compute b - s_i = b /\ ~s_i, we partition s_i based +/// on the first violated constraint: +/// ~s_i = (~s_i1) U (s_i1 /\ ~s_i2) U (s_i1 /\ s_i2 /\ ~s_i3) U ... +/// And the required result is (b /\ ~s_i1) U (b /\ s_i1 /\ ~s_i2) U ... +/// We recurse by subtracting U_{j > i} S_j from each of these parts and +/// returning the union of the results. +/// +/// As a heuristic, we try adding all the constraints and check if simplex +/// says that the intersection is empty. Also, in the process we find out that +/// some constraints are redundant, which we then ignore. +void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex, + const PresburgerSet &s, unsigned i, + PresburgerSet &result) { + if (i == s.getNumFACs()) { + result.addFlatAffineConstraints(b); + return; + } + const FlatAffineConstraints &sI = s.getFlatAffineConstraints()[i]; + unsigned initialSnap = simplex.getSnapshot(); + unsigned offset = simplex.numConstraints(); + simplex.addFlatAffineConstraints(sI); + + if (simplex.isEmpty()) { + /// b /\ s_i is empty, so b - s_i = b. We move directly to i + 1. + simplex.rollback(initialSnap); + subtractRecursively(b, simplex, s, i + 1, result); + return; + } + + simplex.detectRedundant(); + SmallVector isMarkedRedundant; + for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities(); + j++) + isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j)); + + simplex.rollback(initialSnap); + + auto recurseWithInequality = [&, i](ArrayRef ineq) { + size_t snap = simplex.getSnapshot(); + b.addInequality(ineq); + simplex.addInequality(ineq); + subtractRecursively(b, simplex, s, i + 1, result); + b.removeInequality(b.getNumInequalities() - 1); + simplex.rollback(snap); + }; + + auto processInequality = [&](ArrayRef ineq) { + recurseWithInequality(complementIneq(ineq)); + b.addInequality(ineq); + simplex.addInequality(ineq); + }; + + unsigned originalNumIneqs = b.getNumInequalities(); + unsigned originalNumEqs = b.getNumEqualities(); + + for (unsigned j = 0; j < sI.getNumInequalities(); j++) { + if (isMarkedRedundant[j]) + continue; + processInequality(sI.getInequality(j)); + } + + offset = sI.getNumInequalities(); + for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) { + const ArrayRef &coeffs = sI.getEquality(j); + // Same as the above loop for inequalities, done once each for the positive + // and negative inequalities. + if (!isMarkedRedundant[offset + 2 * j]) + processInequality(coeffs); + if (!isMarkedRedundant[offset + 2 * j + 1]) + processInequality(negatedCoeffs(coeffs)); + } + + for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i) + b.removeInequality(i - 1); + + for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i) + b.removeEquality(i - 1); + + simplex.rollback(initialSnap); +} + +/// Returns the set difference fac - set. +PresburgerSet PresburgerSet::subtract(FlatAffineConstraints &fac, + const PresburgerSet &set) { + assertDimensionsCompatible(fac, set); + if (fac.isEmptyByGCDTest()) + return PresburgerSet(fac.getNumDimIds(), fac.getNumSymbolIds()); + + PresburgerSet result(fac.getNumDimIds(), fac.getNumSymbolIds()); + Simplex simplex(fac); + subtractRecursively(fac, simplex, set, 0, result); + return result; +} + +PresburgerSet PresburgerSet::subtract(FlatAffineConstraints &&fac, + const PresburgerSet &set) { + FlatAffineConstraints lvalue(fac); + return subtract(fac, set); +} + +PresburgerSet PresburgerSet::complement(const PresburgerSet &set) { + // The complement of S is the universe of all points, minus S. + return subtract(FlatAffineConstraints(set.getNumDims(), set.getNumSyms()), + set); +} + +/// Subtracts the set from the current set. +/// +/// We compute (U_i t_i) - (U_i set_i) as U_i (t_i - U_i set_i). +void PresburgerSet::subtract(const PresburgerSet &set) { + assertDimensionsCompatible(set, *this); + PresburgerSet result(nDim, nSym); + for (FlatAffineConstraints &c : flatAffineConstraints) + result.unionSet(subtract(c, set)); + *this = result; +} + +/// Return true if all the sets in the union are known to be integer empty, +/// false otherwise. +bool PresburgerSet::isIntegerEmpty() const { + assert(nSym == 0 && "findIntegerSample is intended for non-symbolic sets"); + // The set is empty iff all of the disjuncts are empty. + for (const FlatAffineConstraints &fac : flatAffineConstraints) { + if (!fac.isIntegerEmpty()) + return false; + } + return true; +} + +Optional> PresburgerSet::findIntegerSample() { + assert(nSym == 0 && "findIntegerSample is intended for non-symbolic sets"); + // A sample exists iff any of the disjuncts containts a sample. + for (FlatAffineConstraints &fac : flatAffineConstraints) { + if (Optional> opt = fac.findIntegerSample()) + return *opt; + } + return {}; +} + +void PresburgerSet::print(raw_ostream &os) const { + os << getNumFACs() << " FlatAffineConstraints:\n"; + for (const FlatAffineConstraints &fac : flatAffineConstraints) { + fac.print(os); + os << '\n'; + } +} + +void PresburgerSet::dump() const { print(llvm::errs()); } diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -451,6 +451,16 @@ } } +/// Add all the constraints from the given FlatAffineConstraints. +void Simplex::addFlatAffineConstraints(const FlatAffineConstraints &fac) { + assert(fac.getNumIds() == numVariables() && + "FlatAffineConstraints must have same dimensionality as simplex"); + for (unsigned i = 0, e = fac.getNumInequalities(); i < e; ++i) + addInequality(fac.getInequality(i)); + for (unsigned i = 0, e = fac.getNumEqualities(); i < e; ++i) + addEquality(fac.getEquality(i)); +} + Optional Simplex::computeRowOptimum(Direction direction, unsigned row) { // Keep trying to find a pivot for the row in the specified direction. diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -15,17 +15,6 @@ namespace mlir { -/// Evaluate the value of the given affine expression at the specified point. -/// The expression is a list of coefficients for the dimensions followed by the -/// constant term. -int64_t valueAt(ArrayRef expr, ArrayRef point) { - assert(expr.size() == 1 + point.size()); - int64_t value = expr.back(); - for (unsigned i = 0; i < point.size(); ++i) - value += expr[i] * point[i]; - return value; -} - /// If 'hasValue' is true, check that findIntegerSample returns a valid sample /// for the FlatAffineConstraints fac. /// @@ -41,10 +30,7 @@ } } else { ASSERT_TRUE(maybeSample.hasValue()); - for (unsigned i = 0; i < fac.getNumEqualities(); ++i) - EXPECT_EQ(valueAt(fac.getEquality(i), *maybeSample), 0); - for (unsigned i = 0; i < fac.getNumInequalities(); ++i) - EXPECT_GE(valueAt(fac.getInequality(i), *maybeSample), 0); + EXPECT_TRUE(fac.containsPoint(*maybeSample)); } } diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -1,7 +1,8 @@ add_mlir_unittest(MLIRPresburgerTests MatrixTest.cpp SimplexTest.cpp + SetTest.cpp ) target_link_libraries(MLIRPresburgerTests - PRIVATE MLIRPresburger) + PRIVATE MLIRPresburger MLIRAnalysis) diff --git a/mlir/unittests/Analysis/Presburger/SetTest.cpp b/mlir/unittests/Analysis/Presburger/SetTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/SetTest.cpp @@ -0,0 +1,510 @@ +//===- SetTest.cpp - Tests for PresburgerSet-------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/Set.h" + +#include +#include + +namespace mlir { + +void testUnionAtPoints(PresburgerSet s, PresburgerSet t, + ArrayRef> points) { + PresburgerSet unionSet = s; + unionSet.unionSet(t); + for (const auto &point : points) { + bool inS = s.containsPoint(point); + bool inT = t.containsPoint(point); + bool inUnion = unionSet.containsPoint(point); + EXPECT_EQ(inUnion, inS || inT); + } +} + +void testIntersectAtPoints(PresburgerSet s, PresburgerSet t, + ArrayRef> points) { + PresburgerSet intersection = s; + intersection.unionSet(t); + for (const auto &point : points) { + bool inS = s.containsPoint(point); + bool inT = t.containsPoint(point); + bool inIntersection = intersection.containsPoint(point); + EXPECT_EQ(inIntersection, inS && inT); + } +} + +void testSubtractAtPoints(PresburgerSet s, PresburgerSet t, + ArrayRef> points) { + PresburgerSet diff = s; + diff.subtract(t); + for (const auto &point : points) { + bool inS = s.containsPoint(point); + bool inT = t.containsPoint(point); + bool inDiff = diff.containsPoint(point); + if (inT) + EXPECT_FALSE(inDiff); + else + EXPECT_EQ(inDiff, inS); + } +} + +void testComplementAtPoints(PresburgerSet s, + ArrayRef> points) { + PresburgerSet complement = PresburgerSet::complement(s); + for (const auto &point : points) { + bool inS = s.containsPoint(point); + bool inComplement = complement.containsPoint(point); + if (inS) + EXPECT_FALSE(inComplement); + else + EXPECT_TRUE(inComplement); + } +} + +/// Construct a FlatAffineConstraints from a set of inequality and +/// equality constraints. +FlatAffineConstraints +makeFACFromConstraints(unsigned dims, ArrayRef> ineqs, + ArrayRef> eqs) { + FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + 1, dims); + for (const auto &eq : eqs) + fac.addEquality(eq); + for (const auto &ineq : ineqs) + fac.addInequality(ineq); + return fac; +} + +FlatAffineConstraints makeFACFromIneq(unsigned dims, + ArrayRef> ineqs) { + return makeFACFromConstraints(dims, ineqs, {}); +} + +PresburgerSet makeSetFromFACs(unsigned dims, + ArrayRef facs) { + PresburgerSet set(dims); + for (const FlatAffineConstraints &fac : facs) + set.addFlatAffineConstraints(fac); + return set; +} + +TEST(SetTest, containsPoint) { + PresburgerSet set1 = + makeSetFromFACs(1, { + makeFACFromIneq(1, {{1, -2}, // x >= 2. + {-1, 8}}), // x <= 8. + makeFACFromIneq(1, {{1, -10}, // x >= 10. + {-1, 20}}), // x <= 20. + }); + for (unsigned x = 0; x <= 21; ++x) { + if ((2 <= x && x <= 8) || (10 <= x && x <= 20)) + EXPECT_TRUE(set1.containsPoint({x})); + else + EXPECT_FALSE(set1.containsPoint({x})); + } + + // A parallelogram with vertices {(3, 1), (10, -6), (24, 8), (17, 15)} union + // a square with opposite corners (2, 2) and (10, 10). + PresburgerSet set2 = + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 1, -2}, // x + y >= 4. + {-1, -1, 30}, // x + y <= 32. + {1, -1, 0}, // x - y >= 2. + {-1, 1, 10}, // x - y <= 16. + }), + makeFACFromIneq(2, { + {1, 0, -2}, // x >= 2. + {0, 1, -2}, // y >= 2. + {-1, 0, 10}, // x <= 10. + {0, -1, 10} // y <= 10. + })}); + + for (unsigned x = 1; x <= 25; ++x) { + for (unsigned y = -6; y <= 16; ++y) { + if (4 <= x + y && x + y <= 32 && 2 <= x - y && x - y <= 16) + EXPECT_TRUE(set2.containsPoint({x, y})); + else if (2 <= x && x <= 10 && 2 <= y && y <= 10) + EXPECT_TRUE(set2.containsPoint({x, y})); + else + EXPECT_FALSE(set2.containsPoint({x, y})); + } + } +} + +TEST(SetTest, Union) { + PresburgerSet set = + makeSetFromFACs(1, { + makeFACFromIneq(1, {{1, -2}, // x >= 2. + {-1, 8}}), // x <= 8. + makeFACFromIneq(1, {{1, -10}, // x >= 10. + {-1, 20}}), // x <= 20. + }); + + // Universe union set. + testUnionAtPoints(PresburgerSet::makeUniverse(1), set, + {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + // Null union set. + testUnionAtPoints(PresburgerSet(1), set, + {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + // Null union Universe. + testUnionAtPoints(PresburgerSet(1), PresburgerSet::makeUniverse(1), + {{1}, {2}, {0}, {-1}}); + + // Universe union Null. + testUnionAtPoints(PresburgerSet::makeUniverse(1), PresburgerSet(1), + {{1}, {2}, {0}, {-1}}); + + // Null union Null. + testUnionAtPoints(PresburgerSet(1), PresburgerSet(1), {{1}, {2}, {0}, {-1}}); +} + +TEST(SetTest, Intersect) { + PresburgerSet set = + makeSetFromFACs(1, { + makeFACFromIneq(1, {{1, -2}, // x >= 2. + {-1, 8}}), // x <= 8. + makeFACFromIneq(1, {{1, -10}, // x >= 10. + {-1, 20}}), // x <= 20. + }); + + // Universe intersection set. + testUnionAtPoints(PresburgerSet::makeUniverse(1), set, + {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + // Null intersection set. + testUnionAtPoints(PresburgerSet(1), set, + {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + // Null intersection Universe. + testUnionAtPoints(PresburgerSet(1), PresburgerSet::makeUniverse(1), + {{1}, {2}, {0}, {-1}}); + + // Universe intersection Null. + testUnionAtPoints(PresburgerSet::makeUniverse(1), PresburgerSet(1), + {{1}, {2}, {0}, {-1}}); + + // Universe intersection Universe. + testUnionAtPoints(PresburgerSet::makeUniverse(1), + PresburgerSet::makeUniverse(1), {{1}, {2}, {0}, {-1}}); +} + +TEST(SetTest, Subtract) { + // The interval [2, 8] minus + // the interval [10, 20]. + testSubtractAtPoints( + makeSetFromFACs(1, {makeFACFromIneq(1, {})}), + makeSetFromFACs(1, + { + makeFACFromIneq(1, {{1, -2}, // x >= 2. + {-1, 8}}), // x <= 8. + makeFACFromIneq(1, {{1, -10}, // x >= 10. + {-1, 20}}), // x <= 20. + }), + {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + // ((-infinity, 0] U [3, 4] U [6, 7]) - ([2, 3] U [5, 6]) + testSubtractAtPoints( + makeSetFromFACs(1, + { + makeFACFromIneq(1, + { + {-1, 0} // x <= 0. + }), + makeFACFromIneq(1, + { + {1, -3}, // x >= 3. + {-1, 4} // x <= 4. + }), + makeFACFromIneq(1, + { + {1, -6}, // x >= 6. + {-1, 7} // x <= 7. + }), + }), + makeSetFromFACs(1, {makeFACFromIneq(1, + { + {1, -2}, // x >= 2. + {-1, 3}, // x <= 3. + }), + makeFACFromIneq(1, + { + {1, -5}, // x >= 5. + {-1, 6} // x <= 6. + })}), + {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}}); + + // Expected result is {[x, y] : x > y}, i.e., {[x, y] : x >= y + 1}. + testSubtractAtPoints( + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, -1, 0} // x >= y. + })}), + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 1, 0} // x >= -y. + })}), + {{0, 1}, {1, 1}, {1, 0}, {1, -1}, {0, -1}}); + + // A rectangle with corners at (2, 2) and (10, 10), minus + // a rectangle with corners at (5, -10) and (7, 100). + // This splits the former rectangle into two halves, (2, 2) to (5, 10) and + // (7, 2) to (10, 10). + testSubtractAtPoints( + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 0, -2}, // x >= 2. + {0, 1, -2}, // y >= 2. + {-1, 0, 10}, // x <= 10. + {0, -1, 10} // y <= 10. + })}), + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 0, -5}, // x >= 5. + {0, 1, 10}, // y >= -10. + {-1, 0, 7}, // x <= 7. + {0, -1, 100}, // y <= 100. + })}), + {{1, 2}, {2, 2}, {4, 2}, {5, 2}, {7, 2}, {8, 2}, {11, 2}, + {1, 1}, {2, 1}, {4, 1}, {5, 1}, {7, 1}, {8, 1}, {11, 1}, + {1, 10}, {2, 10}, {4, 10}, {5, 10}, {7, 10}, {8, 10}, {11, 10}, + {1, 11}, {2, 11}, {4, 11}, {5, 11}, {7, 11}, {8, 11}, {11, 11}}); + + // A rectangle with corners at (2, 2) and (10, 10), minus + // a rectangle with corners at (5, 4) and (7, 8). + // This creates a hole in the middle of the former rectangle, and the + // resulting set can be represented as a union of four rectangles. + testSubtractAtPoints( + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 0, -2}, // x >= 2. + {0, 1, -2}, // y >= 2. + {-1, 0, 10}, // x <= 10. + {0, -1, 10} // y <= 10. + })}), + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 0, -5}, // x >= 5. + {0, 1, -4}, // y >= 4. + {-1, 0, 7}, // x <= 7. + {0, -1, 8}, // y <= 8. + })}), + {{1, 1}, + {2, 2}, + {10, 10}, + {11, 11}, + {5, 4}, + {7, 4}, + {5, 8}, + {7, 8}, + {4, 4}, + {8, 4}, + {4, 8}, + {8, 8}}); + + // The second set is a superset of the first one, since on the line x + y = 0, + // y <= 1 is equivalent to x >= -1. So the result is empty. + testSubtractAtPoints( + makeSetFromFACs(2, {makeFACFromConstraints(2, + { + {1, 0, 0} // x >= 0. + }, + { + {1, 1, 0} // x + y = 0. + })}), + makeSetFromFACs(2, {makeFACFromConstraints(2, + { + {0, -1, 1} // y <= 1. + }, + { + {1, 1, 0} // x + y = 0. + })}), + {{0, 0}, + {1, -1}, + {2, -2}, + {-1, 1}, + {-2, 2}, + {1, 1}, + {-1, -1}, + {-1, 1}, + {1, -1}}); + + // The result should be {0} U {2}. + testSubtractAtPoints( + makeSetFromFACs(1, + { + makeFACFromIneq(1, {{1, 0}, // x >= 0. + {-1, 2}}), // x <= 2. + }), + makeSetFromFACs(1, + { + makeFACFromConstraints(1, {}, + { + {1, -1} // x = 1. + }), + }), + {{-1}, {0}, {1}, {2}, {3}}); + + // Sets with lots of redundant inequalities to test the redundancy heuristic. + // (the heuristic is for the subtrahend, the second set which is the one being + // subtracted) + + // A parallelogram with vertices {(3, 1), (10, -6), (24, 8), (17, 15)} minus + // a triangle with vertices {(2, 2), (10, 2), (10, 10)}. + testSubtractAtPoints( + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 1, -2}, // x + y >= 4. + {-1, -1, 30}, // x + y <= 32. + {1, -1, 0}, // x - y >= 2. + {-1, 1, 10}, // x - y <= 16. + })}), + makeSetFromFACs( + 2, {makeFACFromIneq(2, + { + {1, 0, -2}, // x >= 2. [redundant] + {0, 1, -2}, // y >= 2. + {-1, 0, 10}, // x <= 10. + {0, -1, 10}, // y <= 10. [redundant] + {1, 1, -2}, // x + y >= 2. [redundant] + {-1, -1, 30}, // x + y <= 30. [redundant] + {1, -1, 0}, // x - y >= 0. + {-1, 1, 10}, // x - y <= 10. + })}), + {{1, 2}, {2, 2}, {3, 2}, {4, 2}, {1, 1}, {2, 1}, {3, 1}, + {4, 1}, {2, 0}, {3, 0}, {4, 0}, {5, 0}, {10, 2}, {11, 2}, + {10, 1}, {10, 10}, {10, 11}, {10, 9}, {11, 10}, {10, -6}, {11, -6}, + {24, 8}, {24, 7}, {17, 15}, {16, 15}}); + + testSubtractAtPoints( + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 1, -2}, // x + y >= 4. + {-1, -1, 30}, // x + y <= 32. + {1, -1, 0}, // x - y >= 2. + {-1, 1, 10}, // x - y <= 16. + })}), + makeSetFromFACs( + 2, {makeFACFromIneq(2, + { + {1, 0, -2}, // x >= 2. [redundant] + {0, 1, -2}, // y >= 2. + {-1, 0, 10}, // x <= 10. + {0, -1, 10}, // y <= 10. [redundant] + {1, 1, -2}, // x + y >= 2. [redundant] + {-1, -1, 30}, // x + y <= 30. [redundant] + {1, -1, 0}, // x - y >= 0. + {-1, 1, 10}, // x - y <= 10. + })}), + {{1, 2}, {2, 2}, {3, 2}, {4, 2}, {1, 1}, {2, 1}, {3, 1}, + {4, 1}, {2, 0}, {3, 0}, {4, 0}, {5, 0}, {10, 2}, {11, 2}, + {10, 1}, {10, 10}, {10, 11}, {10, 9}, {11, 10}, {10, -6}, {11, -6}, + {24, 8}, {24, 7}, {17, 15}, {16, 15}}); + + // ((-infinity, -5] U [3, 3] U [4, 4] U [5, 5]) - ([-2, -10] U [3, 4] U [6, + // 7]) + testSubtractAtPoints( + makeSetFromFACs(1, + { + makeFACFromIneq(1, + { + {-1, -5}, // x <= -5. + }), + makeFACFromConstraints(1, {}, + { + {1, -3} // x = 3. + }), + makeFACFromConstraints(1, {}, + { + {1, -4} // x = 4. + }), + makeFACFromConstraints(1, {}, + { + {1, -5} // x = 5. + }), + }), + makeSetFromFACs( + 1, + { + makeFACFromIneq(1, + { + {-1, -2}, // x <= -2. + {1, -10}, // x >= -10. + {-1, 0}, // x <= 0. [redundant] + {-1, 10}, // x <= 10. [redundant] + {1, -100}, // x >= -100. [redundant] + {1, -50} // x >= -50. [redundant] + }), + makeFACFromIneq(1, + { + {1, -3}, // x >= 3. + {-1, 4}, // x <= 4. + {1, 1}, // x >= -1. [redundant] + {1, 7}, // x >= -7. [redundant] + {-1, 10} // x <= 10. [redundant] + }), + makeFACFromIneq(1, + { + {1, -6}, // x >= 6. + {-1, 7}, // x <= 7. + {1, 1}, // x >= -1. [redundant] + {1, -3}, // x >= -3. [redundant] + {-1, 5} // x <= 5. [redundant] + }), + }), + {{-6}, + {-5}, + {-4}, + {-9}, + {-10}, + {-11}, + {0}, + {1}, + {2}, + {3}, + {4}, + {5}, + {6}, + {7}, + {8}}); +} + +TEST(SetTest, Complement) { + // Complement of universe. + testComplementAtPoints( + PresburgerSet::makeUniverse(1), + {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + // Complement of null set. + testComplementAtPoints( + PresburgerSet(1), + {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + testComplementAtPoints( + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 0, -2}, // x >= 2. + {0, 1, -2}, // y >= 2. + {-1, 0, 10}, // x <= 10. + {0, -1, 10} // y <= 10. + })}), + {{1, 1}, + {2, 1}, + {1, 2}, + {2, 2}, + {2, 3}, + {3, 2}, + {10, 10}, + {10, 11}, + {11, 10}, + {2, 10}, + {2, 11}, + {1, 10}}); +} + +} // namespace mlir