diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -426,6 +426,11 @@ /// O(VC) time. void removeRedundantConstraints(); + /// Merge local ids of `this` and `other`. This is done by appending local ids + /// of `other` to `this` and inserting local ids of `this` to `other` at start + /// of its local ids. + void mergeLocalIds(FlatAffineConstraints &other); + /// Removes all equalities and inequalities. void clearConstraints(); @@ -841,6 +846,11 @@ setValue(i, values[i - start]); } + /// Merge and align symbols of `this` and `other` such that both get union of + /// of symbols that are unique. Symbols with Value as `None` are considered + /// to be inequal to all other symbols. + void mergeSymbolIds(FlatAffineValueConstraints &other); + protected: /// Returns false if the fields corresponding to various identifier counts, or /// equality/inequality buffer sizes aren't consistent; true otherwise. This diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -440,13 +440,11 @@ b->getMaybeValues().begin() + b->getNumDimAndSymbolIds(), [](Optional id) { return id.hasValue(); })); - // Place local id's of A after local id's of B. - b->insertLocalId(/*pos=*/0, /*num=*/a->getNumLocalIds()); - a->appendLocalId(/*num=*/b->getNumLocalIds() - a->getNumLocalIds()); + // Bring A and B to common local space + a->mergeLocalIds(*b); - SmallVector aDimValues, aSymValues; + SmallVector aDimValues; a->getValues(offset, a->getNumDimIds(), &aDimValues); - a->getValues(a->getNumDimIds(), a->getNumDimAndSymbolIds(), &aSymValues); { // Merge dims from A into B. @@ -471,29 +469,8 @@ "expected same number of dims"); } - { - // Merge symbols: merge A's symbols into B first. - unsigned s = 0; - for (auto aSymValue : aSymValues) { - unsigned loc; - if (b->findId(aSymValue, &loc)) { - assert(loc >= b->getNumDimIds() && loc < b->getNumDimAndSymbolIds() && - "A's symbol appears in B's non-symbol position"); - b->swapId(s + b->getNumDimIds(), loc); - } else { - b->insertSymbolId(s, aSymValue); - } - s++; - } - // Symbols that are in B, but not in A, are added at the end. - for (unsigned t = a->getNumDimAndSymbolIds(), - e = b->getNumDimAndSymbolIds(); - t < e; t++) { - a->appendSymbolId(b->getValue(t)); - } - assert(a->getNumDimAndSymbolIds() == b->getNumDimAndSymbolIds() && - "expected same number of dims and symbols"); - } + // Merge and align symbols of A and B + a->mergeSymbolIds(*b); assert(areIdsAligned(*a, *b) && "IDs expected to be aligned"); } @@ -571,6 +548,38 @@ } } +/// Merge and align symbols of `this` and `other` such that both get union of +/// of symbols that are unique. Symbols with Value as `None` are considered +/// to be inequal to all other symbols. +void FlatAffineValueConstraints::mergeSymbolIds( + FlatAffineValueConstraints &other) { + SmallVector aSymValues; + getValues(getNumDimIds(), getNumDimAndSymbolIds(), &aSymValues); + + // Merge symbols: merge symbols into `other` first from `this`. + unsigned s = other.getNumDimIds(); + for (Value aSymValue : aSymValues) { + unsigned loc; + // If the id is a symbol in `other`, then align it, otherwise assume that + // it is a new symbol + if (other.findId(aSymValue, &loc) && loc >= other.getNumDimIds() && + loc < getNumDimAndSymbolIds()) + other.swapId(s, loc); + else + other.insertSymbolId(s - other.getNumDimIds(), aSymValue); + s++; + } + + // Symbols that are in other, but not in this, are added at the end. + for (unsigned t = other.getNumDimIds() + getNumSymbolIds(), + e = other.getNumDimAndSymbolIds(); + t < e; t++) + insertSymbolId(getNumSymbolIds(), other.getValue(t)); + + assert(getNumSymbolIds() == other.getNumSymbolIds() && + "expected same number of symbols"); +} + // Changes all symbol identifiers which are loop IVs to dim identifiers. void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() { // Gather all symbols which are loop IVs. @@ -1780,6 +1789,15 @@ equalities.resizeVertically(pos); } +/// Merge local ids of `this` and `other`. This is done by appending local ids +/// of `other` to `this` and inserting local ids of `this` to `other` at start +/// of its local ids. +void FlatAffineConstraints::mergeLocalIds(FlatAffineConstraints &other) { + unsigned initLocals = getNumLocalIds(); + insertLocalId(getNumLocalIds(), other.getNumLocalIds()); + other.insertLocalId(0, initLocals); +} + std::pair FlatAffineConstraints::getLowerAndUpperBound( unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, ArrayRef localExprs, MLIRContext *context) const {