diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -120,6 +120,8 @@ /// The list of disjuncts that this set is the union of. SmallVector integerRelations; + + friend class SetCoalescer; }; class PresburgerSet : public PresburgerRelation { diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -393,16 +393,162 @@ return result; } -/// Given an IntegerRelation `p` and one of its inequalities `ineq`, check -/// that all inequalities of `cuttingIneqs` are redundant for the facet of `p` -/// where `ineq` holds as an equality. `simp` must be the Simplex constructed -/// from `p`. -static bool isFacetContained(ArrayRef ineq, Simplex &simp, - IntegerRelation &p, - ArrayRef> cuttingIneqs) { +/// The SetCoalescer class contains all functionality concerning the coalesce +/// heuristic. It is built from a `PresburgerRelation` and has the `coalesce()` +/// function as its main API. The coalesce heuristic simplifies the +/// representation of a PresburgerRelation. In particular, it removes all +/// disjuncts which are subsets of other disjuncts in the union and it combines +/// sets that overlap and can be combined in a convex way. +class presburger::SetCoalescer { + +public: + /// Simplifies the representation of a PresburgerSet. + PresburgerRelation coalesce(); + + /// Construct a SetCoalescer from a PresburgerSet. + SetCoalescer(const PresburgerRelation &s); + +private: + /// The dimensionality of the set the SetCoalescer is coalescing. + unsigned numDomainIds; + unsigned numRangeIds; + unsigned numSymbolIds; + + /// The current list of `IntegerRelation`s that the currently coalesced set is + /// the union of. + SmallVector disjuncts; + /// The list of `Simplex`s constructed from the elements of `disjuncts`. + SmallVector simplices; + + /// The list of all inversed equalities during typing. This ensures that + /// the constraints exist even after the typing function has concluded. + SmallVector, 2> negEqs; + + /// `redundantIneqsA` is the inequalities of `a` that are redundant for `b` + /// (similarly for `cuttingIneqsA`, `redundantIneqsB`, and `cuttingIneqsB`). + SmallVector, 2> redundantIneqsA; + SmallVector, 2> cuttingIneqsA; + + SmallVector, 2> redundantIneqsB; + SmallVector, 2> cuttingIneqsB; + + /// Given a Simplex `simp` and one of its inequalities `ineq`, check + /// that the facet of `simp` where `ineq` holds as an equality is contained + /// within `a`. + bool isFacetContained(ArrayRef ineq, Simplex &simp); + + /// Adds `disjunct` to `disjuncts` and removes the disjuncts at position `i` + /// and `j`. Updates `simplices` to reflect the changes. `i` and `j` cannot + /// be equal. + void addCoalescedDisjunct(unsigned i, unsigned j, + const IntegerRelation &disjunct); + + /// Checks whether `a` and `b` can be combined in a convex sense, if there + /// exist cutting inequalities. + /// + /// An example of this case: + /// ___________ ___________ + /// / / | / / / + /// \ \ | / ==> \ / + /// \ \ | / \ / + /// \___\|/ \_____/ + /// + /// + LogicalResult coalescePairCutCase(unsigned i, unsigned j); + + /// Types the inequality `ineq` according to its `IneqType` for `simp` into + /// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate + /// inequalities were encountered. Otherwise, returns failure. + LogicalResult typeInequality(ArrayRef ineq, Simplex &simp); + + /// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and + /// -`eq` >= 0 according to their `IneqType` for `simp` into + /// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate + /// inequalities were encountered. Otherwise, returns failure. + LogicalResult typeEquality(ArrayRef eq, Simplex &simp); + + /// Replaces the element at position `i` with the last element and erases + /// the last element for both `disjuncts` and `simplices`. + void eraseDisjunct(unsigned i); + + /// Attempts to coalesce the two IntegerRelations at position `i` and `j` + /// in `disjuncts` in-place. Returns whether the disjuncts were + /// successfully coalesced. The simplices in `simplices` need to be the ones + /// constructed from `disjuncts`. At this point, there are no empty + /// disjuncts in `disjuncts` left. + LogicalResult coalescePair(unsigned i, unsigned j); +}; + +/// Constructs a `SetCoalescer` from a `PresburgerRelation`. Only adds non-empty +/// `IntegerRelation`s to the `disjuncts` vector. +SetCoalescer::SetCoalescer(const PresburgerRelation &s) { + + disjuncts = s.integerRelations; + + simplices.reserve(s.getNumDisjuncts()); + // Note that disjuncts.size() changes during the loop. + for (unsigned i = 0; i < disjuncts.size();) { + Simplex simp(disjuncts[i]); + if (simp.isEmpty()) { + disjuncts[i] = disjuncts[disjuncts.size() - 1]; + disjuncts.pop_back(); + continue; + } + ++i; + simplices.push_back(simp); + } + numDomainIds = s.getNumDomainIds(); + numRangeIds = s.getNumRangeIds(); + numSymbolIds = s.getNumSymbolIds(); +} + +/// Simplifies the representation of a PresburgerSet. +PresburgerRelation SetCoalescer::coalesce() { + // For all tuples of IntegerRelations, check whether they can be + // coalesced. When coalescing is successful, the contained IntegerRelation + // is swapped with the last element of `disjuncts` and subsequently erased + // and similarly for simplices. + for (unsigned i = 0; i < disjuncts.size();) { + + // TODO: This does some comparisons two times (index 0 with 1 and index 1 + // with 0). + bool broken = false; + for (unsigned j = 0, e = disjuncts.size(); j < e; ++j) { + negEqs.clear(); + redundantIneqsA.clear(); + redundantIneqsB.clear(); + cuttingIneqsA.clear(); + cuttingIneqsB.clear(); + if (i == j) + continue; + if (coalescePair(i, j).succeeded()) { + broken = true; + break; + } + } + + // Only if the inner loop was not broken, i is incremented. This is + // required as otherwise, if a coalescing occurs, the IntegerRelation + // now at position i is not compared. + if (!broken) + ++i; + } + + PresburgerRelation newSet = + PresburgerRelation::getEmpty(numDomainIds, numRangeIds, numSymbolIds); + for (unsigned i = 0, e = disjuncts.size(); i < e; ++i) + newSet.unionInPlace(disjuncts[i]); + + return newSet; +} + +/// Given a Simplex `simp` and one of its inequalities `ineq`, check +/// that all inequalities of `cuttingIneqsB` are redundant for the facet of +/// `simp` where `ineq` holds as an equality is contained within `a`. +bool SetCoalescer::isFacetContained(ArrayRef ineq, Simplex &simp) { unsigned snapshot = simp.getSnapshot(); simp.addEquality(ineq); - if (llvm::any_of(cuttingIneqs, [&simp](ArrayRef curr) { + if (llvm::any_of(cuttingIneqsB, [&simp](ArrayRef curr) { return !simp.isRedundantInequality(curr); })) { simp.rollback(snapshot); @@ -412,20 +558,14 @@ return true; } -/// Adds `disjunct` to `disjuncts` and removes the disjuncts at position `i` and -/// `j`. Updates `simplices` to reflect the changes. `i` and `j` cannot be -/// equal. -static void addCoalescedDisjunct(SmallVectorImpl &disjuncts, - unsigned i, unsigned j, - const IntegerRelation &disjunct, - SmallVectorImpl &simplices) { +void SetCoalescer::addCoalescedDisjunct(unsigned i, unsigned j, + const IntegerRelation &disjunct) { assert(i != j && "The indices must refer to different disjuncts"); - unsigned n = disjuncts.size(); if (j == n - 1) { - // This case needs special handling since position `n` - 1 is removed from - // the vector, hence the `IntegerRelation` at position `n` - 2 is lost - // otherwise. + // This case needs special handling since position `n` - 1 is removed + // from the vector, hence the `IntegerRelation` at position `n` - 2 is + // lost otherwise. disjuncts[i] = disjuncts[n - 2]; disjuncts.pop_back(); disjuncts[n - 2] = disjunct; @@ -435,10 +575,11 @@ simplices[n - 2] = Simplex(disjunct); } else { - // Other possible edge cases are correct since for `j` or `i` == `n` - 2, - // the `IntegerRelation` at position `n` - 2 should be lost. The case - // `i` == `n` - 1 makes the first following statement a noop. Hence, in this - // case the same thing is done as above, but with `j` rather than `i`. + // Other possible edge cases are correct since for `j` or `i` == `n` - + // 2, the `IntegerRelation` at position `n` - 2 should be lost. The + // case `i` == `n` - 1 makes the first following statement a noop. + // Hence, in this case the same thing is done as above, but with `j` + // rather than `i`. disjuncts[i] = disjuncts[n - 1]; disjuncts[j] = disjuncts[n - 2]; disjuncts.pop_back(); @@ -451,12 +592,13 @@ } } -/// Given two disjuncts `a` and `b` at positions `i` and `j` in `disjuncts` -/// and `redundantIneqsA` being the inequalities of `a` that are redundant for -/// `b` (similarly for `cuttingIneqsA`, `redundantIneqsB`, and `cuttingIneqsB`), -/// checks whether the facets of all cutting inequalites of `a` are contained in -/// `b`. If so, a new disjunct consisting of all redundant inequalites of `a` -/// and `b` and all equalities of both is created. +/// Given two polyhedra `a` and `b` at positions `i` and `j` in +/// `disjuncts` and `redundantIneqsA` being the inequalities of `a` that +/// are redundant for `b` (similarly for `cuttingIneqsA`, `redundantIneqsB`, +/// and `cuttingIneqsB`), Checks whether the facets of all cutting +/// inequalites of `a` are contained in `b`. If so, a new polyhedron +/// consisting of all redundant inequalites of `a` and `b` and all +/// equalities of both is created. /// /// An example of this case: /// ___________ ___________ @@ -466,20 +608,13 @@ /// \___\|/ \_____/ /// /// -static LogicalResult -coalescePairCutCase(SmallVectorImpl &disjuncts, - SmallVectorImpl &simplices, unsigned i, unsigned j, - ArrayRef> redundantIneqsA, - ArrayRef> cuttingIneqsA, - ArrayRef> redundantIneqsB, - ArrayRef> cuttingIneqsB) { +LogicalResult SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) { /// All inequalities of `b` need to be redundant. We already know that the /// redundant ones are, so only the cutting ones remain to be checked. Simplex &simp = simplices[i]; IntegerRelation &disjunct = disjuncts[i]; - if (llvm::any_of(cuttingIneqsA, [&simp, &disjunct, - &cuttingIneqsB](ArrayRef curr) { - return !isFacetContained(curr, simp, disjunct, cuttingIneqsB); + if (llvm::any_of(cuttingIneqsA, [this, &simp](ArrayRef curr) { + return !isFacetContained(curr, simp); })) return failure(); IntegerRelation newSet(disjunct.getNumDomainIds(), disjunct.getNumRangeIds(), @@ -491,50 +626,33 @@ for (ArrayRef curr : redundantIneqsB) newSet.addInequality(curr); - addCoalescedDisjunct(disjuncts, i, j, newSet, simplices); + addCoalescedDisjunct(i, j, newSet); return success(); } -/// Types the inequality `ineq` according to its `IneqType` for `simp` into -/// `redundantIneqs` and `cuttingIneqs`. Returns success, if no separate -/// inequalities were encountered. Otherwise, returns failure. -static LogicalResult -typeInequality(ArrayRef ineq, Simplex &simp, - SmallVectorImpl> &redundantIneqs, - SmallVectorImpl> &cuttingIneqs) { +LogicalResult SetCoalescer::typeInequality(ArrayRef ineq, + Simplex &simp) { Simplex::IneqType type = simp.findIneqType(ineq); if (type == Simplex::IneqType::Redundant) - redundantIneqs.push_back(ineq); + redundantIneqsB.push_back(ineq); else if (type == Simplex::IneqType::Cut) - cuttingIneqs.push_back(ineq); + cuttingIneqsB.push_back(ineq); else return failure(); return success(); } -/// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and -`eq` -/// >= 0 according to their `IneqType` for `simp` into `redundantIneqs` and -/// `cuttingIneqs`. Returns success, if no separate inequalities were -/// encountered. Otherwise, returns failure. -static LogicalResult -typeEquality(ArrayRef eq, Simplex &simp, - SmallVectorImpl> &redundantIneqs, - SmallVectorImpl> &cuttingIneqs, - SmallVectorImpl> &negEqs) { - if (typeInequality(eq, simp, redundantIneqs, cuttingIneqs).failed()) +LogicalResult SetCoalescer::typeEquality(ArrayRef eq, Simplex &simp) { + if (typeInequality(eq, simp).failed()) return failure(); negEqs.push_back(getNegatedCoeffs(eq)); ArrayRef inv(negEqs.back()); - if (typeInequality(inv, simp, redundantIneqs, cuttingIneqs).failed()) + if (typeInequality(inv, simp).failed()) return failure(); return success(); } -/// Replaces the element at position `i` with the last element and erases the -/// last element for both `disjuncts` and `simplices`. -static void eraseDisjunct(unsigned i, - SmallVectorImpl &disjuncts, - SmallVectorImpl &simplices) { +void SetCoalescer::eraseDisjunct(unsigned i) { assert(simplices.size() == disjuncts.size() && "simplices and disjuncts must be equally as long"); disjuncts[i] = disjuncts.back(); @@ -543,133 +661,73 @@ simplices.pop_back(); } -/// Attempts to coalesce the two IntegerRelations at position `i` and `j` in -/// `disjuncts` in-place. Returns whether the disjuncts were successfully -/// coalesced. The simplices in `simplices` need to be the ones constructed from -/// `disjuncts`. At this point, there are no empty disjuncts in -/// `disjuncts` left. -static LogicalResult coalescePair(unsigned i, unsigned j, - SmallVectorImpl &disjuncts, - SmallVectorImpl &simplices) { +LogicalResult SetCoalescer::coalescePair(unsigned i, unsigned j) { IntegerRelation &a = disjuncts[i]; IntegerRelation &b = disjuncts[j]; - /// Handling of local ids is not yet implemented, so these cases are skipped. + /// Handling of local ids is not yet implemented, so these cases are + /// skipped. /// TODO: implement local id support. if (a.getNumLocalIds() != 0 || b.getNumLocalIds() != 0) return failure(); Simplex &simpA = simplices[i]; Simplex &simpB = simplices[j]; - SmallVector, 2> redundantIneqsA; - SmallVector, 2> cuttingIneqsA; - SmallVector, 2> negEqs; - - // Organize all inequalities and equalities of `a` according to their type for - // `b` into `redundantIneqsA` and `cuttingIneqsA` (and vice versa for all - // inequalities of `b` according to their type in `a`). If a separate - // inequality is encountered during typing, the two IntegerRelations cannot - // be coalesced. + // Organize all inequalities and equalities of `a` according to their type + // for `b` into `redundantIneqsA` and `cuttingIneqsA` (and vice versa for + // all inequalities of `b` according to their type in `a`). If a separate + // inequality is encountered during typing, the two IntegerRelations + // cannot be coalesced. for (int k = 0, e = a.getNumInequalities(); k < e; ++k) - if (typeInequality(a.getInequality(k), simpB, redundantIneqsA, - cuttingIneqsA) - .failed()) + if (typeInequality(a.getInequality(k), simpB).failed()) return failure(); for (int k = 0, e = a.getNumEqualities(); k < e; ++k) - if (typeEquality(a.getEquality(k), simpB, redundantIneqsA, cuttingIneqsA, - negEqs) - .failed()) + if (typeEquality(a.getEquality(k), simpB).failed()) return failure(); - SmallVector, 2> redundantIneqsB; - SmallVector, 2> cuttingIneqsB; + std::swap(redundantIneqsA, redundantIneqsB); + std::swap(cuttingIneqsA, cuttingIneqsB); + for (int k = 0, e = b.getNumInequalities(); k < e; ++k) - if (typeInequality(b.getInequality(k), simpA, redundantIneqsB, - cuttingIneqsB) - .failed()) + if (typeInequality(b.getInequality(k), simpA).failed()) return failure(); for (int k = 0, e = b.getNumEqualities(); k < e; ++k) - if (typeEquality(b.getEquality(k), simpA, redundantIneqsB, cuttingIneqsB, - negEqs) - .failed()) + if (typeEquality(b.getEquality(k), simpA).failed()) return failure(); // If there are no cutting inequalities of `a`, `b` is contained - // within `a` (and vice versa for `b`). + // within `a`. if (cuttingIneqsA.empty()) { - eraseDisjunct(j, disjuncts, simplices); + eraseDisjunct(j); return success(); } - if (cuttingIneqsB.empty()) { - eraseDisjunct(i, disjuncts, simplices); + // Try to apply the cut case + if (coalescePairCutCase(i, j).succeeded()) return success(); - } - // Try to apply the cut case - if (coalescePairCutCase(disjuncts, simplices, i, j, redundantIneqsA, - cuttingIneqsA, redundantIneqsB, cuttingIneqsB) - .succeeded()) + // Swap the vectors to compare the pair (j,i) instead of (i,j). + std::swap(redundantIneqsA, redundantIneqsB); + std::swap(cuttingIneqsA, cuttingIneqsB); + + // If there are no cutting inequalities of `a`, `b` is contained + // within `a`. + if (cuttingIneqsA.empty()) { + eraseDisjunct(i); return success(); + } - if (coalescePairCutCase(disjuncts, simplices, j, i, redundantIneqsB, - cuttingIneqsB, redundantIneqsA, cuttingIneqsA) - .succeeded()) + // Try to apply the cut case + if (coalescePairCutCase(j, i).succeeded()) return success(); return failure(); } PresburgerRelation PresburgerRelation::coalesce() const { - PresburgerRelation newSet = PresburgerRelation::getEmpty( - getNumDomainIds(), getNumRangeIds(), getNumSymbolIds()); - SmallVector disjuncts = integerRelations; - SmallVector simplices; - - simplices.reserve(getNumDisjuncts()); - // Note that disjuncts.size() changes during the loop. - for (unsigned i = 0; i < disjuncts.size();) { - Simplex simp(disjuncts[i]); - if (simp.isEmpty()) { - disjuncts[i] = disjuncts[disjuncts.size() - 1]; - disjuncts.pop_back(); - continue; - } - ++i; - simplices.push_back(simp); - } - - // For all tuples of IntegerRelations, check whether they can be coalesced. - // When coalescing is successful, the contained IntegerRelation is swapped - // with the last element of `disjuncts` and subsequently erased and - // similarly for simplices. - for (unsigned i = 0; i < disjuncts.size();) { - - // TODO: This does some comparisons two times (index 0 with 1 and index 1 - // with 0). - bool broken = false; - for (unsigned j = 0, e = disjuncts.size(); j < e; ++j) { - if (i == j) - continue; - if (coalescePair(i, j, disjuncts, simplices).succeeded()) { - broken = true; - break; - } - } - - // Only if the inner loop was not broken, i is incremented. This is - // required as otherwise, if a coalescing occurs, the IntegerRelation - // now at position i is not compared. - if (!broken) - ++i; - } - - for (unsigned i = 0, e = disjuncts.size(); i < e; ++i) - newSet.unionInPlace(disjuncts[i]); - - return newSet; + return SetCoalescer(*this).coalesce(); } void PresburgerRelation::print(raw_ostream &os) const {