diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -884,8 +884,9 @@ } /// Merge and align symbols of `this` and `other` such that both get union of - /// of symbols that are unique. Symbols with Value as `None` are considered - /// to be inequal to all other symbols. + /// of symbols that are unique. Symbols in `this` and `other` should be + /// unique. Symbols with Value as `None` are considered to be inequal to all + /// other symbols. void mergeSymbolIds(FlatAffineValueConstraints &other); protected: diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -450,9 +450,11 @@ /// Checks if the SSA values associated with `cst`'s identifiers are unique. static bool LLVM_ATTRIBUTE_UNUSED -areIdsUnique(const FlatAffineValueConstraints &cst) { +areIdsUnique(const FlatAffineValueConstraints &cst, unsigned offset = 0) { SmallPtrSet uniqueIds; - for (auto val : cst.getMaybeValues()) { + ArrayRef> maybeValues = cst.getMaybeValues(); + for (unsigned i = offset, e = maybeValues.size(); i < e; ++i) { + Optional val = maybeValues[i]; if (val.hasValue() && !uniqueIds.insert(val.getValue()).second) return false; } @@ -592,10 +594,18 @@ } /// Merge and align symbols of `this` and `other` such that both get union of -/// of symbols that are unique. Symbols with Value as `None` are considered -/// to be inequal to all other symbols. +/// of symbols that are unique. Symbols in `this` and `other` should be +/// unique. Symbols with Value as `None` are considered to be inequal to all +/// other symbols. void FlatAffineValueConstraints::mergeSymbolIds( FlatAffineValueConstraints &other) { + + assert(areIdsUnique(*this, /*offset=*/getIdKindOffset(IdKind::Symbol)) && + "Symbol ids are not unique"); + assert( + areIdsUnique(other, /*offset=*/other.getIdKindOffset(IdKind::Symbol)) && + "Symbol ids not unique"); + SmallVector aSymValues; getValues(getNumDimIds(), getNumDimAndSymbolIds(), &aSymValues); @@ -606,7 +616,7 @@ // If the id is a symbol in `other`, then align it, otherwise assume that // it is a new symbol if (other.findId(aSymValue, &loc) && loc >= other.getNumDimIds() && - loc < getNumDimAndSymbolIds()) + loc < other.getNumDimAndSymbolIds()) other.swapId(s, loc); else other.insertSymbolId(s - other.getNumDimIds(), aSymValue); @@ -621,6 +631,11 @@ assert(getNumSymbolIds() == other.getNumSymbolIds() && "expected same number of symbols"); + assert(areIdsUnique(*this, /*offset=*/getIdKindOffset(IdKind::Symbol)) && + "Symbol ids are not unique"); + assert( + areIdsUnique(other, /*offset=*/other.getIdKindOffset(IdKind::Symbol)) && + "Symbol ids are not unique"); } // Changes all symbol identifiers which are loop IVs to dim identifiers.