diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -171,7 +171,7 @@ /// within the specified range) from the system. The specified location is /// relative to the first identifier of the specified kind. void removeId(IdKind kind, unsigned pos); - void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit); + void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit) override; /// Removes the specified identifier from the system. void removeId(unsigned pos); @@ -495,7 +495,7 @@ /// Removes identifiers in the column range [idStart, idLimit), and copies any /// remaining valid data into place, updates member variables, and resizes /// arrays as needed. - void removeIdRange(unsigned idStart, unsigned idLimit) override; + void removeIdRange(unsigned idStart, unsigned idLimit); /// A parameter that controls detection of an unrealistic number of /// constraints. If the number of constraints is this many times the number of diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -84,7 +84,8 @@ void swapId(unsigned posA, unsigned posB) override; /// Remove the specified range of ids. - void removeIdRange(unsigned idStart, unsigned idLimit) override; + void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit) override; + using IntegerRelation::removeIdRange; /// Eliminate the `posB^th` local identifier, replacing every instance of it /// with the `posA^th` local identifier. This should be used when the two diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -110,8 +110,9 @@ /// first added identifier. virtual unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1); - /// Removes identifiers in the column range [idStart, idLimit). - virtual void removeIdRange(unsigned idStart, unsigned idLimit); + /// Removes identifiers of the specified kind in the column range [idStart, + /// idLimit). The range is relative to the kind of identifier. + virtual void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit); /// Returns true if both the spaces are equal i.e. if both spaces have the /// same number of identifiers of each kind (excluding Local Identifiers). @@ -182,7 +183,7 @@ unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override; /// Removes identifiers in the column range [idStart, idLimit). - void removeIdRange(unsigned idStart, unsigned idLimit) override; + void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit) override; /// Returns true if both the spaces are equal i.e. if both spaces have the /// same number of identifiers of each kind. diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h --- a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h @@ -412,6 +412,12 @@ unsigned appendSymbolId(ValueRange vals); using FlatAffineConstraints::appendSymbolId; + /// Removes identifiers in the column range [idStart, idLimit), and copies any + /// remaining valid data into place, updates member variables, and resizes + /// arrays as needed. + void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit) override; + using IntegerRelation::removeIdRange; + /// Add the specified values as a dim or symbol id depending on its nature, if /// it already doesn't exist in the system. `val` has to be either a terminal /// symbol or a loop IV, i.e., it cannot be the result affine.apply of any @@ -557,11 +563,6 @@ /// is meant to be used within an assert internally. bool hasConsistentState() const override; - /// Removes identifiers in the column range [idStart, idLimit), and copies any - /// remaining valid data into place, updates member variables, and resizes - /// arrays as needed. - void removeIdRange(unsigned idStart, unsigned idLimit) override; - /// Eliminates the identifier at the specified position using Fourier-Motzkin /// variable elimination, but uses Gaussian elimination if there is an /// equality involving that identifier. If the result of the elimination is @@ -643,17 +644,18 @@ void appendDomainId(unsigned num = 1); void appendRangeId(unsigned num = 1); + /// Removes identifiers in the column range [idStart, idLimit), and copies any + /// remaining valid data into place, updates member variables, and resizes + /// arrays as needed. + void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit) override; + using IntegerRelation::removeIdRange; + protected: // Number of dimension identifers corresponding to domain identifers. unsigned numDomainDims; // Number of dimension identifers corresponding to range identifers. unsigned numRangeDims; - - /// Removes identifiers in the column range [idStart, idLimit), and copies any - /// remaining valid data into place, updates member variables, and resizes - /// arrays as needed. - void removeIdRange(unsigned idStart, unsigned idLimit) override; }; /// Flattens 'expr' into 'flattenedExpr', which contains the coefficients of the diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -153,17 +153,55 @@ void IntegerRelation::removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit) { assert(idLimit <= getNumIdKind(kind)); - removeIdRange(getIdKindOffset(kind) + idStart, - getIdKindOffset(kind) + idLimit); + + if (idStart >= idLimit) + return; + + // Remove eliminated identifiers from the constraints. + unsigned offset = getIdKindOffset(kind); + equalities.removeColumns(offset + idStart, idLimit - idStart); + inequalities.removeColumns(offset + idStart, idLimit - idStart); + + // Remove eliminated identifiers from the space. + PresburgerLocalSpace::removeIdRange(kind, idStart, idLimit); } void IntegerRelation::removeIdRange(unsigned idStart, unsigned idLimit) { - // Update space paramaters. - PresburgerLocalSpace::removeIdRange(idStart, idLimit); + assert(idLimit <= getNumIds()); + + if (idStart >= idLimit) + return; + + // Helper function to remove ids of the specified kind in the given range + // [start, limit), The range is absolute (i.e. it is not relative to the kind + // of identifier). Also updates `limit` to reflect the deleted identifiers. + auto removeIdKindInRange = [this](IdKind kind, unsigned &start, + unsigned &limit) { + if (start >= limit) + return; + + unsigned offset = getIdKindOffset(kind); + unsigned num = getNumIdKind(kind); + + // Get `start`, `limit` relative to the specified kind. + unsigned relativeStart = + start <= offset ? 0 : std::min(num, start - offset); + unsigned relativeLimit = + limit <= offset ? 0 : std::min(num, limit - offset); + + // Remove ids of the specified kind in the relative range. + removeIdRange(kind, relativeStart, relativeLimit); + + // Update `limit` to reflect deleted identifiers. + // `start` does not need to be updated because any identifiers that are + // deleted are after position `start`. + limit -= relativeLimit - relativeStart; + }; - // Remove eliminated identifiers from the constraints.. - equalities.removeColumns(idStart, idLimit - idStart); - inequalities.removeColumns(idStart, idLimit - idStart); + removeIdKindInRange(IdKind::Domain, idStart, idLimit); + removeIdKindInRange(IdKind::Range, idStart, idLimit); + removeIdKindInRange(IdKind::Symbol, idStart, idLimit); + removeIdKindInRange(IdKind::Local, idStart, idLimit); } void IntegerRelation::removeEquality(unsigned pos) { @@ -1111,7 +1149,7 @@ swapId(i + dimStart, i + newLocalIdStart); // Remove dimensions converted to local variables. - removeIdRange(dimStart, dimLimit); + removeIdRange(IdKind::SetDim, dimStart, dimLimit); } void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) { diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -94,8 +94,9 @@ IntegerPolyhedron::swapId(posA, posB); } -void MultiAffineFunction::removeIdRange(unsigned idStart, unsigned idLimit) { - output.removeColumns(idStart, idLimit - idStart); +void MultiAffineFunction::removeIdRange(IdKind kind, unsigned idStart, + unsigned idLimit) { + output.removeColumns(idStart + getIdKindOffset(kind), idLimit - idStart); IntegerPolyhedron::removeIdRange(idStart, idLimit); } diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -99,27 +99,24 @@ return absolutePos; } -void PresburgerSpace::removeIdRange(unsigned idStart, unsigned idLimit) { - assert(idLimit <= getNumIds() && "invalid id limit"); +void PresburgerSpace::removeIdRange(IdKind kind, unsigned idStart, + unsigned idLimit) { + assert(idLimit <= getNumIdKind(kind) && "invalid id limit"); if (idStart >= idLimit) return; - // We are going to be removing one or more identifiers from the range. - assert(idStart < getNumIds() && "invalid idStart position"); - - // Update members numDomain, numRange, numSymbols and numIds. - unsigned numDomainEliminated = 0; - if (spaceKind == Relation) - numDomainEliminated = getIdKindOverlap(IdKind::Domain, idStart, idLimit); - unsigned numRangeEliminated = - getIdKindOverlap(IdKind::Range, idStart, idLimit); - unsigned numSymbolsEliminated = - getIdKindOverlap(IdKind::Symbol, idStart, idLimit); - - numDomain -= numDomainEliminated; - numRange -= numRangeEliminated; - numSymbols -= numSymbolsEliminated; + unsigned numIdsEliminated = idLimit - idStart; + if (kind == IdKind::Domain) { + assert(spaceKind == Relation && "IdKind::Domain is not supported in Set."); + numDomain -= numIdsEliminated; + } else if (kind == IdKind::Range) { + numRange -= numIdsEliminated; + } else if (kind == IdKind::Symbol) { + numSymbols -= numIdsEliminated; + } else { + llvm_unreachable("PresburgerSpace does not support local identifiers!"); + } } unsigned PresburgerLocalSpace::insertId(IdKind kind, unsigned pos, @@ -131,23 +128,17 @@ return PresburgerSpace::insertId(kind, pos, num); } -void PresburgerLocalSpace::removeIdRange(unsigned idStart, unsigned idLimit) { - assert(idLimit <= getNumIds() && "invalid id limit"); +void PresburgerLocalSpace::removeIdRange(IdKind kind, unsigned idStart, + unsigned idLimit) { + assert(idLimit <= getNumIdKind(kind) && "invalid id limit"); if (idStart >= idLimit) return; - // We are going to be removing one or more identifiers from the range. - assert(idStart < getNumIds() && "invalid idStart position"); - - unsigned numLocalsEliminated = - getIdKindOverlap(IdKind::Local, idStart, idLimit); - - // Update space parameters. - PresburgerSpace::removeIdRange(idStart, idLimit); - - // Update local ids. - numLocals -= numLocalsEliminated; + if (kind == IdKind::Local) + numLocals -= idLimit - idStart; + else + PresburgerSpace::removeIdRange(kind, idStart, idLimit); } bool PresburgerSpace::isEqual(const PresburgerSpace &other) const { diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -720,10 +720,12 @@ values.size() == getNumIds(); } -void FlatAffineValueConstraints::removeIdRange(unsigned idStart, +void FlatAffineValueConstraints::removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit) { - FlatAffineConstraints::removeIdRange(idStart, idLimit); - values.erase(values.begin() + idStart, values.begin() + idLimit); + FlatAffineConstraints::removeIdRange(kind, idStart, idLimit); + unsigned offset = getIdKindOffset(kind); + values.erase(values.begin() + idStart + offset, + values.begin() + idLimit + offset); } // Determine whether the identifier at 'pos' (say id_r) can be expressed as @@ -1723,10 +1725,18 @@ numRangeDims += num; } -void FlatAffineRelation::removeIdRange(unsigned idStart, unsigned idLimit) { +void FlatAffineRelation::removeIdRange(IdKind kind, unsigned idStart, + unsigned idLimit) { + assert(idLimit <= getNumIdKind(kind)); if (idStart >= idLimit) return; + FlatAffineValueConstraints::removeIdRange(kind, idStart, idLimit); + + // If kind is not SetDim, domain and range don't need to be updated. + if (kind != IdKind::SetDim) + return; + // Compute number of domain and range identifiers to remove. This is done by // intersecting the range of domain/range ids with range of ids to remove. unsigned intersectDomainLHS = std::min(idLimit, getNumDomainDims()); @@ -1734,8 +1744,6 @@ unsigned intersectRangeLHS = std::min(idLimit, getNumDimIds()); unsigned intersectRangeRHS = std::max(idStart, getNumDomainDims()); - FlatAffineValueConstraints::removeIdRange(idStart, idLimit); - if (intersectDomainLHS > intersectDomainRHS) numDomainDims -= intersectDomainLHS - intersectDomainRHS; if (intersectRangeLHS > intersectRangeRHS) diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp @@ -39,11 +39,12 @@ PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 3); // Remove 1 domain identifier. - space.removeIdRange(0, 1); + space.removeIdRange(IdKind::Domain, 0, 1); EXPECT_EQ(space.getNumDomainIds(), 1u); // Remove 1 symbol and 1 range identifier. - space.removeIdRange(1, 3); + space.removeIdRange(IdKind::Symbol, 0, 1); + space.removeIdRange(IdKind::Range, 0, 1); EXPECT_EQ(space.getNumDomainIds(), 1u); EXPECT_EQ(space.getNumRangeIds(), 0u); EXPECT_EQ(space.getNumSymbolIds(), 2u);