diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -166,21 +166,21 @@ /// false otherwise. bool isEmpty() const; - /// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n - /// is the current number of variables, then the corresponding inequality is - /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0. - virtual void addInequality(ArrayRef coeffs) = 0; - /// Returns the number of variables in the tableau. unsigned getNumVariables() const; /// Returns the number of constraints in the tableau. unsigned getNumConstraints() const; + /// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n + /// is the current number of variables, then the corresponding inequality is + /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0. + virtual void addInequality(ArrayRef coeffs) = 0; + /// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n /// is the current number of variables, then the corresponding equality is /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0. - void addEquality(ArrayRef coeffs); + virtual void addEquality(ArrayRef coeffs) = 0; /// Add new variables to the end of the list of variables. void appendVariable(unsigned count = 1); @@ -249,6 +249,14 @@ /// coefficient for it. Optional findAnyPivotRow(unsigned col); + /// Return any column that this row can be pivoted with, ignoring tableau + /// consistency. Equality rows are not considered. + /// + /// Returns an empty optional if no pivot is possible, which happens only when + /// the column unknown is a variable and no constraint has a non-zero + /// coefficient for it. + Optional findAnyPivotCol(unsigned row); + /// Swap the row with the column in the tableau's data structures but not the /// tableau itself. This is used by pivot. void swapRowWithCol(unsigned row, unsigned col); @@ -295,6 +303,7 @@ RemoveLastVariable, UnmarkEmpty, UnmarkLastRedundant, + UnmarkLastEquality, RestoreBasis }; @@ -308,13 +317,14 @@ /// Undo the operation represented by the log entry. void undo(UndoLogEntry entry); - /// Return the number of fixed columns, as described in the constructor above, - /// this is the number of columns beyond those for the variables in var. - unsigned getNumFixedCols() const { return usingBigM ? 3u : 2u; } + unsigned getNumFixedCols() const { return numFixedCols; } /// Stores whether or not a big M column is present in the tableau. bool usingBigM; + /// denom + const + maybe M + equality columns + unsigned numFixedCols; + /// The number of rows in the tableau. unsigned nRow; @@ -435,9 +445,12 @@ /// /// This just adds the inequality to the tableau and does not try to create a /// consistent tableau configuration. - void addInequality(ArrayRef coeffs) final { - addRow(coeffs, /*makeRestricted=*/true); - } + void addInequality(ArrayRef coeffs) final; + + /// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n + /// is the current number of variables, then the corresponding equality is + /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0. + void addEquality(ArrayRef coeffs) final; /// Get a snapshot of the current state. This is used for rolling back. unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); } @@ -533,6 +546,11 @@ /// state and marks the Simplex empty if this is not possible. void addInequality(ArrayRef coeffs) final; + /// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n + /// is the current number of variables, then the corresponding equality is + /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0. + void addEquality(ArrayRef coeffs) final; + /// Compute the maximum or minimum value of the given row, depending on /// direction. The specified row is never pivoted. On return, the row may /// have a negative sample value if the direction is down. diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -19,12 +19,12 @@ const int nullIndex = std::numeric_limits::max(); SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM) - : usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar), - nRedundant(0), tableau(0, nCol), empty(false) { - colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex); + : usingBigM(mustUseBigM), numFixedCols(mustUseBigM ? 3 : 2), nRow(0), + nCol(numFixedCols + nVar), nRedundant(0), tableau(0, nCol), empty(false) { + colUnknown.insert(colUnknown.begin(), numFixedCols, nullIndex); for (unsigned i = 0; i < nVar; ++i) { var.emplace_back(Orientation::Column, /*restricted=*/false, - /*pos=*/getNumFixedCols() + i); + /*pos=*/numFixedCols + i); colUnknown.push_back(i); } } @@ -309,7 +309,7 @@ // minimizes the change in sample value. LogicalResult LexSimplex::moveRowUnknownToColumn(unsigned row) { Optional maybeColumn; - for (unsigned col = 3; col < nCol; ++col) { + for (unsigned col = getNumFixedCols(); col < nCol; ++col) { if (tableau(row, col) <= 0) continue; maybeColumn = @@ -648,7 +648,7 @@ /// /// We simply add two opposing inequalities, which force the expression to /// be zero. -void SimplexBase::addEquality(ArrayRef coeffs) { +void Simplex::addEquality(ArrayRef coeffs) { addInequality(coeffs); SmallVector negatedCoeffs; for (int64_t coeff : coeffs) @@ -705,6 +705,15 @@ return {}; } +// This doesn't find a pivot column only if the row has zero coefficients for +// every column not marked as an equality. +Optional SimplexBase::findAnyPivotCol(unsigned row) { + for (unsigned col = getNumFixedCols(); col < nCol; ++col) + if (tableau(row, col) != 0) + return col; + return {}; +} + // It's not valid to remove the constraint by deleting the column since this // would result in an invalid basis. void Simplex::undoLastConstraint() { @@ -780,6 +789,10 @@ empty = false; } else if (entry == UndoLogEntry::UnmarkLastRedundant) { nRedundant--; + } else if (entry == UndoLogEntry::UnmarkLastEquality) { + numFixedCols--; + assert(getNumFixedCols() >= 2 + usingBigM && + "The denominator, constant, big M and symbols are always fixed!"); } else if (entry == UndoLogEntry::RestoreBasis) { assert(!savedBases.empty() && "No bases saved!"); @@ -1110,6 +1123,26 @@ return sample; } +void LexSimplex::addInequality(ArrayRef coeffs) { + addRow(coeffs, /*makeRestricted=*/true); +} + +/// Try to make the equality a fixed column by finding any pivot and performing +/// it. The only time this is not possible is when the given equality's +/// direction is already in the span of the existing fixed column equalities. In +/// that case, we just leave it in row position. +void LexSimplex::addEquality(ArrayRef coeffs) { + const Unknown &u = con[addRow(coeffs, /*makeRestricted=*/true)]; + Optional pivotCol = findAnyPivotCol(u.pos); + if (!pivotCol) + return; + + pivot(u.pos, *pivotCol); + swapColumns(*pivotCol, getNumFixedCols()); + numFixedCols++; + undoLog.push_back(UndoLogEntry::UnmarkLastEquality); +} + MaybeOptimum> LexSimplex::getRationalSample() const { if (empty) return OptimumKind::Empty; diff --git a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp --- a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp +++ b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp @@ -548,3 +548,10 @@ ASSERT_TRUE(sample.hasValue()); EXPECT_EQ((*sample)[0] / 2, (*sample)[1]); } + +TEST(LexSimplexTest, addEquality) { + IntegerRelation rel(/*numDomain=*/0, /*numRange=*/1); + rel.addEquality({1, 0}); + LexSimplex simplex(rel); + EXPECT_EQ(simplex.getNumConstraints(), 1u); +}