diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -162,6 +162,9 @@ /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0. void addEquality(ArrayRef coeffs); + /// Add new variables to the end of the list of variables. + void appendVariable(unsigned count = 1); + /// Mark the tableau as being empty. void markEmpty(); @@ -301,8 +304,9 @@ /// and the denominator. void normalizeRow(unsigned row); - /// Swap the two rows in the tableau and associated data structures. + /// Swap the two rows/columns in the tableau and associated data structures. void swapRows(unsigned i, unsigned j); + void swapColumns(unsigned i, unsigned j); /// Restore the unknown to a non-negative sample value. /// @@ -327,6 +331,7 @@ /// Enum to denote operations that need to be undone during rollback. enum class UndoLogEntry { RemoveLastConstraint, + RemoveLastVariable, UnmarkEmpty, UnmarkLastRedundant }; diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -345,6 +345,16 @@ unknownFromRow(j).pos = j; } +void Simplex::swapColumns(unsigned i, unsigned j) { + assert(i < nCol && j < nCol && "Invalid columns provided!"); + if (i == j) + return; + tableau.swapColumns(i, j); + std::swap(colUnknown[i], colUnknown[j]); + unknownFromColumn(i).pos = i; + unknownFromColumn(j).pos = j; +} + /// Mark this tableau empty and push an entry to the undo stack. void Simplex::markEmpty() { undoLog.push_back(UndoLogEntry::UnmarkEmpty); @@ -434,6 +444,26 @@ nRow--; rowUnknown.pop_back(); con.pop_back(); + } else if (entry == UndoLogEntry::RemoveLastVariable) { + // Whenever we are rolling back the addition of a variable, it is guaranteed + // that the variable will be in column position. + // + // We can see this as follows: any constraint that depends on this variable + // was added after this variable was added, so the addition of such + // constraints should already have been rolled back by the time we get to + // rolling back the addition of the variable. Therefore, no constraint + // currently has a component along the variable, so the variable itself must + // be part of the basis. + assert(var.back().orientation == Orientation::Column && + "Variable to be removed must be in column orientation!"); + + // Move this variable to the last column and remove the column from the + // tableau. + swapColumns(var.back().pos, nCol - 1); + tableau.resizeHorizontally(nCol - 1); + var.pop_back(); + colUnknown.pop_back(); + nCol--; } else if (entry == UndoLogEntry::UnmarkEmpty) { empty = false; } else if (entry == UndoLogEntry::UnmarkLastRedundant) { @@ -452,6 +482,19 @@ } } +void Simplex::appendVariable(unsigned count) { + var.reserve(var.size() + count); + colUnknown.reserve(colUnknown.size() + count); + for (unsigned i = 0; i < count; ++i) { + nCol++; + var.emplace_back(Orientation::Column, /*restricted=*/false, + /*pos=*/nCol - 1); + colUnknown.push_back(var.size() - 1); + } + tableau.resizeHorizontally(nCol); + undoLog.insert(undoLog.end(), count, UndoLogEntry::RemoveLastVariable); +} + /// Add all the constraints from the given FlatAffineConstraints. void Simplex::intersectFlatAffineConstraints(const FlatAffineConstraints &fac) { assert(fac.getNumIds() == numVariables() && diff --git a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp --- a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp +++ b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp @@ -383,4 +383,30 @@ EXPECT_TRUE(simplex.isMarkedRedundant(1)); } +TEST(SimplexTest, appendVariable) { + Simplex simplex(1); + + unsigned snapshot1 = simplex.getSnapshot(); + simplex.appendVariable(); + EXPECT_EQ(simplex.numVariables(), 2u); + + int64_t yMin = 2, yMax = 5; + simplex.addInequality({0, 1, -yMin}); // y >= 2. + simplex.addInequality({0, -1, yMax}); // y <= 5. + + unsigned snapshot2 = simplex.getSnapshot(); + simplex.appendVariable(2); + EXPECT_EQ(simplex.numVariables(), 4u); + simplex.rollback(snapshot2); + + EXPECT_EQ(simplex.numVariables(), 2u); + EXPECT_EQ(simplex.numConstraints(), 2u); + EXPECT_EQ(simplex.computeIntegerBounds({0, 1, 0}), + std::make_pair(yMin, yMax)); + + simplex.rollback(snapshot1); + EXPECT_EQ(simplex.numVariables(), 1u); + EXPECT_EQ(simplex.numConstraints(), 0u); +} + } // namespace mlir