diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -432,7 +432,8 @@ /// Merge local ids of `this` and `other`. This is done by appending local ids /// of `other` to `this` and inserting local ids of `this` to `other` at start - /// of its local ids. + /// of its local ids. Number of dimension and symbol ids should match in + /// `this` and `other`. void mergeLocalIds(FlatAffineConstraints &other); /// Removes all equalities and inequalities. diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -455,8 +455,8 @@ // Append domain constraints to `ret`. domainRel.appendRangeId(rel.getNumRangeDims()); - domainRel.mergeLocalIds(rel); domainRel.mergeSymbolIds(rel); + domainRel.mergeLocalIds(rel); rel.append(domainRel); return success(); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -512,9 +512,6 @@ b->getMaybeValues().begin() + b->getNumDimAndSymbolIds(), [](Optional id) { return id.hasValue(); })); - // Bring A and B to common local space - a->mergeLocalIds(*b); - SmallVector aDimValues; a->getValues(offset, a->getNumDimIds(), &aDimValues); @@ -543,6 +540,8 @@ // Merge and align symbols of A and B a->mergeSymbolIds(*b); + // Merge and align local ids of A and B + a->mergeLocalIds(*b); assert(areIdsAligned(*a, *b) && "IDs expected to be aligned"); } @@ -1875,8 +1874,13 @@ /// Merge local ids of `this` and `other`. This is done by appending local ids /// of `other` to `this` and inserting local ids of `this` to `other` at start -/// of its local ids. +/// of its local ids. Number of dimension and symbol ids should match in +/// `this` and `other`. void FlatAffineConstraints::mergeLocalIds(FlatAffineConstraints &other) { + assert(getNumDimIds() == other.getNumDimIds() && + "Number of dimension ids should match"); + assert(getNumSymbolIds() == other.getNumSymbolIds() && + "Number of symbol ids should match"); unsigned initLocals = getNumLocalIds(); insertLocalId(getNumLocalIds(), other.getNumLocalIds()); other.insertLocalId(0, initLocals); @@ -3665,15 +3669,32 @@ "Domain of this and range of other do not match"); FlatAffineRelation rel = other; + + // Convert `rel` from + // [otherDomain] -> [otherRange] + // to + // [otherDomain] -> [otherRange thisRange] + // and `this` from + // [thisDomain] -> [thisRange] + // to + // [otherDomain thisDomain] -> [thisRange]. + unsigned removeDims = rel.getNumRangeDims(); + insertDomainId(0, rel.getNumDomainDims()); + rel.appendRangeId(getNumRangeDims()); + + // Merge symbol and local identifiers. mergeSymbolIds(rel); mergeLocalIds(rel); - // Convert domain of `this` and range of `rel` to local identifiers. - convertDimToLocal(0, getNumDomainDims()); - rel.convertDimToLocal(rel.getNumDomainDims(), rel.getNumDimIds()); - // Add dimensions such that both relations become `domainRel -> rangeThis`. - appendDomainId(rel.getNumDomainDims()); - rel.appendRangeId(getNumRangeDims()); + // Convert `rel` from [otherDomain] -> [otherRange thisRange] to + // [otherDomain] -> [thisRange] by converting first otherRange range ids + // to local ids. + rel.convertDimToLocal(rel.getNumDomainDims(), + rel.getNumDomainDims() + removeDims); + // Convert `this` from [otherDomain thisDomain] -> [thisRange] to + // [otherDomain] -> [thisRange] by converting last thisDomain domain ids + // to local ids. + convertDimToLocal(getNumDomainDims() - removeDims, getNumDomainDims()); auto thisMaybeValues = getMaybeDimValues(); auto relMaybeValues = rel.getMaybeDimValues();