diff --git a/mlir/lib/Analysis/Presburger/PresburgerSet.cpp b/mlir/lib/Analysis/Presburger/PresburgerSet.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSet.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSet.cpp @@ -379,6 +379,97 @@ return result; } +/// Given an IntegerPolyhedron `p` and one of its inequalities `ineq`, check +/// that all inequalities of `cut` are redundant for the facet of `p` where +/// `ineq` holds as an equality. `simp` must be the Simplex constructed from +/// `p`. +bool containedFacet(ArrayRef ineq, Simplex &simp, IntegerPolyhedron &p, + ArrayRef> cut) { + unsigned snapshot = simp.getSnapshot(); + simp.addEquality(ineq); + if (!llvm::all_of(cut, [simp](ArrayRef curr) { + return ((Simplex)simp).isRedundantInequality(curr); + })) { + simp.rollback(snapshot); + return false; + } + simp.rollback(snapshot); + return true; +} + +/// Adds `poly` to `polyhedrons` and removes the polyhedrons at position `i` and +/// `j`. Updates `simplices` to reflect the changes. `i` and `j` cannot be +/// equal. +void addCoalescedPolyhedron(SmallVectorImpl &polyhedrons, + unsigned i, unsigned j, IntegerPolyhedron &poly, + SmallVectorImpl &simplices) { + assert(i != j && "The indices must refer to different polyhedra"); + IntegerPolyhedron newSet = poly; + + unsigned n = polyhedrons.size(); + polyhedrons[i] = polyhedrons[n - 1]; + polyhedrons[j] = polyhedrons[n - 2]; + polyhedrons.pop_back(); + polyhedrons[n - 2] = newSet; + + simplices[i] = simplices[n - 1]; + simplices[j] = simplices[n - 2]; + simplices.pop_back(); + simplices[n - 2] = Simplex(newSet); +} + +/// Given two polyhedra `a` and `b` at positions `i` and `j` in `polyhedrons` +/// and `redundantIneqsA` being the inequalities of `a` that are redundant for +/// `b` (similarly for `cuttingIneqsA`, `redundantIneqsB`, and `cuttingIneqsB`), +/// checks whether the facets of all cutting inequalites of `a` are contained in +/// `b`. If so, a new polyhedron consisting of all redundant inequalites of `a` +/// and `b` and all equalities of both is created. +/// +/// An example of this case: +/// ___________ ___________ +/// / / | / / / +/// \ \ | / ==> \ / +/// \ \ | / \ / +/// \___\|/ \_____/ +/// +/// +LogicalResult cutCase(SmallVectorImpl &polyhedrons, + SmallVectorImpl &simplices, unsigned i, + unsigned j, ArrayRef> redundantIneqsA, + ArrayRef> cuttingIneqsA, + ArrayRef> redundantIneqsB, + ArrayRef> cuttingIneqsB) { + /// All inequalities of `b` need to be redundant. We already know that the + /// redundant ones are, so only the cutting ones remain to be checked. + Simplex &simp = simplices[i]; + IntegerPolyhedron &poly = polyhedrons[i]; + if (!llvm::all_of(cuttingIneqsA, + [simp, poly, cuttingIneqsB](ArrayRef curr) { + Simplex s = (Simplex)simp; + IntegerPolyhedron p = (IntegerPolyhedron)poly; + return containedFacet(curr, s, p, cuttingIneqsB); + })) + return failure(); + IntegerPolyhedron newSet(polyhedrons[i].getNumDimIds(), + polyhedrons[i].getNumSymbolIds(), + polyhedrons[i].getNumLocalIds()); + + for (ArrayRef curr : redundantIneqsA) + newSet.addInequality(curr); + + for (ArrayRef curr : redundantIneqsB) + newSet.addInequality(curr); + + for (unsigned k = 0, e = polyhedrons[i].getNumEqualities(); k < e; ++k) + newSet.addEquality(polyhedrons[i].getEquality(k)); + + for (unsigned k = 0, e = polyhedrons[j].getNumEqualities(); k < e; ++k) + newSet.addEquality(polyhedrons[j].getEquality(k)); + + addCoalescedPolyhedron(polyhedrons, i, j, newSet, simplices); + return success(); +} + /// Types the inequalities of `p` according to their `IneqType` for `simp` into /// `redundantIneqs` and `cuttingIneqs`. Returns success, if no separate /// inequalities were encountered. Otherwise, returns failure. @@ -398,6 +489,37 @@ return success(); } +/// Types the equalities of `p`, i.e. for an equality `coeffs` == 0, assigns +/// both `coeffs` >= 0 and -`coeffs` >= 0 according to their `IneqType` for +/// `simp` into `redundantIneqs` and `cuttingIneqs`. Returns success, if no +/// separate inequalities were encountered. Otherwise, returns failure. +LogicalResult typeEqualities(const IntegerPolyhedron &p, Simplex &simp, + SmallVectorImpl> &redundantIneqs, + SmallVectorImpl> &cuttingIneqs) { + for (unsigned i = 0, e = p.getNumEqualities(); i < e; ++i) { + Simplex::IneqType type = simp.findIneqType(p.getEquality(i)); + if (type == Simplex::IneqType::Redundant) + redundantIneqs.push_back(p.getEquality(i)); + else if (type == Simplex::IneqType::Cut) + cuttingIneqs.push_back(p.getEquality(i)); + else + return failure(); + + SmallVector neg; + for (int64_t curr : p.getEquality(i)) + neg.push_back(-curr); + type = simp.findIneqType(neg); + if (type == Simplex::IneqType::Redundant) + redundantIneqs.push_back(neg); + else if (type == Simplex::IneqType::Cut) + cuttingIneqs.push_back(neg); + else + return failure(); + } + + return success(); +} + /// Replaces the element at position `i` with the last element and erases the /// last element for both `polyhedrons` and `simplices`. void erasePolyhedron(unsigned i, @@ -427,25 +549,6 @@ Simplex &simpA = simplices[i]; Simplex &simpB = simplices[j]; - // Check that all equalities are redundant in a (and in b). - bool onlyRedundantEqsA = true; - for (unsigned k = 0, e = a.getNumEqualities(); k < e; ++k) - if (!simpB.isRedundantEquality(a.getEquality(k))) { - onlyRedundantEqsA = false; - break; - } - - bool onlyRedundantEqsB = true; - for (unsigned k = 0, e = b.getNumEqualities(); k < e; ++k) - if (!simpA.isRedundantEquality(b.getEquality(k))) { - onlyRedundantEqsB = false; - break; - } - - // If there are non-redundant equalities for both, exit early. - if (!onlyRedundantEqsB && !onlyRedundantEqsA) - return failure(); - SmallVector, 2> redundantIneqsA; SmallVector, 2> cuttingIneqsA; @@ -456,25 +559,43 @@ if (typeInequalities(a, simpB, redundantIneqsA, cuttingIneqsA).failed()) return failure(); + if (typeEqualities(a, simpB, redundantIneqsA, cuttingIneqsA).failed()) + return failure(); + SmallVector, 2> redundantIneqsB; SmallVector, 2> cuttingIneqsB; if (typeInequalities(b, simpA, redundantIneqsB, cuttingIneqsB).failed()) return failure(); + if (typeEqualities(b, simpA, redundantIneqsB, cuttingIneqsB).failed()) + return failure(); + // If there are no cutting inequalities of `a` and all equalities of `a` are // redundant, then all constraints of `a` are redundant making `b` contained // within a (and vice versa for `b`). - if (cuttingIneqsA.empty() && onlyRedundantEqsA) { + if (cuttingIneqsA.empty()) { erasePolyhedron(j, polyhedrons, simplices); return success(); } - if (cuttingIneqsB.empty() && onlyRedundantEqsB) { + if (cuttingIneqsB.empty()) { erasePolyhedron(i, polyhedrons, simplices); return success(); } + // Try to apply the cut case if all equalities of `a` are redundant (and `b` + // respectively). + if (cutCase(polyhedrons, simplices, i, j, redundantIneqsA, cuttingIneqsA, + redundantIneqsB, cuttingIneqsB) + .succeeded()) + return success(); + + if (cutCase(polyhedrons, simplices, j, i, redundantIneqsB, cuttingIneqsB, + redundantIneqsA, cuttingIneqsA) + .succeeded()) + return success(); + return failure(); } diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -528,7 +528,7 @@ "(x) : ( x >= 0, -x + 3 >= 0)", "(x) : ( x - 2 >= 0, -x + 4 >= 0)", }); - expectCoalesce(2, set); + expectCoalesce(1, set); } TEST(SetTest, coalesceSeparateOneDim) { @@ -537,6 +537,12 @@ expectCoalesce(2, set); } +TEST(SetTest, coalesceAdjEq) { + PresburgerSet set = parsePresburgerSetFromPolyStrings( + 1, {"(x) : ( x == 0)", "(x) : ( x - 1 == 0)"}); + expectCoalesce(2, set); +} + TEST(SetTest, coalesceContainedTwoDim) { PresburgerSet set = parsePresburgerSetFromPolyStrings( 2, { @@ -552,6 +558,15 @@ "(x,y) : (x >= 0, -x + 3 >= 0, y >= 0, -y + 2 >= 0)", "(x,y) : (x >= 0, -x + 3 >= 0, y - 1 >= 0, -y + 3 >= 0)", }); + expectCoalesce(1, set); +} + +TEST(SetTest, coalesceEqStickingOut) { + PresburgerSet set = parsePresburgerSetFromPolyStrings( + 2, { + "(x,y) : (x >= 0, -x + 2 >= 0, y >= 0, -y + 2 >= 0)", + "(x,y) : (y - 1 == 0, x >= 0, -x + 3 >= 0)", + }); expectCoalesce(2, set); } @@ -576,10 +591,10 @@ TEST(SetTest, coalesceCuttingEq) { PresburgerSet set = parsePresburgerSetFromPolyStrings( 2, { - "(x,y) : (x - 1 >= 0, -x + 3 >= 0, x - y == 0)", + "(x,y) : (x + 1 >= 0, -x + 3 >= 0, x - y == 0)", "(x,y) : (x >= 0, -x + 2 >= 0, x - y == 0)", }); - expectCoalesce(2, set); + expectCoalesce(1, set); } TEST(SetTest, coalesceSeparateEq) {