diff --git a/mlir/lib/Analysis/Presburger/PresburgerSet.cpp b/mlir/lib/Analysis/Presburger/PresburgerSet.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSet.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSet.cpp @@ -401,6 +401,88 @@ return result; } +/// Given `p` and one of its inequalities `ineq`, checks that all inequalities +/// of `cut` are redundant for the facet of `p` where `ineq` holds as an +/// equality. +LogicalResult containedFacet(ArrayRef ineq, IntegerPolyhedron &p, + ArrayRef> cut) { + Simplex simp(p); + simp.addEquality(ineq); + for (ArrayRef curr : cut) { + if (simp.findIneqType(curr) != Simplex::IneqType::Redundant) { + return failure(); + } + } + return success(); +} + +/// Adds `poly` to `polyhedrons` and removes the polyhedrons at position `i` and +/// `j`. Updates `simplices` to reflect the changes. `polyhedrons` and +/// `simplices` must have length 2 initially and `i` and `j` cannot be equal. +void addCoalescedPolyhedron(SmallVectorImpl &polyhedrons, + unsigned i, unsigned j, IntegerPolyhedron &poly, + SmallVectorImpl &simplices) { + IntegerPolyhedron newSet = poly; + + unsigned n = polyhedrons.size(); + polyhedrons[i] = polyhedrons[n - 1]; + polyhedrons[j] = polyhedrons[n - 2]; + polyhedrons.pop_back(); + polyhedrons[n - 2] = newSet; + + simplices[i] = simplices[n - 1]; + simplices[j] = simplices[n - 2]; + simplices.pop_back(); + simplices[n - 2] = Simplex(newSet); +} + +/// Given two polyhedra a and b at positions `i` and `j` in `polyhedrons` and +/// with `redundantIneqsA` being the inequalities of a that are redundant for b +/// (similarly for `cuttingIneqsA`, `redundantIneqsB`, and `cuttingIneqsB`), +/// checks whether the facets of all cutting inequalites of a are contained in +/// b. If so, a new polyhedron consisting of all redundant inequalites of a and +/// b and all equalities of both is created. +/// +/// An example of this case: +/// ___________ ___________ +/// / / | / / / +/// \ \ | / ==> \ / +/// \ \ | / \ / +/// \___\|/ \_____/ +/// +/// +LogicalResult cutCase(SmallVectorImpl &polyhedrons, + SmallVectorImpl &simplices, unsigned i, + unsigned j, ArrayRef> redundantIneqsA, + ArrayRef> cuttingIneqsA, + ArrayRef> redundantIneqsB, + ArrayRef> cuttingIneqsB) { + for (ArrayRef curr : cuttingIneqsA) { + /// All inequalities of b need to be redundant. We already know that the + /// redundant ones are, so only the cutting ones remain to be checked. + if (containedFacet(curr, polyhedrons[i], cuttingIneqsB).failed()) { + return failure(); + } + } + IntegerPolyhedron newSet(polyhedrons[i].getNumDimIds(), + polyhedrons[i].getNumSymbolIds()); + + for (unsigned k = 0, e = redundantIneqsA.size(); k < e; ++k) + newSet.addInequality(redundantIneqsA[k]); + + for (unsigned k = 0, e = redundantIneqsB.size(); k < e; ++k) + newSet.addInequality(redundantIneqsB[k]); + + for (unsigned k = 0, e = polyhedrons[i].getNumEqualities(); k < e; k++) + newSet.addEquality(polyhedrons[i].getEquality(k)); + + for (unsigned k = 0, e = polyhedrons[j].getNumEqualities(); k < e; k++) + newSet.addEquality(polyhedrons[j].getEquality(k)); + + addCoalescedPolyhedron(polyhedrons, i, j, newSet, simplices); + return success(); +} + /// Types the inequalities of `p` according to their `IneqType` for `simp` into /// `redundantIneqs` and `cuttingIneqs`. Returns success, if no separate /// inequalities were encountered. Otherwise, returns failure. @@ -495,6 +577,18 @@ return success(); } + if (onlyRedundantEqsA && + cutCase(polyhedrons, simplices, i, j, redundantIneqsA, cuttingIneqsA, + redundantIneqsB, cuttingIneqsB) + .succeeded()) + return success(); + + if (onlyRedundantEqsB && + cutCase(polyhedrons, simplices, j, i, redundantIneqsB, cuttingIneqsB, + redundantIneqsA, cuttingIneqsA) + .succeeded()) + return success(); + return failure(); } diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -528,7 +528,7 @@ "(x) : ( x >= 0, -x + 3 >= 0)", "(x) : ( x - 2 >= 0, -x + 4 >= 0)", }); - expectCoalesce(2, set); + expectCoalesce(1, set); } TEST(SetTest, coalesceSeparateOneDim) { @@ -552,7 +552,7 @@ "(x,y) : (x >= 0, -x + 3 >= 0, y >= 0, -y + 2 >= 0)", "(x,y) : (x >= 0, -x + 3 >= 0, y - 1 >= 0, -y + 3 >= 0)", }); - expectCoalesce(2, set); + expectCoalesce(1, set); } TEST(SetTest, coalesceSeparateTwoDim) { @@ -579,7 +579,7 @@ "(x,y) : (x - 1 >= 0, -x + 3 >= 0, x - y == 0)", "(x,y) : (x >= 0, -x + 2 >= 0, x - y == 0)", }); - expectCoalesce(2, set); + expectCoalesce(1, set); } TEST(SetTest, coalesceSeparateEq) {