diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -288,6 +288,10 @@ void removeEquality(unsigned pos); void removeInequality(unsigned pos); + /// Remove the (in)equalities at positions [start, end). + void removeEqualityRange(unsigned start, unsigned end); + void removeInequalityRange(unsigned start, unsigned end); + /// Sets the `values.size()` identifiers starting at `po`s to the specified /// values and removes them. void setAndEliminate(unsigned pos, ArrayRef values); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -2312,6 +2312,19 @@ inequalities.removeRow(pos); } +void FlatAffineConstraints::removeEqualityRange(unsigned begin, unsigned end) { + if (begin >= end) + return; + equalities.removeRows(begin, end - begin); +} + +void FlatAffineConstraints::removeInequalityRange(unsigned begin, + unsigned end) { + if (begin >= end) + return; + inequalities.removeRows(begin, end - begin); +} + /// Finds an equality that equates the specified identifier to a constant. /// Returns the position of the equality row. If 'symbolic' is set to true, /// symbols are also treated like a constant, i.e., an affine function of the diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -17,6 +17,8 @@ namespace mlir { +using testing::ElementsAre; + enum class TestFunction { Sample, Empty }; /// If fn is TestFunction::Sample (default): @@ -461,7 +463,7 @@ // The second inequality is redundant and should have been removed. The // remaining inequality should be the first one. EXPECT_EQ(fac2.getNumInequalities(), 1u); - EXPECT_THAT(fac2.getInequality(0), testing::ElementsAre(1, 0, -3)); + EXPECT_THAT(fac2.getInequality(0), ElementsAre(1, 0, -3)); EXPECT_EQ(fac2.getNumEqualities(), 1u); FlatAffineConstraints fac3 = @@ -575,6 +577,44 @@ EXPECT_EQ(fac.atIneq(1, 2), 2); } +TEST(FlatAffineConstraintsTest, removeInequality) { + FlatAffineConstraints fac = + makeFACFromConstraints(1, {{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 4}}, {}); + + fac.removeInequalityRange(0, 0); + EXPECT_EQ(fac.getNumInequalities(), 5u); + + fac.removeInequalityRange(1, 3); + EXPECT_EQ(fac.getNumInequalities(), 3u); + EXPECT_THAT(fac.getInequality(0), ElementsAre(0, 0)); + EXPECT_THAT(fac.getInequality(1), ElementsAre(3, 3)); + EXPECT_THAT(fac.getInequality(2), ElementsAre(4, 4)); + + fac.removeInequality(1); + EXPECT_EQ(fac.getNumInequalities(), 2u); + EXPECT_THAT(fac.getInequality(0), ElementsAre(0, 0)); + EXPECT_THAT(fac.getInequality(1), ElementsAre(4, 4)); +} + +TEST(FlatAffineConstraintsTest, removeEquality) { + FlatAffineConstraints fac = + makeFACFromConstraints(1, {}, {{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 4}}); + + fac.removeEqualityRange(0, 0); + EXPECT_EQ(fac.getNumEqualities(), 5u); + + fac.removeEqualityRange(1, 3); + EXPECT_EQ(fac.getNumEqualities(), 3u); + EXPECT_THAT(fac.getEquality(0), ElementsAre(0, 0)); + EXPECT_THAT(fac.getEquality(1), ElementsAre(3, 3)); + EXPECT_THAT(fac.getEquality(2), ElementsAre(4, 4)); + + fac.removeEquality(1); + EXPECT_EQ(fac.getNumEqualities(), 2u); + EXPECT_THAT(fac.getEquality(0), ElementsAre(0, 0)); + EXPECT_THAT(fac.getEquality(1), ElementsAre(4, 4)); +} + TEST(FlatAffineConstraintsTest, clearConstraints) { FlatAffineConstraints fac = makeFACFromConstraints(1, {}, {});