diff --git a/mlir/lib/Analysis/Presburger/PresburgerSet.cpp b/mlir/lib/Analysis/Presburger/PresburgerSet.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSet.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSet.cpp @@ -379,30 +379,130 @@ return result; } -/// Types the inequalities of `p` according to their `IneqType` for `simp` into +/// Given an IntegerPolyhedron `p` and one of its inequalities `ineq`, check +/// that all inequalities of `cuttingIneqs` are redundant for the facet of `p` +/// where `ineq` holds as an equality. `simp` must be the Simplex constructed +/// from `p`. +static bool isFacetContained(ArrayRef ineq, Simplex &simp, + IntegerPolyhedron &p, + ArrayRef> cuttingIneqs) { + unsigned snapshot = simp.getSnapshot(); + simp.addEquality(ineq); + if (llvm::any_of(cuttingIneqs, [&simp](ArrayRef curr) { + return !simp.isRedundantInequality(curr); + })) { + simp.rollback(snapshot); + return false; + } + simp.rollback(snapshot); + return true; +} + +/// Adds `poly` to `polyhedrons` and removes the polyhedrons at position `i` and +/// `j`. Updates `simplices` to reflect the changes. `i` and `j` cannot be +/// equal. +static void +addCoalescedPolyhedron(SmallVectorImpl &polyhedrons, + unsigned i, unsigned j, const IntegerPolyhedron &poly, + SmallVectorImpl &simplices) { + assert(i != j && "The indices must refer to different polyhedra"); + + unsigned n = polyhedrons.size(); + polyhedrons[i] = polyhedrons[n - 1]; + polyhedrons[j] = polyhedrons[n - 2]; + polyhedrons.pop_back(); + polyhedrons[n - 2] = poly; + + simplices[i] = simplices[n - 1]; + simplices[j] = simplices[n - 2]; + simplices.pop_back(); + simplices[n - 2] = Simplex(poly); +} + +/// Given two polyhedra `a` and `b` at positions `i` and `j` in `polyhedrons` +/// and `redundantIneqsA` being the inequalities of `a` that are redundant for +/// `b` (similarly for `cuttingIneqsA`, `redundantIneqsB`, and `cuttingIneqsB`), +/// checks whether the facets of all cutting inequalites of `a` are contained in +/// `b`. If so, a new polyhedron consisting of all redundant inequalites of `a` +/// and `b` and all equalities of both is created. +/// +/// An example of this case: +/// ___________ ___________ +/// / / | / / / +/// \ \ | / ==> \ / +/// \ \ | / \ / +/// \___\|/ \_____/ +/// +/// +static LogicalResult +coalescePairCutCase(SmallVectorImpl &polyhedrons, + SmallVectorImpl &simplices, unsigned i, unsigned j, + ArrayRef> redundantIneqsA, + ArrayRef> cuttingIneqsA, + ArrayRef> redundantIneqsB, + ArrayRef> cuttingIneqsB) { + /// All inequalities of `b` need to be redundant. We already know that the + /// redundant ones are, so only the cutting ones remain to be checked. + Simplex &simp = simplices[i]; + IntegerPolyhedron &poly = polyhedrons[i]; + if (llvm::any_of(cuttingIneqsA, + [&simp, &poly, &cuttingIneqsB](ArrayRef curr) { + return !isFacetContained(curr, simp, poly, cuttingIneqsB); + })) + return failure(); + IntegerPolyhedron newSet(poly.getNumDimIds(), poly.getNumSymbolIds(), + poly.getNumLocalIds()); + + for (ArrayRef curr : redundantIneqsA) + newSet.addInequality(curr); + + for (ArrayRef curr : redundantIneqsB) + newSet.addInequality(curr); + + addCoalescedPolyhedron(polyhedrons, i, j, newSet, simplices); + return success(); +} + +/// Types the inequality `ineq` according to its `IneqType` for `simp` into /// `redundantIneqs` and `cuttingIneqs`. Returns success, if no separate /// inequalities were encountered. Otherwise, returns failure. -LogicalResult -typeInequalities(const IntegerPolyhedron &p, Simplex &simp, - SmallVectorImpl> &redundantIneqs, - SmallVectorImpl> &cuttingIneqs) { - for (unsigned i = 0, e = p.getNumInequalities(); i < e; ++i) { - Simplex::IneqType type = simp.findIneqType(p.getInequality(i)); - if (type == Simplex::IneqType::Redundant) - redundantIneqs.push_back(p.getInequality(i)); - else if (type == Simplex::IneqType::Cut) - cuttingIneqs.push_back(p.getInequality(i)); - else - return failure(); - } +static LogicalResult +typeInequality(ArrayRef ineq, Simplex &simp, + SmallVectorImpl> &redundantIneqs, + SmallVectorImpl> &cuttingIneqs) { + Simplex::IneqType type = simp.findIneqType(ineq); + if (type == Simplex::IneqType::Redundant) + redundantIneqs.push_back(ineq); + else if (type == Simplex::IneqType::Cut) + cuttingIneqs.push_back(ineq); + else + return failure(); + return success(); +} + +/// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and -`eq` +/// >= 0 according to their `IneqType` for `simp` into `redundantIneqs` and +/// `cuttingIneqs`. Returns success, if no separate inequalities were +/// encountered. Otherwise, returns failure. +static LogicalResult +typeEquality(ArrayRef eq, Simplex &simp, + SmallVectorImpl> &redundantIneqs, + SmallVectorImpl> &cuttingIneqs, + SmallVectorImpl> &negEqs) { + if (typeInequality(eq, simp, redundantIneqs, cuttingIneqs).failed()) + return failure(); + negEqs.push_back(getNegatedCoeffs(eq)); + ArrayRef inv(negEqs.back()); + if (typeInequality(inv, simp, redundantIneqs, cuttingIneqs).failed()) + return failure(); return success(); } /// Replaces the element at position `i` with the last element and erases the /// last element for both `polyhedrons` and `simplices`. -void erasePolyhedron(unsigned i, - SmallVectorImpl &polyhedrons, - SmallVectorImpl &simplices) { +static void erasePolyhedron(unsigned i, + SmallVectorImpl &polyhedrons, + SmallVectorImpl &simplices) { assert(simplices.size() == polyhedrons.size() && "simplices and polyhedrons must be equally as long"); polyhedrons[i] = polyhedrons.back(); @@ -416,9 +516,10 @@ /// coalesced. The simplices in `simplices` need to be the ones constructed from /// `polyhedrons`. At this point, there are no empty polyhedrons in /// `polyhedrons` left. -LogicalResult coalescePair(unsigned i, unsigned j, - SmallVectorImpl &polyhedrons, - SmallVectorImpl &simplices) { +static LogicalResult +coalescePair(unsigned i, unsigned j, + SmallVectorImpl &polyhedrons, + SmallVectorImpl &simplices) { IntegerPolyhedron &a = polyhedrons[i]; IntegerPolyhedron &b = polyhedrons[j]; @@ -427,54 +528,64 @@ Simplex &simpA = simplices[i]; Simplex &simpB = simplices[j]; - // Check that all equalities are redundant in a (and in b). - bool onlyRedundantEqsA = true; - for (unsigned k = 0, e = a.getNumEqualities(); k < e; ++k) - if (!simpB.isRedundantEquality(a.getEquality(k))) { - onlyRedundantEqsA = false; - break; - } - - bool onlyRedundantEqsB = true; - for (unsigned k = 0, e = b.getNumEqualities(); k < e; ++k) - if (!simpA.isRedundantEquality(b.getEquality(k))) { - onlyRedundantEqsB = false; - break; - } - - // If there are non-redundant equalities for both, exit early. - if (!onlyRedundantEqsB && !onlyRedundantEqsA) - return failure(); - SmallVector, 2> redundantIneqsA; SmallVector, 2> cuttingIneqsA; + SmallVector, 2> negEqs; + + // Organize all inequalities and equalities of `a` according to their type for + // `b` into `redundantIneqsA` and `cuttingIneqsA` (and vice versa for all + // inequalities of `b` according to their type in `a`). If a separate + // inequality is encountered during typing, the two IntegerPolyhedrons cannot + // be coalesced. + for (int k = 0, e = a.getNumInequalities(); k < e; ++k) + if (typeInequality(a.getInequality(k), simpB, redundantIneqsA, + cuttingIneqsA) + .failed()) + return failure(); - // Organize all inequalities of `a` according to their type for `b` into - // `redundantIneqsA` and `cuttingIneqsA` (and vice versa for all inequalities - // of `b` according to their type in `a`). If a separate inequality is - // encountered during typing, the two IntegerPolyhedrons cannot be coalesced. - if (typeInequalities(a, simpB, redundantIneqsA, cuttingIneqsA).failed()) - return failure(); + for (int k = 0, e = a.getNumEqualities(); k < e; ++k) + if (typeEquality(a.getEquality(k), simpB, redundantIneqsA, cuttingIneqsA, + negEqs) + .failed()) + return failure(); SmallVector, 2> redundantIneqsB; SmallVector, 2> cuttingIneqsB; + for (int k = 0, e = b.getNumInequalities(); k < e; ++k) + if (typeInequality(b.getInequality(k), simpA, redundantIneqsB, + cuttingIneqsB) + .failed()) + return failure(); - if (typeInequalities(b, simpA, redundantIneqsB, cuttingIneqsB).failed()) - return failure(); + for (int k = 0, e = b.getNumEqualities(); k < e; ++k) + if (typeEquality(b.getEquality(k), simpA, redundantIneqsB, cuttingIneqsB, + negEqs) + .failed()) + return failure(); - // If there are no cutting inequalities of `a` and all equalities of `a` are - // redundant, then all constraints of `a` are redundant making `b` contained - // within a (and vice versa for `b`). - if (cuttingIneqsA.empty() && onlyRedundantEqsA) { + // If there are no cutting inequalities of `a`, `b` is contained + // within `a` (and vice versa for `b`). + if (cuttingIneqsA.empty()) { erasePolyhedron(j, polyhedrons, simplices); return success(); } - if (cuttingIneqsB.empty() && onlyRedundantEqsB) { + if (cuttingIneqsB.empty()) { erasePolyhedron(i, polyhedrons, simplices); return success(); } + // Try to apply the cut case + if (coalescePairCutCase(polyhedrons, simplices, i, j, redundantIneqsA, + cuttingIneqsA, redundantIneqsB, cuttingIneqsB) + .succeeded()) + return success(); + + if (coalescePairCutCase(polyhedrons, simplices, j, i, redundantIneqsB, + cuttingIneqsB, redundantIneqsA, cuttingIneqsA) + .succeeded()) + return success(); + return failure(); } diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -528,7 +528,7 @@ "(x) : ( x >= 0, -x + 3 >= 0)", "(x) : ( x - 2 >= 0, -x + 4 >= 0)", }); - expectCoalesce(2, set); + expectCoalesce(1, set); } TEST(SetTest, coalesceSeparateOneDim) { @@ -537,6 +537,12 @@ expectCoalesce(2, set); } +TEST(SetTest, coalesceAdjEq) { + PresburgerSet set = parsePresburgerSetFromPolyStrings( + 1, {"(x) : ( x == 0)", "(x) : ( x - 1 == 0)"}); + expectCoalesce(2, set); +} + TEST(SetTest, coalesceContainedTwoDim) { PresburgerSet set = parsePresburgerSetFromPolyStrings( 2, { @@ -552,6 +558,15 @@ "(x,y) : (x >= 0, -x + 3 >= 0, y >= 0, -y + 2 >= 0)", "(x,y) : (x >= 0, -x + 3 >= 0, y - 1 >= 0, -y + 3 >= 0)", }); + expectCoalesce(1, set); +} + +TEST(SetTest, coalesceEqStickingOut) { + PresburgerSet set = parsePresburgerSetFromPolyStrings( + 2, { + "(x,y) : (x >= 0, -x + 2 >= 0, y >= 0, -y + 2 >= 0)", + "(x,y) : (y - 1 == 0, x >= 0, -x + 3 >= 0)", + }); expectCoalesce(2, set); } @@ -576,10 +591,10 @@ TEST(SetTest, coalesceCuttingEq) { PresburgerSet set = parsePresburgerSetFromPolyStrings( 2, { - "(x,y) : (x - 1 >= 0, -x + 3 >= 0, x - y == 0)", + "(x,y) : (x + 1 >= 0, -x + 1 >= 0, x - y == 0)", "(x,y) : (x >= 0, -x + 2 >= 0, x - y == 0)", }); - expectCoalesce(2, set); + expectCoalesce(1, set); } TEST(SetTest, coalesceSeparateEq) {