diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -475,6 +475,13 @@ return success(gaussianEliminateIds(position, position + 1) == 1); } + /// Removes local variables using equalities. Each equality is checked if it + /// can be reduced to the form: `e = affine-expr`, where `e` is a local + /// variable and `affine-expr` is an affine expression not containing `e`. + /// If an equality satisfies this form, the local variable is replaced in + /// each constraint and then removed. + void removeRedundantLocalVars(); + /// Eliminates identifiers from equality and inequality constraints /// in column range [posStart, posLimit). /// Returns the number of variables eliminated. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -940,9 +940,15 @@ if (isEmptyByGCDTest() || hasInvalidConstraint()) return true; - // First, eliminate as many identifiers as possible using Gaussian - // elimination. FlatAffineConstraints tmpCst(*this); + + // First, eliminate as many local variables as possible using equalities + tmpCst.removeRedundantLocalVars(); + if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint()) + return true; + + // Eliminate as many identifiers as possible using Gaussian + // elimination. unsigned currentPos = 0; while (currentPos < tmpCst.getNumIds()) { tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds()); @@ -1780,6 +1786,46 @@ equalities.resizeVertically(pos); } +/// Removes local variables using equalities. Each equality is checked if it +/// can be reduced to the form: `e = affine-expr`, where `e` is a local +/// variables and `affine-expr` is an affine expression not containing `e`. +/// If an equality satisfies this form, the local variable is replaced in +/// each constraint and then removed. +void FlatAffineConstraints::removeRedundantLocalVars() { + bool change = true; + while (change) { + change = false; + for (int64_t i = 0; i < getNumEqualities(); ++i) { + bool foundOne = false; + unsigned eliminateVar; + for (unsigned j = getNumDimAndSymbolIds(), f = getNumIds(); j < f; ++j) { + if (std::abs(atEq(i, j)) == 1) { + foundOne = true; + eliminateVar = j; + break; + } + } + + if (!foundOne) + continue; + + change = true; + // Use equality to simplify other constraints + for (unsigned j = 0, f = getNumEqualities(); j < f; ++j) + eliminateFromConstraint(this, j, i, eliminateVar, eliminateVar, + /*isEq=*/true); + for (unsigned j = 0, f = getNumInequalities(); j < f; ++j) + eliminateFromConstraint(this, j, i, eliminateVar, eliminateVar, + /*isEq=*/false); + removeId(eliminateVar); + removeEquality(i); + --i; + normalizeConstraintsByGCD(); + break; + } + } +} + std::pair FlatAffineConstraints::getLowerAndUpperBound( unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, ArrayRef localExprs, MLIRContext *context) const { diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -769,4 +769,13 @@ EXPECT_THAT(fac.getInequality(0), testing::ElementsAre(12, 20, 40)); } +TEST(FlatAffineConstraintsTest, simplifyLocalsTest) { + // (x) : (exists y: 2x + y = 1 and y = 2) + FlatAffineConstraints fac(1, 0, 1); + fac.addEquality({2, 1, -1}); + fac.addEquality({0, 1, -2}); + + EXPECT_TRUE(fac.isEmpty()); +} + } // namespace mlir