diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -884,8 +884,9 @@ } /// Merge and align symbols of `this` and `other` such that both get union of - /// of symbols that are unique. Symbols with Value as `None` are considered - /// to be inequal to all other symbols. + /// of symbols that are unique. Symbols in `this` and `other` should be + /// unique. Symbols with Value as `None` are considered to be inequal to all + /// other symbols. void mergeSymbolIds(FlatAffineValueConstraints &other); protected: diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -448,17 +448,40 @@ return areIdsAligned(*this, other); } -/// Checks if the SSA values associated with `cst`'s identifiers are unique. -static bool LLVM_ATTRIBUTE_UNUSED -areIdsUnique(const FlatAffineValueConstraints &cst) { +/// Checks if the SSA values associated with `cst`'s identifiers in range +/// [start, end) are unique. +static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique( + const FlatAffineValueConstraints &cst, unsigned start, unsigned end) { SmallPtrSet uniqueIds; - for (auto val : cst.getMaybeValues()) { + ArrayRef> maybeValues = cst.getMaybeValues(); + for (unsigned i = start; i < end; ++i) { + Optional val = maybeValues[i]; if (val.hasValue() && !uniqueIds.insert(val.getValue()).second) return false; } return true; } +/// Checks if the SSA values associated with `cst`'s identifiers are unique. +static bool LLVM_ATTRIBUTE_UNUSED +areIdsUnique(const FlatAffineConstraints &cst) { + return areIdsUnique(cst, 0, cst.getNumIds()); +} + +/// Checks if the SSA values associated with `cst`'s identifiers of kind `kind` +/// are unique. +static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique( + const FlatAffineValueConstraints &cst, FlatAffineConstraints::IdKind kind) { + + if (kind == FlatAffineConstraints::IdKind::Dimension) + return areIdsUnique(cst, 0, cst.getNumDimIds()); + if (kind == FlatAffineConstraints::IdKind::Symbol) + return areIdsUnique(cst, cst.getNumDimIds(), cst.getNumDimAndSymbolIds()); + if (kind == FlatAffineConstraints::IdKind::Local) + return areIdsUnique(cst, cst.getNumDimAndSymbolIds(), cst.getNumIds()); + llvm_unreachable("Unexpected IdKind"); +} + /// Merge and align the identifiers of A and B starting at 'offset', so that /// both constraint systems get the union of the contained identifiers that is /// dimension-wise and symbol-wise unique; both constraint systems are updated @@ -592,10 +615,15 @@ } /// Merge and align symbols of `this` and `other` such that both get union of -/// of symbols that are unique. Symbols with Value as `None` are considered -/// to be inequal to all other symbols. +/// of symbols that are unique. Symbols in `this` and `other` should be +/// unique. Symbols with Value as `None` are considered to be inequal to all +/// other symbols. void FlatAffineValueConstraints::mergeSymbolIds( FlatAffineValueConstraints &other) { + + assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique"); + assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique"); + SmallVector aSymValues; getValues(getNumDimIds(), getNumDimAndSymbolIds(), &aSymValues); @@ -606,7 +634,7 @@ // If the id is a symbol in `other`, then align it, otherwise assume that // it is a new symbol if (other.findId(aSymValue, &loc) && loc >= other.getNumDimIds() && - loc < getNumDimAndSymbolIds()) + loc < other.getNumDimAndSymbolIds()) other.swapId(s, loc); else other.insertSymbolId(s - other.getNumDimIds(), aSymValue); @@ -621,6 +649,8 @@ assert(getNumSymbolIds() == other.getNumSymbolIds() && "expected same number of symbols"); + assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique"); + assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique"); } // Changes all symbol identifiers which are loop IVs to dim identifiers.