diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -480,6 +480,14 @@ return success(gaussianEliminateIds(position, position + 1) == 1); } + /// Removes local variables using equalities. Each equality is checked if it + /// can be reduced to the form: `e = affine-expr`, where `e` is a local + /// variable and `affine-expr` is an affine expression not containing `e`. + /// If an equality satisfies this form, the local variable is replaced in + /// each constraint and then removed. The equality used to replace this local + /// variable is also removed. + void removeRedundantLocalVars(); + /// Eliminates identifiers from equality and inequality constraints /// in column range [posStart, posLimit). /// Returns the number of variables eliminated. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -949,9 +949,14 @@ if (isEmptyByGCDTest() || hasInvalidConstraint()) return true; - // First, eliminate as many identifiers as possible using Gaussian - // elimination. FlatAffineConstraints tmpCst(*this); + + // First, eliminate as many local variables as possible using equalities. + tmpCst.removeRedundantLocalVars(); + if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint()) + return true; + + // Eliminate as many identifiers as possible using Gaussian elimination. unsigned currentPos = 0; while (currentPos < tmpCst.getNumIds()) { tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds()); @@ -1798,6 +1803,54 @@ other.insertLocalId(0, initLocals); } +/// Removes local variables using equalities. Each equality is checked if it +/// can be reduced to the form: `e = affine-expr`, where `e` is a local +/// variable and `affine-expr` is an affine expression not containing `e`. +/// If an equality satisfies this form, the local variable is replaced in +/// each constraint and then removed. The equality used to replace this local +/// variable is also removed. +void FlatAffineConstraints::removeRedundantLocalVars() { + // Normalize the equality constraints to reduce coefficients of local + // variables to 1 wherever possible. + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) + normalizeConstraintByGCD(this, i); + + while (true) { + unsigned i, e, j, f; + for (i = 0, e = getNumEqualities(); i < e; ++i) { + // Find a local variable to eliminate using ith equality. + for (j = getNumDimAndSymbolIds(), f = getNumIds(); j < f; ++j) + if (std::abs(atEq(i, j)) == 1) + break; + + // Local variable can be eliminated using ith equality. + if (j < f) + break; + } + + // No equality can be used to eliminate a local variable. + if (i == e) + break; + + // Use the ith equality to simplify other equalities. If any changes + // are made to an equality constraint, it is normalized by GCD. + for (unsigned k = 0, t = getNumEqualities(); k < t; ++k) { + if (atEq(k, j) != 0) { + eliminateFromConstraint(this, k, i, j, j, /*isEq=*/true); + normalizeConstraintByGCD(this, k); + } + } + + // Use the ith equality to simplify inequalities. + for (unsigned k = 0, t = getNumInequalities(); k < t; ++k) + eliminateFromConstraint(this, k, i, j, j, /*isEq=*/false); + + // Remove the ith equality and the found local variable. + removeId(j); + removeEquality(i); + } +} + std::pair FlatAffineConstraints::getLowerAndUpperBound( unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, ArrayRef localExprs, MLIRContext *context) const { diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -769,4 +769,30 @@ EXPECT_THAT(fac.getInequality(0), testing::ElementsAre(12, 20, 40)); } +TEST(FlatAffineConstraintsTest, simplifyLocalsTest) { + // (x) : (exists y: 2x + y = 1 and y = 2). + FlatAffineConstraints fac(1, 0, 1); + fac.addEquality({2, 1, -1}); + fac.addEquality({0, 1, -2}); + + EXPECT_TRUE(fac.isEmpty()); + + // (x) : (exists y, z, w: 3x + y = 1 and 2y = z and 3y = w and z = w). + FlatAffineConstraints fac2(1, 0, 3); + fac2.addEquality({3, 1, 0, 0, -1}); + fac2.addEquality({0, 2, -1, 0, 0}); + fac2.addEquality({0, 3, 0, -1, 0}); + fac2.addEquality({0, 0, 1, -1, 0}); + + EXPECT_TRUE(fac2.isEmpty()); + + // (x) : (exists y: x >= y + 1 and 2x + y = 0 and y >= -1). + FlatAffineConstraints fac3(1, 0, 1); + fac3.addInequality({1, -1, -1}); + fac3.addInequality({0, 1, 1}); + fac3.addEquality({2, 1, 0}); + + EXPECT_TRUE(fac3.isEmpty()); +} + } // namespace mlir