diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -537,6 +537,9 @@ /// removeTrivialRedundancy. void removeRedundantInequalities(); + // A check using Simplex to eliminate redundant constraints. + void removeRedundantConstraints(); + // Removes all equalities and inequalities. void clearConstraints(); diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -112,6 +112,15 @@ /// set of constraints is mutually contradictory and the tableau is marked /// _empty_, which means the set of constraints has no solution. /// +/// The Simplex class supports redundancy checking via detectRedundant and +/// isMarkedRedundant. A redundant constraint is one which is never violated as +/// long as the other constrants are not violated. i.e., removing a redundant +/// constraint does not change the set of solutions to the constraints. As a +/// heuristic, constraints that have been marked redundant can be ignored for +/// most operations. Therefore, these constraints are kept in rows 0 to +/// nRedundant - 1, where nRedundant is a member variable that tracks the number +/// of constraints that have been marked redundant. +/// /// This Simplex class also supports taking snapshots of the current state /// and rolling back to prior snapshots. This works by maintaing an undo log /// of operations. Snapshots are just pointers to a particular location in the @@ -158,7 +167,7 @@ void rollback(unsigned snapshot); /// Compute the maximum or minimum value of the given row, depending on - /// direction. + /// direction. The specified row is never pivoted. /// /// Returns a (num, den) pair denoting the optimum, or None if no /// optimum exists, i.e., if the expression is unbounded in this direction. @@ -172,6 +181,18 @@ Optional computeOptimum(Direction direction, ArrayRef coeffs); + /// Returns whether the specified constraint has been marked as redundant. + /// Constraints are numbered from 0 starting at the first added inequality. + /// Equalities are added as a pair of inequalities and so correspond to two + /// inequalities with successive indices. + bool isMarkedRedundant(unsigned inequalityIndex) const; + + /// Finds a maximal subset of constraints that is redundant, i.e., such that + /// the set of solutions does not change if these constraints are removed. + /// Marks these constraints as redundant. Whether a specific constraint has + /// been marked redundant can be queried using isMarkedRedundant. + void detectRedundant(); + /// Returns a (min, max) pair denoting the minimum and maximum integer values /// of the given expression. std::pair computeIntegerBounds(ArrayRef coeffs); @@ -272,7 +293,13 @@ /// sample value, false otherwise. LogicalResult restoreRow(Unknown &u); - enum class UndoLogEntry { RemoveLastConstraint, UnmarkEmpty }; + void markRowRedundant(Unknown &u); + + enum class UndoLogEntry { + RemoveLastConstraint, + UnmarkEmpty, + UnmarkLastRedundant + }; /// Undo the operation represented by the log entry. void undo(UndoLogEntry entry); @@ -299,6 +326,10 @@ /// and the constant column. unsigned nCol; + /// The number of redundant rows in the tableau. These are the first + /// nRedundant rows. + unsigned nRedundant; + /// The matrix representing the tableau. Matrix tableau; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1410,6 +1410,43 @@ inequalities.resize(numReservedCols * pos); } +// A more complex check to eliminate redundant inequalities. Uses Simplex +// to check if a constraint is redundant. +void FlatAffineConstraints::removeRedundantConstraints() { + Simplex simplex(*this); + simplex.detectRedundant(); + + // Scan to get rid of all inequalities marked redundant, in-place. + auto copyInequality = [&](unsigned src, unsigned dest) { + if (src == dest) + return; + for (unsigned c = 0, e = getNumCols(); c < e; c++) + atIneq(dest, c) = atIneq(src, c); + }; + unsigned pos = 0; + unsigned numIneqs = getNumInequalities(); + for (unsigned r = 0; r < numIneqs; r++) { + if (!simplex.isMarkedRedundant(r)) + copyInequality(r, pos++); + } + inequalities.resize(numReservedCols * pos); + + // Scan to get rid of all equalities marked redundant, in-place. + auto copyEquality = [&](unsigned src, unsigned dest) { + if (src == dest) + return; + for (unsigned c = 0, e = getNumCols(); c < e; c++) + atEq(dest, c) = atEq(src, c); + }; + pos = 0; + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + if (!simplex.isMarkedRedundant(numIneqs + 2 * r) || + !simplex.isMarkedRedundant(numIneqs + 2 * r + 1)) + copyEquality(r, pos++); + } + equalities.resize(numReservedCols * pos); +} + std::pair FlatAffineConstraints::getLowerAndUpperBound( unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, ArrayRef localExprs, MLIRContext *context) const { diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -17,7 +17,7 @@ /// Construct a Simplex object with `nVar` variables. Simplex::Simplex(unsigned nVar) - : nRow(0), nCol(2), tableau(0, 2 + nVar), empty(false) { + : nRow(0), nCol(2), nRedundant(0), tableau(0, 2 + nVar), empty(false) { colUnknown.push_back(nullIndex); colUnknown.push_back(nullIndex); for (unsigned i = 0; i < nVar; ++i) { @@ -239,7 +239,7 @@ } normalizeRow(pivotRow); - for (unsigned row = 0; row < nRow; ++row) { + for (unsigned row = nRedundant; row < nRow; ++row) { if (row == pivotRow) continue; if (tableau(row, pivotCol) == 0) // Nothing to do. @@ -303,7 +303,7 @@ unsigned col) const { Optional retRow; int64_t retElem, retConst; - for (unsigned row = 0; row < nRow; ++row) { + for (unsigned row = nRedundant; row < nRow; ++row) { if (skipRow && row == *skipRow) continue; int64_t elem = tableau(row, col); @@ -413,7 +413,7 @@ // coefficients for every row. But the unknown is a constraint, // so it was added initially as a row. Such a row could never have been // pivoted to a column. So a pivot row will always be found. - for (unsigned i = 0; i < nRow; ++i) { + for (unsigned i = nRedundant; i < nRow; ++i) { if (tableau(i, column) != 0) { row = i; break; @@ -435,6 +435,8 @@ con.pop_back(); } else if (entry == UndoLogEntry::UnmarkEmpty) { empty = false; + } else if (entry == UndoLogEntry::UnmarkLastRedundant) { + nRedundant--; } } @@ -480,6 +482,46 @@ return optimum; } +/// Redundant constraints are those that are in row orientation and lie in +/// rows 0 to nRedundant - 1. +bool Simplex::isMarkedRedundant(unsigned constraintIndex) const { + const Unknown &u = con[constraintIndex]; + return u.orientation == Orientation::Row && u.pos < nRedundant; +} + +void Simplex::markRowRedundant(Unknown &u) { + assert(u.orientation == Orientation::Row && + "Unknown should be in row position!"); + swapRows(u.pos, nRedundant); + ++nRedundant; + undoLog.emplace_back(UndoLogEntry::UnmarkLastRedundant); +} + +void Simplex::detectRedundant() { + for (Unknown &u : con) { + if (u.orientation == Orientation::Column) { + unsigned column = u.pos; + Optional pivotRow = findPivotRow({}, Direction::Down, column); + // If no downward pivot is returned, the constraint is unbounded below + // and hence not redundant. + if (!pivotRow) + continue; + pivot(*pivotRow, column); + } + + unsigned row = u.pos; + Optional minimum = computeRowOptimum(Direction::Down, row); + if (!minimum || *minimum < Fraction(0, 1)) { + // Constraint is unbounded below or can attain negative sample values and + // hence is not redundant. + restoreRow(u); + continue; + } + + markRowRedundant(u); + } +} + bool Simplex::isUnbounded() { if (empty) return false; @@ -506,7 +548,7 @@ /// The product constraints and variables are stored as: first A's, then B's. /// /// The product tableau has row layout: -/// A's rows, B's rows. +/// A's redundant rows, B's redundan rows, A's other rows, B's other rows. /// /// It has column layout: /// denominator, constant, A's columns, B's columns. @@ -569,9 +611,14 @@ result.nRow++; }; - for (unsigned row = 0; row < a.nRow; ++row) + result.nRedundant = a.nRedundant + b.nRedundant; + for (unsigned row = 0; row < a.nRedundant; ++row) + appendRowFromA(row); + for (unsigned row = 0; row < b.nRedundant; ++row) + appendRowFromB(row); + for (unsigned row = a.nRedundant; row < a.nRow; ++row) appendRowFromA(row); - for (unsigned row = 0; row < b.nRow; ++row) + for (unsigned row = b.nRedundant; row < b.nRow; ++row) appendRowFromB(row); return result; diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -274,4 +274,44 @@ .isIntegerEmpty()); } +TEST(FlatAffineConstraintsTest, removeRedundantTest) { + FlatAffineConstraints fac = makeFACFromConstraints(1, + { + {1, -2}, // x >= 2. + {-1, 2} // x <= 2. + }, + {{1, -2}}); // x == 2. + fac.removeRedundantConstraints(); + + // Both inequalities are redundant given the equality. Both have been removed. + EXPECT_EQ(fac.getNumInequalities(), 0u); + EXPECT_EQ(fac.getNumEqualities(), 1u); + + FlatAffineConstraints fac2 = + makeFACFromConstraints(2, + { + {1, 0, -3}, // x >= 3. + {0, 1, -2} // y >= 2 (redundant). + }, + {{1, -1, 0}}); // x == y. + fac2.removeRedundantConstraints(); + + // The second inequality is redundant and should have been removed. The + // remaining inequality should be the first one. + EXPECT_EQ(fac2.getNumInequalities(), 1u); + EXPECT_THAT(fac2.getInequality(0), testing::ElementsAre(1, 0, -3)); + EXPECT_EQ(fac2.getNumEqualities(), 1u); + + FlatAffineConstraints fac3 = + makeFACFromConstraints(3, {}, + {{1, -1, 0, 0}, // x == y. + {1, 0, -1, 0}, // x == z. + {0, 1, -1, 0}}); // y == z. + fac3.removeRedundantConstraints(); + + // One of the three equalities can be removed. + EXPECT_EQ(fac3.getNumInequalities(), 0u); + EXPECT_EQ(fac3.getNumEqualities(), 2u); +} + } // namespace mlir diff --git a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp --- a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp +++ b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp @@ -216,4 +216,190 @@ .hasValue()); } +TEST(SimplexTest, isMarkedRedundant_no_var_ge_zero) { + Simplex tab(0); + tab.addInequality({0}); // 0 >= 0 + + tab.detectRedundant(); + ASSERT_FALSE(tab.isEmpty()); + EXPECT_TRUE(tab.isMarkedRedundant(0)); +} + +TEST(SimplexTest, isMarkedRedundant_no_var_eq) { + Simplex tab(0); + tab.addEquality({0}); // 0 == 0 + tab.detectRedundant(); + ASSERT_FALSE(tab.isEmpty()); + EXPECT_TRUE(tab.isMarkedRedundant(0)); +} + +TEST(SimplexTest, isMarkedRedundant_pos_var_eq) { + Simplex tab(1); + tab.addEquality({1, 0}); // x == 0 + + tab.detectRedundant(); + ASSERT_FALSE(tab.isEmpty()); + EXPECT_FALSE(tab.isMarkedRedundant(0)); +} + +TEST(SimplexTest, isMarkedRedundant_zero_var_eq) { + Simplex tab(1); + tab.addEquality({0, 0}); // 0x == 0 + tab.detectRedundant(); + ASSERT_FALSE(tab.isEmpty()); + EXPECT_TRUE(tab.isMarkedRedundant(0)); +} + +TEST(SimplexTest, isMarkedRedundant_neg_var_eq) { + Simplex tab(1); + tab.addEquality({-1, 0}); // -x == 0 + tab.detectRedundant(); + ASSERT_FALSE(tab.isEmpty()); + EXPECT_FALSE(tab.isMarkedRedundant(0)); +} + +TEST(SimplexTest, isMarkedRedundant_pos_var_ge) { + Simplex tab(1); + tab.addInequality({1, 0}); // x >= 0 + tab.detectRedundant(); + ASSERT_FALSE(tab.isEmpty()); + EXPECT_FALSE(tab.isMarkedRedundant(0)); +} + +TEST(SimplexTest, isMarkedRedundant_zero_var_ge) { + Simplex tab(1); + tab.addInequality({0, 0}); // 0x >= 0 + tab.detectRedundant(); + ASSERT_FALSE(tab.isEmpty()); + EXPECT_TRUE(tab.isMarkedRedundant(0)); +} + +TEST(SimplexTest, isMarkedRedundant_neg_var_ge) { + Simplex tab(1); + tab.addInequality({-1, 0}); // x <= 0 + tab.detectRedundant(); + ASSERT_FALSE(tab.isEmpty()); + EXPECT_FALSE(tab.isMarkedRedundant(0)); +} + +TEST(SimplexTest, isMarkedRedundant_no_redundant) { + Simplex tab(3); + + tab.addEquality({-1, 0, 1, 0}); // u = w + tab.addInequality({-1, 16, 0, 15}); // 15 - (u - 16v) >= 0 + tab.addInequality({1, -16, 0, 0}); // (u - 16v) >= 0 + + tab.detectRedundant(); + ASSERT_FALSE(tab.isEmpty()); + + for (unsigned i = 0; i < tab.numConstraints(); ++i) + EXPECT_FALSE(tab.isMarkedRedundant(i)) << "i = " << i << "\n"; +} + +TEST(SimplexTest, isMarkedRedundant_regression_test) { + Simplex tab(17); + + tab.addEquality({0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0}); + tab.addEquality({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, -10}); + tab.addEquality({0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, -13}); + tab.addEquality({0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10}); + tab.addEquality({1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13}); + tab.addInequality({0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1}); + tab.addInequality({0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 500}); + tab.addInequality({0, 0, 0, -16, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + tab.addInequality({0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1}); + tab.addInequality({0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 998}); + tab.addInequality({0, 0, 0, 16, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15}); + tab.addInequality({0, 0, 0, 0, -16, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + tab.addInequality({0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1}); + tab.addInequality({0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 998}); + tab.addInequality({0, 0, 0, 0, 16, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15}); + tab.addInequality({0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + tab.addInequality({0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}); + tab.addInequality({0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, -1}); + tab.addInequality({0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 500}); + tab.addInequality({0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 15}); + tab.addInequality({0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, -16, 0, 0, 0, 0, 0, 0}); + tab.addInequality({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -16, 0, 1, 0, 0, 0}); + tab.addInequality({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, -1}); + tab.addInequality({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 998}); + tab.addInequality({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, -1, 0, 0, 15}); + tab.addInequality({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0}); + tab.addInequality({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 1}); + tab.addInequality({0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8}); + tab.addInequality({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 8, 8}); + tab.addInequality({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, -8, -1}); + tab.addInequality({0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, -8, -1}); + + tab.detectRedundant(); + ASSERT_FALSE(tab.isEmpty()); + for (unsigned i = 0; i < tab.numConstraints(); ++i) + EXPECT_FALSE(tab.isMarkedRedundant(i)) << "i = " << i << '\n'; +} + +TEST(SimplexTest, isMarkedRedundant_repeated_constraints) { + Simplex tab(3); + + // [4] to [7] are repeats of [0] to [3]. + tab.addInequality({0, -1, 0, 1}); // [0]: y <= 1 + tab.addInequality({-1, 0, 8, 7}); // [1]: 8z >= x - 7 + tab.addInequality({1, 0, -8, 0}); // [2]: 8z <= x + tab.addInequality({0, 1, 0, 0}); // [3]: y >= 0 + tab.addInequality({-1, 0, 8, 7}); // [4]: 8z >= 7 - x + tab.addInequality({1, 0, -8, 0}); // [5]: 8z <= x + tab.addInequality({0, 1, 0, 0}); // [6]: y >= 0 + tab.addInequality({0, -1, 0, 1}); // [7]: y <= 1 + + tab.detectRedundant(); + ASSERT_FALSE(tab.isEmpty()); + + EXPECT_EQ(tab.isMarkedRedundant(0), true); + EXPECT_EQ(tab.isMarkedRedundant(1), true); + EXPECT_EQ(tab.isMarkedRedundant(2), true); + EXPECT_EQ(tab.isMarkedRedundant(3), true); + EXPECT_EQ(tab.isMarkedRedundant(4), false); + EXPECT_EQ(tab.isMarkedRedundant(5), false); + EXPECT_EQ(tab.isMarkedRedundant(6), false); + EXPECT_EQ(tab.isMarkedRedundant(7), false); +} + +TEST(SimplexTest, isMarkedRedundant) { + Simplex tab(3); + tab.addInequality({0, -1, 0, 1}); // [0]: y <= 1 + tab.addInequality({1, 0, 0, -1}); // [1]: x >= 1 + tab.addInequality({-1, 0, 0, 2}); // [2]: x <= 2 + tab.addInequality({-1, 0, 2, 7}); // [3]: 2z >= x - 7 + tab.addInequality({1, 0, -2, 0}); // [4]: 2z <= x + tab.addInequality({0, 1, 0, 0}); // [5]: y >= 0 + tab.addInequality({0, 1, -2, 1}); // [6]: y >= 2z - 1 + tab.addInequality({-1, 1, 0, 1}); // [7]: y >= x - 1 + + tab.detectRedundant(); + ASSERT_FALSE(tab.isEmpty()); + + // [0], [1], [3], [4], [7] together imply [2], [5], [6] must hold. + // + // From [7], [0]: x <= y + 1 <= 2, so we have [2]. + // From [7], [1]: y >= x - 1 >= 0, so we have [5]. + // From [4], [7]: 2z - 1 <= x - 1 <= y, so we have [6]. + EXPECT_FALSE(tab.isMarkedRedundant(0)); + EXPECT_FALSE(tab.isMarkedRedundant(1)); + EXPECT_TRUE(tab.isMarkedRedundant(2)); + EXPECT_FALSE(tab.isMarkedRedundant(3)); + EXPECT_FALSE(tab.isMarkedRedundant(4)); + EXPECT_TRUE(tab.isMarkedRedundant(5)); + EXPECT_TRUE(tab.isMarkedRedundant(6)); + EXPECT_FALSE(tab.isMarkedRedundant(7)); +} + +TEST(SimplexTest, addInequality_already_redundant) { + Simplex tab(1); + tab.addInequality({1, -1}); // x >= 1 + tab.addInequality({1, 0}); // x >= 0 + tab.detectRedundant(); + ASSERT_FALSE(tab.isEmpty()); + EXPECT_FALSE(tab.isMarkedRedundant(0)); + EXPECT_TRUE(tab.isMarkedRedundant(1)); +} + } // namespace mlir