diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -97,6 +97,13 @@ ids.append(idArgs.begin(), idArgs.end()); } + /// Return a system with no constraints, i.e., one which is satisfied by all + /// points. + static FlatAffineConstraints universe(unsigned numDims = 0, + unsigned numSymbols = 0) { + return FlatAffineConstraints(numDims, numSymbols); + } + /// Create a flat affine constraint system from an AffineValueMap or a list of /// these. The constructed system will only include equalities. explicit FlatAffineConstraints(const AffineValueMap &avm); @@ -153,6 +160,10 @@ /// Returns such a point if one exists, or an empty Optional otherwise. Optional> findIntegerSample() const; + /// Returns true if the given point satisfies the constraints, or false + /// otherwise. + bool containsPoint(ArrayRef point) const; + // Clones this object. std::unique_ptr clone() const; diff --git a/mlir/include/mlir/Analysis/Presburger/Set.h b/mlir/include/mlir/Analysis/Presburger/Set.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/Presburger/Set.h @@ -0,0 +1,103 @@ +//===- Set.h - MLIR PresburgerSet Class -------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A class to represent unions of FlatAffineConstraints. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_SET_H +#define MLIR_ANALYSIS_PRESBURGER_SET_H + +#include "mlir/Analysis/AffineStructures.h" + +namespace mlir { + +/// This class can represent a union of FlatAffineConstraints, with support for +/// union, intersection, subtraction and complement operations, as well as +/// sampling. +/// +/// The FlatAffineConstraints (FACs) are stored in a vector, and the set +/// represents the union of these FACs. +class PresburgerSet { +public: + /// Construct an empty PresburgerSet. + PresburgerSet(unsigned nDim = 0, unsigned nSym = 0) + : nDim(nDim), nSym(nSym) {} + + explicit PresburgerSet(const FlatAffineConstraints &fac); + + /// Return the number of FACs in the union. + unsigned getNumFACs() const; + + /// Return the number of real dimensions. + unsigned getNumDims() const; + + /// Return the number of symbolic dimensions. + unsigned getNumSyms() const; + + /// Return a reference to the list of FlatAffineConstraints. + ArrayRef getAllFlatAffineConstraints() const; + + /// Return the FlatAffineConsatraints at the specified index. + const FlatAffineConstraints &getFlatAffineConstraints(unsigned index) const; + + /// Add the given FlatAffineConstraints to the union. + void addFlatAffineConstraints(const FlatAffineConstraints &fac); + + /// Return the union of this set and the given set. + PresburgerSet unionSet(const PresburgerSet &set) const; + + /// Return the intersection of this set and the given set. + PresburgerSet intersect(const PresburgerSet &set) const; + + /// Return true if the set contains the given point, or false otherwise. + bool containsPoint(ArrayRef point) const; + + /// Print the set's internal state. + void print(raw_ostream &os) const; + void dump() const; + + /// Return the complement of this set. + PresburgerSet complement() const; + + /// Return the set difference of this set and the given set, i.e., + /// return `this \ set`. + PresburgerSet subtract(const PresburgerSet &set) const; + + /// Return a universe set of the specified type that contains all points. + static PresburgerSet universe(unsigned nDim = 0, unsigned nSym = 0); + /// Return an empty set of the specified type that contains no points. + static PresburgerSet emptySet(unsigned nDim = 0, unsigned nSym = 0); + + /// Return true if all the sets in the union are known to be integer empty + /// false otherwise. + bool isIntegerEmpty() const; + + /// Find an integer sample from the given set. This should not be called if + /// any of the FACs in the union are unbounded. + bool findIntegerSample(SmallVectorImpl &sample); + +private: + /// Return the set difference fac \ set. + static PresburgerSet getSetDifference(FlatAffineConstraints fac, + const PresburgerSet &set); + + /// Number of identifiers corresponding to real dimensions. + unsigned nDim; + + /// Number of symbolic dimensions, unknown but constant for analysis, as in + /// FlatAffineConstraints. + unsigned nSym; + + /// The list of flatAffineConstraints that this set is the union of. + SmallVector flatAffineConstraints; +}; + +} // namespace mlir + +#endif // MLIR_ANALYSIS_PRESBURGER_SET_H diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -169,6 +169,9 @@ /// Rollback to a snapshot. This invalidates all later snapshots. void rollback(unsigned snapshot); + /// Add all the constraints from the given FlatAffineConstraints. + void addFlatAffineConstraints(const FlatAffineConstraints &fac); + /// Compute the maximum or minimum value of the given row, depending on /// direction. The specified row is never pivoted. /// diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1073,6 +1073,33 @@ return Simplex(*this).findIntegerSample(); } +/// Helper to evaluate the value of an affine expression at a point. +/// The expression is a list of coefficients for the dimensions followed by the +/// constant term. +static int64_t valueAt(ArrayRef expr, ArrayRef point) { + assert(expr.size() == 1 + point.size() && + "Dimensionalities of point and expresion don't match!"); + int64_t value = expr.back(); + for (unsigned i = 0; i < point.size(); ++i) + value += expr[i] * point[i]; + return value; +} + +/// A point satisfies an equality iff the value of the equality at the +/// expression is zero, and it satisfies an inequality iff the value of the +/// inequality at that point is non-negative. +bool FlatAffineConstraints::containsPoint(ArrayRef point) const { + for (unsigned i = 0; i < getNumEqualities(); ++i) { + if (valueAt(getEquality(i), point) != 0) + return false; + } + for (unsigned i = 0; i < getNumInequalities(); ++i) { + if (valueAt(getInequality(i), point) < 0) + return false; + } + return true; +} + /// Tightens inequalities given that we are dealing with integer spaces. This is /// analogous to the GCD test but applied to inequalities. The constant term can /// be reduced to the preceding multiple of the GCD of the coefficients, i.e., diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -25,7 +25,6 @@ MLIRCallInterfaces MLIRControlFlowInterfaces MLIRInferTypeOpInterface - MLIRPresburger MLIRSCF ) @@ -34,6 +33,9 @@ AffineStructures.cpp LoopAnalysis.cpp NestedMatcher.cpp + Presburger/Simplex.cpp + Presburger/Matrix.cpp + Presburger/Set.cpp Utils.cpp ADDITIONAL_HEADER_DIRS @@ -47,8 +49,6 @@ MLIRCallInterfaces MLIRControlFlowInterfaces MLIRInferTypeOpInterface - MLIRPresburger MLIRSCF ) -add_subdirectory(Presburger) diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Analysis/Presburger/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -add_mlir_library(MLIRPresburger - Simplex.cpp - Matrix.cpp - ) diff --git a/mlir/lib/Analysis/Presburger/Set.cpp b/mlir/lib/Analysis/Presburger/Set.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Set.cpp @@ -0,0 +1,311 @@ +//===- Set.cpp - MLIR PresburgerSet Class ---------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/Set.h" +#include "mlir/Analysis/Presburger/Simplex.h" +#include "llvm/ADT/SmallBitVector.h" + +using namespace mlir; + +PresburgerSet::PresburgerSet(const FlatAffineConstraints &fac) + : nDim(fac.getNumDimIds()), nSym(fac.getNumSymbolIds()) { + addFlatAffineConstraints(fac); +} + +unsigned PresburgerSet::getNumFACs() const { + return flatAffineConstraints.size(); +} + +unsigned PresburgerSet::getNumDims() const { return nDim; } + +unsigned PresburgerSet::getNumSyms() const { return nSym; } + +ArrayRef +PresburgerSet::getAllFlatAffineConstraints() const { + return flatAffineConstraints; +} + +const FlatAffineConstraints & +PresburgerSet::getFlatAffineConstraints(unsigned index) const { + assert(index < flatAffineConstraints.size() && "index out of bounds!"); + return flatAffineConstraints[index]; +} + +/// Assert that the FlatAffineConstraints and PresburgerSet live in +/// compatible spaces. +static void assertDimensionsCompatible(const FlatAffineConstraints &fac, + const PresburgerSet &set) { + assert(fac.getNumDimIds() == set.getNumDims() && + "Number of dimensions of the FlatAffineConstraints and PresburgerSet" + "do not match!"); + assert(fac.getNumSymbolIds() == set.getNumSyms() && + "Number of symbols of the FlatAffineConstraints and PresburgerSet" + "do not match!"); +} + +/// Assert that the two PresburgerSets live in compatible spaces. +static void assertDimensionsCompatible(const PresburgerSet &setA, + const PresburgerSet &setB) { + assert(setA.getNumDims() == setB.getNumDims() && + "Number of dimensions of the PresburgerSets do not match!"); + assert(setA.getNumSyms() == setB.getNumSyms() && + "Number of symbols of the PresburgerSets do not match!"); +} + +/// Add a FAC to the union. +void PresburgerSet::addFlatAffineConstraints(const FlatAffineConstraints &fac) { + assertDimensionsCompatible(fac, *this); + flatAffineConstraints.push_back(fac); +} + +/// Return the union of this set and the given set. +/// +/// This is accomplished by simply adding all the FACs of the given set to the +/// current set. +PresburgerSet PresburgerSet::unionSet(const PresburgerSet &set) const { + assertDimensionsCompatible(set, *this); + // This copy is not strictly necessary; the result could be constructed + // in-place. However, to keep the API uniform with intersect, subtract and + // complement which return the result of their operations, we need to make a + // copy here. + PresburgerSet result = *this; + for (const FlatAffineConstraints &fac : set.flatAffineConstraints) + result.addFlatAffineConstraints(fac); + return result; +} + +/// A point is contained in the union iff any of the parts contain the point. +bool PresburgerSet::containsPoint(ArrayRef point) const { + for (const FlatAffineConstraints &fac : flatAffineConstraints) { + if (fac.containsPoint(point)) + return true; + } + return false; +} + +PresburgerSet PresburgerSet::universe(unsigned nDim, unsigned nSym) { + PresburgerSet result(nDim, nSym); + result.addFlatAffineConstraints(FlatAffineConstraints::universe(nDim, nSym)); + return result; +} + +PresburgerSet PresburgerSet::emptySet(unsigned nDim, unsigned nSym) { + return PresburgerSet(nDim, nSym); +} + +// Return the intersection of this set with the given set. +// +// We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...) +// as (S_1 and T_1) or (S_1 and T_2) or ... +PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const { + assertDimensionsCompatible(set, *this); + + PresburgerSet result(nDim, nSym); + for (const FlatAffineConstraints &csA : flatAffineConstraints) { + for (const FlatAffineConstraints &csB : set.flatAffineConstraints) { + FlatAffineConstraints intersection(csA); + intersection.append(csB); + if (!intersection.isEmpty()) + result.addFlatAffineConstraints(std::move(intersection)); + } + } + return result; +} + +/// Return `coeffs` with all the elements negated. +static SmallVector getNegatedCoeffs(ArrayRef coeffs) { + SmallVector negatedCoeffs; + negatedCoeffs.reserve(coeffs.size()); + for (int64_t coeff : coeffs) + negatedCoeffs.emplace_back(-coeff); + return negatedCoeffs; +} + +/// Return the complement of the given inequality. +/// +/// The complement of a_1 x_1 + ... + a_n x_ + c >= 0 is +/// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0. +static SmallVector getComplementIneq(ArrayRef ineq) { + SmallVector coeffs; + coeffs.reserve(ineq.size()); + for (int64_t coeff : ineq) + coeffs.emplace_back(-coeff); + --coeffs.back(); + return coeffs; +} + +/// Return the set difference b \ s and accumulate the result into `result`. +/// `simplex` must correspond to b. +/// +/// In the following, V denotes union, ^ denotes intersection, \ denotes set +/// difference and ~ denotes complement. +/// Let b be the FlatAffineConstraints and s = (V_i s_i) be the set. We want +/// b \ (V_i s_i). +/// +/// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute +/// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality: +/// ~s_i = (~s_i1) V (s_i1 ^ ~s_i2) V (s_i1 ^ s_i2 ^ ~s_i3) V ... +/// And the required result is (b ^ ~s_i1) V (b ^ s_i1 ^ ~s_i2) V ... +/// We recurse by subtracting V_{j > i} S_j from each of these parts and +/// returning the union of the results. Each equality is handled as a +/// conjunction of two inequalities. +/// +/// As a heuristic, we try adding all the constraints and check if simplex +/// says that the intersection is empty. Also, in the process we find out that +/// some constraints are redundant. These redundant constraints are ignored. +static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex, + const PresburgerSet &s, unsigned i, + PresburgerSet &result) { + if (i == s.getNumFACs()) { + result.addFlatAffineConstraints(b); + return; + } + const FlatAffineConstraints &sI = s.getFlatAffineConstraints(i); + unsigned initialSnapshot = simplex.getSnapshot(); + unsigned offset = simplex.numConstraints(); + simplex.addFlatAffineConstraints(sI); + + if (simplex.isEmpty()) { + /// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1. + simplex.rollback(initialSnapshot); + subtractRecursively(b, simplex, s, i + 1, result); + return; + } + + simplex.detectRedundant(); + llvm::SmallBitVector isMarkedRedundant; + for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities(); + j++) + isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j)); + + simplex.rollback(initialSnapshot); + + // Recurse with the part b ^ ~ineq. Note that b is modified throughout + // this function. At the time this function is called, the current b is + // actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next + // inequality, s_{i,j+1}. This function recurses into the next level i + 1 + // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}. + auto recurseWithInequality = [&, i](ArrayRef ineq) { + size_t snapshot = simplex.getSnapshot(); + b.addInequality(ineq); + simplex.addInequality(ineq); + subtractRecursively(b, simplex, s, i + 1, result); + b.removeInequality(b.getNumInequalities() - 1); + simplex.rollback(snapshot); + }; + + // For each inequality ineq, we first recurse with the part where ineq + // is not satisfied, and then add the ineq to b and simplex because + // ineq must be satisfied by all later parts. + auto processInequality = [&](ArrayRef ineq) { + recurseWithInequality(getComplementIneq(ineq)); + b.addInequality(ineq); + simplex.addInequality(ineq); + }; + + // processInequality appends some additional constraints to b. We want to + // rollback b to its initial state before returning, which we will do by + // removing all constraints beyond the original number of inequalities + // and equalities, so we store these counts first. + unsigned originalNumIneqs = b.getNumInequalities(); + unsigned originalNumEqs = b.getNumEqualities(); + + for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) { + if (isMarkedRedundant[j]) + continue; + processInequality(sI.getInequality(j)); + } + + offset = sI.getNumInequalities(); + for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) { + const ArrayRef &coeffs = sI.getEquality(j); + // Same as the above loop for inequalities, done once each for the positive + // and negative inequalities that make up this equality. + if (!isMarkedRedundant[offset + 2 * j]) + processInequality(coeffs); + if (!isMarkedRedundant[offset + 2 * j + 1]) + processInequality(getNegatedCoeffs(coeffs)); + } + + // Rollback b and simplex to their initial states. + for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i) + b.removeInequality(i - 1); + + for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i) + b.removeEquality(i - 1); + + simplex.rollback(initialSnapshot); +} + +/// Return the set difference fac \ set. +/// +/// The FAC here is modified in subtractRecursively, so it cannot be a const +/// reference even though it is restored to its original state before returning +/// from that function. +PresburgerSet PresburgerSet::getSetDifference(FlatAffineConstraints fac, + const PresburgerSet &set) { + assertDimensionsCompatible(fac, set); + if (fac.isEmptyByGCDTest()) + return PresburgerSet(fac.getNumDimIds(), fac.getNumSymbolIds()); + + PresburgerSet result(fac.getNumDimIds(), fac.getNumSymbolIds()); + Simplex simplex(fac); + subtractRecursively(fac, simplex, set, 0, result); + return result; +} + +/// Return the complement of this set. +PresburgerSet PresburgerSet::complement() const { + return getSetDifference( + FlatAffineConstraints::universe(getNumDims(), getNumSyms()), *this); +} + +/// Return the result of subtract the given set from this set, i.e., +/// return `this \ set`. +PresburgerSet PresburgerSet::subtract(const PresburgerSet &set) const { + assertDimensionsCompatible(set, *this); + PresburgerSet result(nDim, nSym); + /// We compute (V_i t_i) \ (V_i set_i) as V_i (t_i \ V_i set_i). + for (const FlatAffineConstraints &fac : flatAffineConstraints) + result = result.unionSet(getSetDifference(fac, set)); + return result; +} + +/// Return true if all the sets in the union are known to be integer empty, +/// false otherwise. +bool PresburgerSet::isIntegerEmpty() const { + assert(nSym == 0 && "isIntegerEmpty is intended for non-symbolic sets"); + // The set is empty iff all of the disjuncts are empty. + for (const FlatAffineConstraints &fac : flatAffineConstraints) { + if (!fac.isIntegerEmpty()) + return false; + } + return true; +} + +bool PresburgerSet::findIntegerSample(SmallVectorImpl &sample) { + assert(nSym == 0 && "findIntegerSample is intended for non-symbolic sets"); + // A sample exists iff any of the disjuncts contains a sample. + for (FlatAffineConstraints &fac : flatAffineConstraints) { + if (Optional> opt = fac.findIntegerSample()) { + sample = std::move(*opt); + return true; + } + } + return false; +} + +void PresburgerSet::print(raw_ostream &os) const { + os << getNumFACs() << " FlatAffineConstraints:\n"; + for (const FlatAffineConstraints &fac : flatAffineConstraints) { + fac.print(os); + os << '\n'; + } +} + +void PresburgerSet::dump() const { print(llvm::errs()); } diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -451,6 +451,16 @@ } } +/// Add all the constraints from the given FlatAffineConstraints. +void Simplex::addFlatAffineConstraints(const FlatAffineConstraints &fac) { + assert(fac.getNumIds() == numVariables() && + "FlatAffineConstraints must have same dimensionality as simplex"); + for (unsigned i = 0, e = fac.getNumInequalities(); i < e; ++i) + addInequality(fac.getInequality(i)); + for (unsigned i = 0, e = fac.getNumEqualities(); i < e; ++i) + addEquality(fac.getEquality(i)); +} + Optional Simplex::computeRowOptimum(Direction direction, unsigned row) { // Keep trying to find a pivot for the row in the specified direction. diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -15,17 +15,6 @@ namespace mlir { -/// Evaluate the value of the given affine expression at the specified point. -/// The expression is a list of coefficients for the dimensions followed by the -/// constant term. -int64_t valueAt(ArrayRef expr, ArrayRef point) { - assert(expr.size() == 1 + point.size()); - int64_t value = expr.back(); - for (unsigned i = 0; i < point.size(); ++i) - value += expr[i] * point[i]; - return value; -} - /// If 'hasValue' is true, check that findIntegerSample returns a valid sample /// for the FlatAffineConstraints fac. /// @@ -41,10 +30,7 @@ } } else { ASSERT_TRUE(maybeSample.hasValue()); - for (unsigned i = 0; i < fac.getNumEqualities(); ++i) - EXPECT_EQ(valueAt(fac.getEquality(i), *maybeSample), 0); - for (unsigned i = 0; i < fac.getNumInequalities(); ++i) - EXPECT_GE(valueAt(fac.getInequality(i), *maybeSample), 0); + EXPECT_TRUE(fac.containsPoint(*maybeSample)); } } diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -1,7 +1,9 @@ add_mlir_unittest(MLIRPresburgerTests MatrixTest.cpp SimplexTest.cpp + SetTest.cpp ) target_link_libraries(MLIRPresburgerTests - PRIVATE MLIRPresburger) + PRIVATE MLIRLoopAnalysis) + diff --git a/mlir/unittests/Analysis/Presburger/SetTest.cpp b/mlir/unittests/Analysis/Presburger/SetTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/SetTest.cpp @@ -0,0 +1,509 @@ +//===- SetTest.cpp - Tests for PresburgerSet-------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/Set.h" + +#include +#include + +namespace mlir { + +void testUnionAtPoints(PresburgerSet s, PresburgerSet t, + ArrayRef> points) { + PresburgerSet unionSet = s.unionSet(t); + for (const auto &point : points) { + bool inS = s.containsPoint(point); + bool inT = t.containsPoint(point); + bool inUnion = unionSet.containsPoint(point); + EXPECT_EQ(inUnion, inS || inT); + } +} + +void testIntersectAtPoints(PresburgerSet s, PresburgerSet t, + ArrayRef> points) { + PresburgerSet intersection = s.intersect(t); + for (const auto &point : points) { + bool inS = s.containsPoint(point); + bool inT = t.containsPoint(point); + bool inIntersection = intersection.containsPoint(point); + EXPECT_EQ(inIntersection, inS && inT); + } +} + +void testSubtractAtPoints(PresburgerSet s, PresburgerSet t, + ArrayRef> points) { + PresburgerSet diff = s.subtract(t); + for (const auto &point : points) { + bool inS = s.containsPoint(point); + bool inT = t.containsPoint(point); + bool inDiff = diff.containsPoint(point); + if (inT) + EXPECT_FALSE(inDiff); + else + EXPECT_EQ(inDiff, inS); + } +} + +void testComplementAtPoints(PresburgerSet s, + ArrayRef> points) { + PresburgerSet complement = s.complement(); + complement.complement(); + for (const auto &point : points) { + bool inS = s.containsPoint(point); + bool inComplement = complement.containsPoint(point); + if (inS) + EXPECT_FALSE(inComplement); + else + EXPECT_TRUE(inComplement); + } +} + +/// Construct a FlatAffineConstraints from a set of inequality and +/// equality constraints. +FlatAffineConstraints +makeFACFromConstraints(unsigned dims, ArrayRef> ineqs, + ArrayRef> eqs) { + FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + 1, dims); + for (const auto &eq : eqs) + fac.addEquality(eq); + for (const auto &ineq : ineqs) + fac.addInequality(ineq); + return fac; +} + +FlatAffineConstraints makeFACFromIneq(unsigned dims, + ArrayRef> ineqs) { + return makeFACFromConstraints(dims, ineqs, {}); +} + +PresburgerSet makeSetFromFACs(unsigned dims, + ArrayRef facs) { + PresburgerSet set(dims); + for (const FlatAffineConstraints &fac : facs) + set.addFlatAffineConstraints(fac); + return set; +} + +TEST(SetTest, containsPoint) { + PresburgerSet setA = + makeSetFromFACs(1, { + makeFACFromIneq(1, {{1, -2}, // x >= 2. + {-1, 8}}), // x <= 8. + makeFACFromIneq(1, {{1, -10}, // x >= 10. + {-1, 20}}), // x <= 20. + }); + for (unsigned x = 0; x <= 21; ++x) { + if ((2 <= x && x <= 8) || (10 <= x && x <= 20)) + EXPECT_TRUE(setA.containsPoint({x})); + else + EXPECT_FALSE(setA.containsPoint({x})); + } + + // A parallelogram with vertices {(3, 1), (10, -6), (24, 8), (17, 15)} union + // a square with opposite corners (2, 2) and (10, 10). + PresburgerSet setB = + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 1, -2}, // x + y >= 4. + {-1, -1, 30}, // x + y <= 32. + {1, -1, 0}, // x - y >= 2. + {-1, 1, 10}, // x - y <= 16. + }), + makeFACFromIneq(2, { + {1, 0, -2}, // x >= 2. + {0, 1, -2}, // y >= 2. + {-1, 0, 10}, // x <= 10. + {0, -1, 10} // y <= 10. + })}); + + for (unsigned x = 1; x <= 25; ++x) { + for (unsigned y = -6; y <= 16; ++y) { + if (4 <= x + y && x + y <= 32 && 2 <= x - y && x - y <= 16) + EXPECT_TRUE(setB.containsPoint({x, y})); + else if (2 <= x && x <= 10 && 2 <= y && y <= 10) + EXPECT_TRUE(setB.containsPoint({x, y})); + else + EXPECT_FALSE(setB.containsPoint({x, y})); + } + } +} + +TEST(SetTest, Union) { + PresburgerSet set = + makeSetFromFACs(1, { + makeFACFromIneq(1, {{1, -2}, // x >= 2. + {-1, 8}}), // x <= 8. + makeFACFromIneq(1, {{1, -10}, // x >= 10. + {-1, 20}}), // x <= 20. + }); + + // Universe union set. + testUnionAtPoints(PresburgerSet::universe(1), set, + {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + // empty set union set. + testUnionAtPoints(PresburgerSet::emptySet(1), set, + {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + // empty set union Universe. + testUnionAtPoints(PresburgerSet::emptySet(1), PresburgerSet::universe(1), + {{1}, {2}, {0}, {-1}}); + + // Universe union empty set. + testUnionAtPoints(PresburgerSet::universe(1), PresburgerSet::emptySet(1), + {{1}, {2}, {0}, {-1}}); + + // empty set union empty set. + testUnionAtPoints(PresburgerSet::emptySet(1), PresburgerSet::emptySet(1), + {{1}, {2}, {0}, {-1}}); +} + +TEST(SetTest, Intersect) { + PresburgerSet set = + makeSetFromFACs(1, { + makeFACFromIneq(1, {{1, -2}, // x >= 2. + {-1, 8}}), // x <= 8. + makeFACFromIneq(1, {{1, -10}, // x >= 10. + {-1, 20}}), // x <= 20. + }); + + // Universe intersection set. + testIntersectAtPoints(PresburgerSet::universe(1), set, + {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + // empty set intersection set. + testIntersectAtPoints(PresburgerSet::emptySet(1), set, + {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + // empty set intersection Universe. + testIntersectAtPoints(PresburgerSet::emptySet(1), PresburgerSet::universe(1), + {{1}, {2}, {0}, {-1}}); + + // Universe intersection empty set. + testIntersectAtPoints(PresburgerSet::universe(1), PresburgerSet::emptySet(1), + {{1}, {2}, {0}, {-1}}); + + // Universe intersection Universe. + testIntersectAtPoints(PresburgerSet::universe(1), PresburgerSet::universe(1), + {{1}, {2}, {0}, {-1}}); +} + +TEST(SetTest, Subtract) { + // The interval [2, 8] minus + // the interval [10, 20]. + testSubtractAtPoints( + makeSetFromFACs(1, {makeFACFromIneq(1, {})}), + makeSetFromFACs(1, + { + makeFACFromIneq(1, {{1, -2}, // x >= 2. + {-1, 8}}), // x <= 8. + makeFACFromIneq(1, {{1, -10}, // x >= 10. + {-1, 20}}), // x <= 20. + }), + {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + // ((-infinity, 0] U [3, 4] U [6, 7]) - ([2, 3] U [5, 6]) + testSubtractAtPoints( + makeSetFromFACs(1, + { + makeFACFromIneq(1, + { + {-1, 0} // x <= 0. + }), + makeFACFromIneq(1, + { + {1, -3}, // x >= 3. + {-1, 4} // x <= 4. + }), + makeFACFromIneq(1, + { + {1, -6}, // x >= 6. + {-1, 7} // x <= 7. + }), + }), + makeSetFromFACs(1, {makeFACFromIneq(1, + { + {1, -2}, // x >= 2. + {-1, 3}, // x <= 3. + }), + makeFACFromIneq(1, + { + {1, -5}, // x >= 5. + {-1, 6} // x <= 6. + })}), + {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}}); + + // Expected result is {[x, y] : x > y}, i.e., {[x, y] : x >= y + 1}. + testSubtractAtPoints( + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, -1, 0} // x >= y. + })}), + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 1, 0} // x >= -y. + })}), + {{0, 1}, {1, 1}, {1, 0}, {1, -1}, {0, -1}}); + + // A rectangle with corners at (2, 2) and (10, 10), minus + // a rectangle with corners at (5, -10) and (7, 100). + // This splits the former rectangle into two halves, (2, 2) to (5, 10) and + // (7, 2) to (10, 10). + testSubtractAtPoints( + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 0, -2}, // x >= 2. + {0, 1, -2}, // y >= 2. + {-1, 0, 10}, // x <= 10. + {0, -1, 10} // y <= 10. + })}), + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 0, -5}, // x >= 5. + {0, 1, 10}, // y >= -10. + {-1, 0, 7}, // x <= 7. + {0, -1, 100}, // y <= 100. + })}), + {{1, 2}, {2, 2}, {4, 2}, {5, 2}, {7, 2}, {8, 2}, {11, 2}, + {1, 1}, {2, 1}, {4, 1}, {5, 1}, {7, 1}, {8, 1}, {11, 1}, + {1, 10}, {2, 10}, {4, 10}, {5, 10}, {7, 10}, {8, 10}, {11, 10}, + {1, 11}, {2, 11}, {4, 11}, {5, 11}, {7, 11}, {8, 11}, {11, 11}}); + + // A rectangle with corners at (2, 2) and (10, 10), minus + // a rectangle with corners at (5, 4) and (7, 8). + // This creates a hole in the middle of the former rectangle, and the + // resulting set can be represented as a union of four rectangles. + testSubtractAtPoints( + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 0, -2}, // x >= 2. + {0, 1, -2}, // y >= 2. + {-1, 0, 10}, // x <= 10. + {0, -1, 10} // y <= 10. + })}), + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 0, -5}, // x >= 5. + {0, 1, -4}, // y >= 4. + {-1, 0, 7}, // x <= 7. + {0, -1, 8}, // y <= 8. + })}), + {{1, 1}, + {2, 2}, + {10, 10}, + {11, 11}, + {5, 4}, + {7, 4}, + {5, 8}, + {7, 8}, + {4, 4}, + {8, 4}, + {4, 8}, + {8, 8}}); + + // The second set is a superset of the first one, since on the line x + y = 0, + // y <= 1 is equivalent to x >= -1. So the result is empty. + testSubtractAtPoints( + makeSetFromFACs(2, {makeFACFromConstraints(2, + { + {1, 0, 0} // x >= 0. + }, + { + {1, 1, 0} // x + y = 0. + })}), + makeSetFromFACs(2, {makeFACFromConstraints(2, + { + {0, -1, 1} // y <= 1. + }, + { + {1, 1, 0} // x + y = 0. + })}), + {{0, 0}, + {1, -1}, + {2, -2}, + {-1, 1}, + {-2, 2}, + {1, 1}, + {-1, -1}, + {-1, 1}, + {1, -1}}); + + // The result should be {0} U {2}. + testSubtractAtPoints( + makeSetFromFACs(1, + { + makeFACFromIneq(1, {{1, 0}, // x >= 0. + {-1, 2}}), // x <= 2. + }), + makeSetFromFACs(1, + { + makeFACFromConstraints(1, {}, + { + {1, -1} // x = 1. + }), + }), + {{-1}, {0}, {1}, {2}, {3}}); + + // Sets with lots of redundant inequalities to test the redundancy heuristic. + // (the heuristic is for the subtrahend, the second set which is the one being + // subtracted) + + // A parallelogram with vertices {(3, 1), (10, -6), (24, 8), (17, 15)} minus + // a triangle with vertices {(2, 2), (10, 2), (10, 10)}. + testSubtractAtPoints( + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 1, -2}, // x + y >= 4. + {-1, -1, 30}, // x + y <= 32. + {1, -1, 0}, // x - y >= 2. + {-1, 1, 10}, // x - y <= 16. + })}), + makeSetFromFACs( + 2, {makeFACFromIneq(2, + { + {1, 0, -2}, // x >= 2. [redundant] + {0, 1, -2}, // y >= 2. + {-1, 0, 10}, // x <= 10. + {0, -1, 10}, // y <= 10. [redundant] + {1, 1, -2}, // x + y >= 2. [redundant] + {-1, -1, 30}, // x + y <= 30. [redundant] + {1, -1, 0}, // x - y >= 0. + {-1, 1, 10}, // x - y <= 10. + })}), + {{1, 2}, {2, 2}, {3, 2}, {4, 2}, {1, 1}, {2, 1}, {3, 1}, + {4, 1}, {2, 0}, {3, 0}, {4, 0}, {5, 0}, {10, 2}, {11, 2}, + {10, 1}, {10, 10}, {10, 11}, {10, 9}, {11, 10}, {10, -6}, {11, -6}, + {24, 8}, {24, 7}, {17, 15}, {16, 15}}); + + testSubtractAtPoints( + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 1, -2}, // x + y >= 4. + {-1, -1, 30}, // x + y <= 32. + {1, -1, 0}, // x - y >= 2. + {-1, 1, 10}, // x - y <= 16. + })}), + makeSetFromFACs( + 2, {makeFACFromIneq(2, + { + {1, 0, -2}, // x >= 2. [redundant] + {0, 1, -2}, // y >= 2. + {-1, 0, 10}, // x <= 10. + {0, -1, 10}, // y <= 10. [redundant] + {1, 1, -2}, // x + y >= 2. [redundant] + {-1, -1, 30}, // x + y <= 30. [redundant] + {1, -1, 0}, // x - y >= 0. + {-1, 1, 10}, // x - y <= 10. + })}), + {{1, 2}, {2, 2}, {3, 2}, {4, 2}, {1, 1}, {2, 1}, {3, 1}, + {4, 1}, {2, 0}, {3, 0}, {4, 0}, {5, 0}, {10, 2}, {11, 2}, + {10, 1}, {10, 10}, {10, 11}, {10, 9}, {11, 10}, {10, -6}, {11, -6}, + {24, 8}, {24, 7}, {17, 15}, {16, 15}}); + + // ((-infinity, -5] U [3, 3] U [4, 4] U [5, 5]) - ([-2, -10] U [3, 4] U [6, + // 7]) + testSubtractAtPoints( + makeSetFromFACs(1, + { + makeFACFromIneq(1, + { + {-1, -5}, // x <= -5. + }), + makeFACFromConstraints(1, {}, + { + {1, -3} // x = 3. + }), + makeFACFromConstraints(1, {}, + { + {1, -4} // x = 4. + }), + makeFACFromConstraints(1, {}, + { + {1, -5} // x = 5. + }), + }), + makeSetFromFACs( + 1, + { + makeFACFromIneq(1, + { + {-1, -2}, // x <= -2. + {1, -10}, // x >= -10. + {-1, 0}, // x <= 0. [redundant] + {-1, 10}, // x <= 10. [redundant] + {1, -100}, // x >= -100. [redundant] + {1, -50} // x >= -50. [redundant] + }), + makeFACFromIneq(1, + { + {1, -3}, // x >= 3. + {-1, 4}, // x <= 4. + {1, 1}, // x >= -1. [redundant] + {1, 7}, // x >= -7. [redundant] + {-1, 10} // x <= 10. [redundant] + }), + makeFACFromIneq(1, + { + {1, -6}, // x >= 6. + {-1, 7}, // x <= 7. + {1, 1}, // x >= -1. [redundant] + {1, -3}, // x >= -3. [redundant] + {-1, 5} // x <= 5. [redundant] + }), + }), + {{-6}, + {-5}, + {-4}, + {-9}, + {-10}, + {-11}, + {0}, + {1}, + {2}, + {3}, + {4}, + {5}, + {6}, + {7}, + {8}}); +} + +TEST(SetTest, Complement) { + // Complement of universe. + testComplementAtPoints( + PresburgerSet::universe(1), + {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + // Complement of empty set. + testComplementAtPoints( + PresburgerSet::emptySet(1), + {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}}); + + testComplementAtPoints( + makeSetFromFACs(2, {makeFACFromIneq(2, + { + {1, 0, -2}, // x >= 2. + {0, 1, -2}, // y >= 2. + {-1, 0, 10}, // x <= 10. + {0, -1, 10} // y <= 10. + })}), + {{1, 1}, + {2, 1}, + {1, 2}, + {2, 2}, + {2, 3}, + {3, 2}, + {10, 10}, + {10, 11}, + {11, 10}, + {2, 10}, + {2, 11}, + {1, 10}}); +} + +} // namespace mlir