diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -475,6 +475,13 @@ return success(gaussianEliminateIds(position, position + 1) == 1); } + /// Removes local variables using equalities. Each equality is checked if it + /// can be reduced to the form: `e = affine-expr`, where `e` is a local + /// variables and `affine-expr` is an affine expression not containing `e`. + /// If an equality satisfies this form, the local variable is replaced in + /// each constraint and then removed. + void removeRedundantLocalVars(); + /// Eliminates identifiers from equality and inequality constraints /// in column range [posStart, posLimit). /// Returns the number of variables eliminated. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1780,6 +1780,46 @@ equalities.resizeVertically(pos); } +/// Removes local variables using equalities. Each equality is checked if it +/// can be reduced to the form: `e = affine-expr`, where `e` is a local +/// variables and `affine-expr` is an affine expression not containing `e`. +/// If an equality satisfies this form, the local variable is replaced in +/// each constraint and then removed. +void FlatAffineConstraints::removeRedundantLocalVars() { + bool change = true; + while (change) { + change = false; + for (int64_t i = 0; i < getNumEqualities(); ++i) { + bool foundOne = false; + unsigned eliminateVar; + for (unsigned j = getNumDimAndSymbolIds(), f = getNumIds(); j < f; ++j) { + if (std::abs(atEq(i, j)) == 1) { + foundOne = true; + eliminateVar = j; + break; + } + } + + if (!foundOne) + continue; + + change = true; + // Use equality to simplify other constraints + for (unsigned j = 0, f = getNumEqualities(); j < f; ++j) + eliminateFromConstraint(this, j, i, eliminateVar, eliminateVar, + /*isEq=*/true); + for (unsigned j = 0, f = getNumInequalities(); j < f; ++j) + eliminateFromConstraint(this, j, i, eliminateVar, eliminateVar, + /*isEq=*/false); + removeId(eliminateVar); + removeEquality(i); + --i; + normalizeConstraintsByGCD(); + break; + } + } +} + std::pair FlatAffineConstraints::getLowerAndUpperBound( unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, ArrayRef localExprs, MLIRContext *context) const {