diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -331,13 +331,33 @@ // Add identifiers of the specified kind - specified positions are relative to // the kind of identifier. The coefficient column corresponding to the added - // identifier is initialized to zero. 'id' is the Value corresponding to the - // identifier that can optionally be provided. + // identifier is initialized to zero. 'ids' are the Values corresponding to + // the identifiers which can optionally be provided. void addDimId(unsigned pos, Value id = nullptr); void addSymbolId(unsigned pos, Value id = nullptr); void addLocalId(unsigned pos); void addId(IdKind kind, unsigned pos, Value id = nullptr); + void addDimIds(unsigned pos, unsigned count, ArrayRef newIds = {}); + void addSymbolIds(unsigned pos, unsigned count, ArrayRef newIds = {}); + void addLocalIds(unsigned pos, unsigned count); + void addIds(IdKind kind, unsigned pos, unsigned count, + ArrayRef newIds = {}); + + // Removes identifiers of the specified kind with the specified pos (or within + // the specified range) from the system. The specified location is relative to + // the first identifier of the specified kind. + void removeId(IdKind kind, unsigned pos); + void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit); + + /// Removes the specified identifier from the system. + void removeId(unsigned pos); + + /// Removes identifiers in the column range [idStart, idLimit), and copies any + /// remaining valid data into place, updates member variables, and resizes + /// arrays as needed. + void removeIdRange(unsigned idStart, unsigned idLimit); + /// Add the specified values as a dim or symbol id depending on its nature, if /// it already doesn't exist in the system. `id' has to be either a terminal /// symbol or a loop IV, i.e., it cannot be the result affine.apply of any @@ -348,7 +368,7 @@ /// ones) is allowed. void addInductionVarOrTerminalSymbol(Value id, bool allowNonTerminal = false); - /// Composes the affine value map with this FlatAffineConstrains, adding the + /// Composes the affine value map with this FlatAffineConstraints, adding the /// results of the map as dimensions at the front [0, vMap->getNumResults()) /// and with the dimensions set to the equalities specified by the value map. /// Returns failure if the composition fails (when vMap is a semi-affine map). @@ -376,12 +396,12 @@ /// Projects out the identifier that is associate with Value . void projectOut(Value id); - /// Removes the specified identifier from the system. - void removeId(unsigned pos); - void removeEquality(unsigned pos); void removeInequality(unsigned pos); + void removeEqualityRange(unsigned start, unsigned end); + void removeInequalityRange(unsigned start, unsigned end); + /// Changes the partition between dimensions and symbols. Depending on the new /// symbol count, either a chunk of trailing dimensional identifiers becomes /// symbols, or some of the leading symbols become dimensions. @@ -586,6 +606,12 @@ void dump() const; private: + /// Return the index at which the specified kind of id starts. + unsigned getIdOffset(IdKind kind) const; + + /// Assert that `value` is at most the number of ids of the specified kind. + void assertAtMostNumKind(unsigned value, IdKind kind) const; + /// Returns false if the fields corresponding to various identifier counts, or /// equality/inequality buffer sizes aren't consistent; true otherwise. This /// is meant to be used within an assert internally. @@ -635,11 +661,6 @@ /// Normalized each constraints by the GCD of its coefficients. void normalizeConstraintsByGCD(); - /// Removes identifiers in the column range [idStart, idLimit), and copies any - /// remaining valid data into place, updates member variables, and resizes - /// arrays as needed. - void removeIdRange(unsigned idStart, unsigned idLimit); - /// Total number of identifiers. unsigned numIds; diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h --- a/mlir/include/mlir/Analysis/Presburger/Matrix.h +++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h @@ -118,6 +118,7 @@ /// Resize the matrix to the specified dimensions. If a dimension is smaller, /// the values are truncated; if it is bigger, the new values are default /// initialized. + void resize(unsigned newNRows, unsigned newNColumns); void resizeVertically(unsigned newNRows); /// Add an extra row at the bottom of the matrix and return its position. diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -162,6 +162,10 @@ /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0. void addEquality(ArrayRef coeffs); + // Add new variables to the end of the list of variables. + void appendVariable(); + void appendVariables(unsigned count); + /// Mark the tableau as being empty. void markEmpty(); @@ -301,8 +305,9 @@ /// and the denominator. void normalizeRow(unsigned row); - /// Swap the two rows in the tableau and associated data structures. + /// Swap the two rows/columns in the tableau and associated data structures. void swapRows(unsigned i, unsigned j); + void swapColumns(unsigned i, unsigned j); /// Restore the unknown to a non-negative sample value. /// @@ -327,6 +332,7 @@ /// Enum to denote operations that need to be undone during rollback. enum class UndoLogEntry { RemoveLastConstraint, + RemoveLastVariable, UnmarkEmpty, UnmarkLastRedundant }; diff --git a/mlir/include/mlir/Analysis/PresburgerSet.h b/mlir/include/mlir/Analysis/PresburgerSet.h --- a/mlir/include/mlir/Analysis/PresburgerSet.h +++ b/mlir/include/mlir/Analysis/PresburgerSet.h @@ -72,8 +72,8 @@ PresburgerSet complement() const; /// Return the set difference of this set and the given set, i.e., - /// return `this \ set`. Subtracting when either set contains divisions is not - /// yet supported. + /// return `this \ set`. All local variables in `set` must correspond + /// to floor divisions, otherwise this function's behaviour is undefined. PresburgerSet subtract(const PresburgerSet &set) const; /// Return true if this set is equal to the given set, and false otherwise. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -235,39 +235,84 @@ addId(IdKind::Symbol, pos, id); } -/// Adds a dimensional identifier. The added column is initialized to -/// zero. void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value id) { + addIds(IdKind::Local, pos, 1, id); +} + +void FlatAffineConstraints::addLocalIds(unsigned pos, unsigned count) { + addIds(IdKind::Local, pos, count); +} + +void FlatAffineConstraints::addDimIds(unsigned pos, unsigned count, + ArrayRef newIds) { + addIds(IdKind::Dimension, pos, count, newIds); +} + +void FlatAffineConstraints::addSymbolIds(unsigned pos, unsigned count, + ArrayRef newIds) { + addIds(IdKind::Symbol, pos, count, newIds); +} + +unsigned FlatAffineConstraints::getIdOffset(IdKind kind) const { + if (kind == IdKind::Dimension) + return 0; + if (kind == IdKind::Symbol) + return getNumDimIds(); + if (kind == IdKind::Local) + return getNumDimIds() + getNumSymbolIds(); + llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!"); +} + +void FlatAffineConstraints::assertAtMostNumKind(unsigned val, + IdKind kind) const { if (kind == IdKind::Dimension) - assert(pos <= getNumDimIds()); + assert(val <= getNumDimIds()); else if (kind == IdKind::Symbol) - assert(pos <= getNumSymbolIds()); + assert(val <= getNumSymbolIds()); + else if (kind == IdKind::Local) + assert(val <= getNumLocalIds()); else - assert(pos <= getNumLocalIds()); - - int absolutePos; - if (kind == IdKind::Dimension) { - absolutePos = pos; - numDims++; - } else if (kind == IdKind::Symbol) { - absolutePos = pos + getNumDimIds(); - numSymbols++; - } else { - absolutePos = pos + getNumDimIds() + getNumSymbolIds(); - } - numIds++; + llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!"); +} - inequalities.insertColumn(absolutePos); - equalities.insertColumn(absolutePos); +/// Adds dimensional identifiers. The added columns are initialized to +/// zero. +void FlatAffineConstraints::addIds(IdKind kind, unsigned pos, unsigned count, + ArrayRef newIds) { + assertAtMostNumKind(pos, kind); + + unsigned absolutePos = pos + getIdOffset(kind); + if (kind == IdKind::Dimension) + numDims += count; + else if (kind == IdKind::Symbol) + numSymbols += count; + numIds += count; - // If an 'id' is provided, insert it; otherwise use None. - if (id) - ids.insert(ids.begin() + absolutePos, id); + inequalities.insertColumns(absolutePos, count); + equalities.insertColumns(absolutePos, count); + + // If ids are provided, insert them; otherwise use None. + if (newIds.empty()) + ids.insert(ids.begin() + absolutePos, count, None); else - ids.insert(ids.begin() + absolutePos, None); + ids.insert(ids.begin() + absolutePos, newIds.begin(), newIds.end()); assert(ids.size() == getNumIds()); } +void FlatAffineConstraints::removeId(IdKind kind, unsigned pos) { + removeIdRange(kind, pos, pos + 1); +} + +void FlatAffineConstraints::removeIdRange(IdKind kind, unsigned idStart, + unsigned idLimit) { + assert(idStart <= idLimit); + assertAtMostNumKind(idLimit, kind); + if (idStart == idLimit) + return; + + removeIdRange(getIdOffset(kind) + idStart, getIdOffset(kind) + idLimit); +} + /// Checks if two constraint systems are in the same space, i.e., if they are /// associated with the same set of identifiers, appearing in the same order. static bool areIdsAligned(const FlatAffineConstraints &a, @@ -2310,6 +2355,15 @@ inequalities.removeRow(pos); } +void FlatAffineConstraints::removeEqualityRange(unsigned begin, unsigned end) { + equalities.removeRows(begin, end - begin); +} + +void FlatAffineConstraints::removeInequalityRange(unsigned begin, + unsigned end) { + inequalities.removeRows(begin, end - begin); +} + /// Finds an equality that equates the specified identifier to a constant. /// Returns the position of the equality row. If 'symbolic' is set to true, /// symbols are also treated like a constant, i.e., an affine function of the diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp --- a/mlir/lib/Analysis/Presburger/Matrix.cpp +++ b/mlir/lib/Analysis/Presburger/Matrix.cpp @@ -65,6 +65,14 @@ return nRows - 1; } +void Matrix::resize(unsigned newNRows, unsigned newNColumns) { + if (newNColumns < nColumns) + removeColumns(newNColumns, nColumns - newNColumns); + if (newNColumns > nColumns) + insertColumns(nColumns, newNColumns - nColumns); + resizeVertically(newNRows); +} + void Matrix::resizeVertically(unsigned newNRows) { nRows = newNRows; data.resize(nRows * nReservedColumns); diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -345,6 +345,15 @@ unknownFromRow(j).pos = j; } +void Simplex::swapColumns(unsigned i, unsigned j) { + if (i == j) + return; + tableau.swapColumns(i, j); + std::swap(colUnknown[i], colUnknown[j]); + unknownFromColumn(i).pos = i; + unknownFromColumn(j).pos = j; +} + /// Mark this tableau empty and push an entry to the undo stack. void Simplex::markEmpty() { undoLog.push_back(UndoLogEntry::UnmarkEmpty); @@ -434,6 +443,14 @@ nRow--; rowUnknown.pop_back(); con.pop_back(); + } else if (entry == UndoLogEntry::RemoveLastVariable) { + assert(var.back().orientation == Orientation::Column && + "Variable to be removed must be in column orientation!"); + swapColumns(var.back().pos, nCol - 1); + tableau.resize(nRow, nCol - 1); + var.pop_back(); + colUnknown.pop_back(); + nCol--; } else if (entry == UndoLogEntry::UnmarkEmpty) { empty = false; } else if (entry == UndoLogEntry::UnmarkLastRedundant) { @@ -452,6 +469,19 @@ } } +void Simplex::appendVariable() { + undoLog.emplace_back(UndoLogEntry::RemoveLastVariable); + nCol++; + tableau.resize(nRow, nCol); + var.emplace_back(Orientation::Column, /*restricted=*/false, /*pos=*/nCol - 1); + colUnknown.push_back(var.size() - 1); +} + +void Simplex::appendVariables(unsigned count) { + for (unsigned i = 0; i < count; ++i) + appendVariable(); +} + /// Add all the constraints from the given FlatAffineConstraints. void Simplex::intersectFlatAffineConstraints(const FlatAffineConstraints &fac) { assert(fac.getNumIds() == numVariables() && diff --git a/mlir/lib/Analysis/PresburgerSet.cpp b/mlir/lib/Analysis/PresburgerSet.cpp --- a/mlir/lib/Analysis/PresburgerSet.cpp +++ b/mlir/lib/Analysis/PresburgerSet.cpp @@ -102,20 +102,35 @@ return PresburgerSet(nDim, nSym); } +// Let a and b have n and m locals respectively. +// Add m locals to a, after a's locals. +// Also add n locals to b, before b's locals. +void addLocalsToEachOther(FlatAffineConstraints &a, FlatAffineConstraints &b) { + unsigned initANumLocals = a.getNumLocalIds(); + a.addLocalIds(a.getNumLocalIds(), b.getNumLocalIds()); + // The number of locals of `a` has changed. + // We want to add as many locals to `b` as `a` initially had. + b.addLocalIds(0, initANumLocals); +} + // Return the intersection of this set with the given set. // // We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...) // as (S_1 and T_1) or (S_1 and T_2) or ... +// +// If S_i or T_j have local variables, then S_i and T_j contains the local +// variables of both. PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const { assertDimensionsCompatible(set, *this); PresburgerSet result(nDim, nSym); for (const FlatAffineConstraints &csA : flatAffineConstraints) { for (const FlatAffineConstraints &csB : set.flatAffineConstraints) { - FlatAffineConstraints intersection(csA); - intersection.append(csB); - if (!intersection.isEmpty()) - result.unionFACInPlace(std::move(intersection)); + FlatAffineConstraints csACopy = csA, csBCopy = csB; + addLocalsToEachOther(csACopy, csBCopy); + csACopy.append(std::move(csBCopy)); + if (!csACopy.isEmpty()) + result.unionFACInPlace(std::move(csACopy)); } } return result; @@ -133,7 +148,8 @@ /// Return the complement of the given inequality. /// /// The complement of a_1 x_1 + ... + a_n x_ + c >= 0 is -/// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0. +/// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0, +/// since all the variables are constrained to be integers. static SmallVector getComplementIneq(ArrayRef ineq) { SmallVector coeffs; coeffs.reserve(ineq.size()); @@ -146,22 +162,37 @@ /// Return the set difference b \ s and accumulate the result into `result`. /// `simplex` must correspond to b. /// -/// In the following, V denotes union, ^ denotes intersection, \ denotes set +/// In the following, U denotes union, ^ denotes intersection, \ denotes set /// difference and ~ denotes complement. -/// Let b be the FlatAffineConstraints and s = (V_i s_i) be the set. We want -/// b \ (V_i s_i). +/// Let b be the FlatAffineConstraints and s = (U_i s_i) be the set. We want +/// b \ (U_i s_i). /// /// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute /// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality: -/// ~s_i = (~s_i1) V (s_i1 ^ ~s_i2) V (s_i1 ^ s_i2 ^ ~s_i3) V ... -/// And the required result is (b ^ ~s_i1) V (b ^ s_i1 ^ ~s_i2) V ... -/// We recurse by subtracting V_{j > i} S_j from each of these parts and +/// ~s_i = (~s_i1) U (s_i1 ^ ~s_i2) U (s_i1 ^ s_i2 ^ ~s_i3) U ... +/// And the required result is (b ^ ~s_i1) U (b ^ s_i1 ^ ~s_i2) U ... +/// We recurse by subtracting U_{j > i} S_j from each of these parts and /// returning the union of the results. Each equality is handled as a /// conjunction of two inequalities. /// +/// Note that the same approach works even if an inequality involves a floor +/// division. For example, The complement of x <= 7*floor(x/7) is still +/// x > 7*floor(x/7). Since b \ s_i contains the inequalities of both b and s_i +/// (or the complements of those inequalities), b \ s_i may contain the +/// divisions present in both b and s_i. Therefore, we need to add the local +/// division variables of both b and s_i to each part in the result. This means +/// adding the local variables of both b and s_i, as well as the corresponding +/// division inequalities to each part. Since the division inequalities are +/// added to each part, we can skip the parts where the complement of any +/// division inequality is added, as these parts will become empty anyway. +/// /// As a heuristic, we try adding all the constraints and check if simplex -/// says that the intersection is empty. Also, in the process we find out that -/// some constraints are redundant. These redundant constraints are ignored. +/// says that the intersection is empty. If it is, then subtracting this FAC is +/// a no-op and we just skip it. Also, in the process we find out that some +/// constraints are redundant. These redundant constraints are ignored. +/// +/// b and simplex are callee saved, i.e., their values on return are +/// semantically equivalent to their values when the function is called. static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex, const PresburgerSet &s, unsigned i, PresburgerSet &result) { @@ -169,27 +200,63 @@ result.unionFACInPlace(b); return; } - const FlatAffineConstraints &sI = s.getFlatAffineConstraints(i); - assert(sI.getNumLocalIds() == 0 && - "Subtracting sets with divisions is not yet supported!"); + FlatAffineConstraints sI = s.getFlatAffineConstraints(i); + unsigned bInitNumLocals = b.getNumLocalIds(); + + // Find out which inequalities of sI correspond corresponding to division + // inequalities for the local variables of sI. + std::vector>> repr( + sI.getNumLocalIds()); + sI.getLocalReprLbUbPairs(repr); + + // Add sI's locals to b, after b's locals. Also add b's locals to sI, before + // sI's locals. + addLocalsToEachOther(b, sI); + + llvm::SmallBitVector isDivInequality(sI.getNumInequalities()); + for (Optional> &maybePair : repr) { + if (maybePair) { + b.addInequality(sI.getInequality(maybePair->first)); + b.addInequality(sI.getInequality(maybePair->second)); + + assert(maybePair->first != maybePair->second && + "Upper and lower bounds must be different inequalities!"); + isDivInequality[maybePair->first] = true; + isDivInequality[maybePair->second] = true; + } else { + llvm_unreachable( + "Subtraction is not supported when a representation of " + "the local variables of the subtrahend cannot be found!"); + } + } + unsigned initialSnapshot = simplex.getSnapshot(); unsigned offset = simplex.numConstraints(); + unsigned numLocalsAdded = b.getNumLocalIds() - bInitNumLocals; + simplex.appendVariables(numLocalsAdded); + + unsigned snapshotBeforeIntersect = simplex.getSnapshot(); simplex.intersectFlatAffineConstraints(sI); if (simplex.isEmpty()) { /// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1. simplex.rollback(initialSnapshot); + b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals, + b.getNumLocalIds()); subtractRecursively(b, simplex, s, i + 1, result); return; } simplex.detectRedundant(); - llvm::SmallBitVector isMarkedRedundant; - for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities(); - j++) - isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j)); - simplex.rollback(initialSnapshot); + // Equalities are added to simplex as a pair of inequalities. + unsigned totalNewSimplexInequalities = + 2 * sI.getNumEqualities() + sI.getNumInequalities(); + llvm::SmallBitVector isMarkedRedundant(totalNewSimplexInequalities); + for (unsigned j = 0; j < totalNewSimplexInequalities; j++) + isMarkedRedundant[j] = simplex.isMarkedRedundant(offset + j); + + simplex.rollback(snapshotBeforeIntersect); // Recurse with the part b ^ ~ineq. Note that b is modified throughout // subtractRecursively. At the time this function is called, the current b is @@ -218,20 +285,27 @@ // rollback b to its initial state before returning, which we will do by // removing all constraints beyond the original number of inequalities // and equalities, so we store these counts first. - unsigned originalNumIneqs = b.getNumInequalities(); - unsigned originalNumEqs = b.getNumEqualities(); + unsigned bInitNumIneqs = b.getNumInequalities(); + unsigned bInitNumEqs = b.getNumEqualities(); + // Process all the inequalities, ignoring redundant inequalities and division + // inequalities. The result is correct whether or not we ignore these, but + // ignoring them makes the result simpler. for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) { if (isMarkedRedundant[j]) continue; + if (isDivInequality[j]) + continue; processInequality(sI.getInequality(j)); } offset = sI.getNumInequalities(); for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) { const ArrayRef &coeffs = sI.getEquality(j); - // Same as the above loop for inequalities, done once each for the positive - // and negative inequalities that make up this equality. + // Similar to the above loop for inequalities, done once each for the + // positive and negative inequalities that make up this equality. Divisions + // are always represented in terms of inequalities and not equalities, so we + // do not check that here. if (!isMarkedRedundant[offset + 2 * j]) processInequality(coeffs); if (!isMarkedRedundant[offset + 2 * j + 1]) @@ -239,11 +313,10 @@ } // Rollback b and simplex to their initial states. - for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i) - b.removeInequality(i - 1); - - for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i) - b.removeEquality(i - 1); + b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals, + b.getNumLocalIds()); + b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities()); + b.removeEqualityRange(bInitNumEqs, b.getNumEqualities()); simplex.rollback(initialSnapshot); } @@ -256,8 +329,6 @@ PresburgerSet PresburgerSet::getSetDifference(FlatAffineConstraints fac, const PresburgerSet &set) { assertDimensionsCompatible(fac, set); - assert(fac.getNumLocalIds() == 0 && - "Subtracting sets with divisions is not yet supported!"); if (fac.isEmptyByGCDTest()) return PresburgerSet::getEmptySet(fac.getNumDimIds(), fac.getNumSymbolIds()); diff --git a/mlir/unittests/Analysis/PresburgerSetTest.cpp b/mlir/unittests/Analysis/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/PresburgerSetTest.cpp @@ -82,9 +82,11 @@ /// Construct a FlatAffineConstraints from a set of inequality and /// equality constraints. static FlatAffineConstraints -makeFACFromConstraints(unsigned dims, ArrayRef> ineqs, - ArrayRef> eqs) { - FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + 1, dims); +makeFACFromConstraints(unsigned ids, ArrayRef> ineqs, + ArrayRef> eqs, + unsigned locals = 0) { + FlatAffineConstraints fac(ineqs.size(), eqs.size(), ids + 1, ids - locals, 0, + locals); for (const SmallVector &eq : eqs) fac.addEquality(eq); for (const SmallVector &ineq : ineqs) @@ -591,4 +593,34 @@ EXPECT_FALSE(rect.complement().isEqual(square.complement())); } +void expectEqual(PresburgerSet s, PresburgerSet t) { + EXPECT_TRUE(s.isEqual(t)); +} + +void expectEmpty(PresburgerSet s) { EXPECT_TRUE(s.isIntegerEmpty()); } + +TEST(SetTest, divisions) { + // Note: we currently need to add the equalities as inequalities to the FAC + // since detecting divisions based on equalities is not yet supported. + // + // evens = {x : exists q, x = 2q}. + PresburgerSet evens{ + makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, 0}}, 1)}; + // odds = {x : exists q, x = 2q + 1}. + PresburgerSet odds{ + makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, -1}}, 1)}; + // FlatAffineConstraints odd = odds.getFlatAffineConstraints(0); + PresburgerSet multiples3{ + makeFACFromConstraints(2, {{1, -3, 0}, {-1, 3, 2}}, {{1, -3, 0}}, 1)}; + // multiples6 = {x : exists q, x = 6q}. + PresburgerSet multiples6{ + makeFACFromConstraints(2, {{1, -6, 0}, {-1, 6, 5}}, {{1, -6, 0}}, 1)}; + + // expectEmpty(PresburgerSet(evens).intersect(PresburgerSet(odds))); + // expectEqual(evens.unionSet(odds), PresburgerSet::getUniverse(1)); + expectEqual(evens.complement(), odds); + expectEqual(odds.complement(), evens); + expectEqual(multiples3.intersect(evens), multiples6); +} + } // namespace mlir