diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -475,6 +475,13 @@ return success(gaussianEliminateIds(position, position + 1) == 1); } + /// Removes local variables using equalities. Each equality is checked if it + /// can be reduced to the form: `e = affine-expr`, where `e` is a local + /// variable and `affine-expr` is an affine expression not containing `e`. + /// If an equality satisfies this form, the local variable is replaced in + /// each constraint and then removed. + void removeRedundantLocalVars(); + /// Eliminates identifiers from equality and inequality constraints /// in column range [posStart, posLimit). /// Returns the number of variables eliminated. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -940,9 +940,15 @@ if (isEmptyByGCDTest() || hasInvalidConstraint()) return true; - // First, eliminate as many identifiers as possible using Gaussian - // elimination. FlatAffineConstraints tmpCst(*this); + + // First, eliminate as many local variables as possible using equalities + tmpCst.removeRedundantLocalVars(); + if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint()) + return true; + + // Eliminate as many identifiers as possible using Gaussian + // elimination. unsigned currentPos = 0; while (currentPos < tmpCst.getNumIds()) { tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds()); @@ -1780,6 +1786,41 @@ equalities.resizeVertically(pos); } +/// Removes local variables using equalities. Each equality is checked if it +/// can be reduced to the form: `e = affine-expr`, where `e` is a local +/// variable and `affine-expr` is an affine expression not containing `e`. +/// If an equality satisfies this form, the local variable is replaced in +/// each constraint and then removed. +void FlatAffineConstraints::removeRedundantLocalVars() { + while (true) { + unsigned i, e, j, f; + for (i = 0, e = getNumEqualities(); i < e; ++i) { + // Find a local variable to eliminate using ith equality + for (j = getNumDimAndSymbolIds(), f = getNumIds(); j < f; ++j) + if (std::abs(atEq(i, j)) == 1) + break; + + // Break if local variable found + if (j < f) + break; + } + + if (i >= e) + break; + + // Use the ith equality to simplify other constraints + for (unsigned k = 0, t = getNumEqualities(); k < t; ++k) + eliminateFromConstraint(this, k, i, j, j, /*isEq=*/true); + for (unsigned k = 0, t = getNumInequalities(); k < t; ++k) + eliminateFromConstraint(this, k, i, j, j, /*isEq=*/false); + + // Remove the ith equality and the found local variable + removeId(j); + removeEquality(i); + normalizeConstraintsByGCD(); + } +} + std::pair FlatAffineConstraints::getLowerAndUpperBound( unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, ArrayRef localExprs, MLIRContext *context) const { diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -769,4 +769,13 @@ EXPECT_THAT(fac.getInequality(0), testing::ElementsAre(12, 20, 40)); } +TEST(FlatAffineConstraintsTest, simplifyLocalsTest) { + // (x) : (exists y: 2x + y = 1 and y = 2) + FlatAffineConstraints fac(1, 0, 1); + fac.addEquality({2, 1, -1}); + fac.addEquality({0, 1, -2}); + + EXPECT_TRUE(fac.isEmpty()); +} + } // namespace mlir