diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -390,9 +390,14 @@ /// O(VC) time. void removeRedundantConstraints(); - /// Converts identifiers in the column range [idStart, idLimit) to local - /// variables. - void convertDimToLocal(unsigned dimStart, unsigned dimLimit); + /// Converts identifiers of kind srcKind in the range [idStart, idLimit) to + /// variables of kind dstKind and placed after all the other variables of kind + /// dstKind. The internal ordering among the moved variables is preserved. + void convertIdKind(IdKind srcKind, unsigned idStart, unsigned idLimit, + IdKind dstKind); + void convertToLocal(IdKind kind, unsigned idStart, unsigned idLimit) { + convertIdKind(kind, idStart, idLimit, IdKind::Local); + } /// Adds additional local ids to the sets such that they both have the union /// of the local ids in each set, without changing the set of points that diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1117,23 +1117,31 @@ } } -void IntegerRelation::convertDimToLocal(unsigned dimStart, unsigned dimLimit) { - assert(dimLimit <= getNumDimIds() && "Invalid dim pos range"); +void IntegerRelation::convertIdKind(IdKind srcKind, unsigned idStart, + unsigned idLimit, IdKind dstKind) { + assert(idLimit <= getNumIdKind(srcKind) && "Invalid id range"); - if (dimStart >= dimLimit) + if (idStart >= idLimit) return; // Append new local variables corresponding to the dimensions to be converted. - unsigned convertCount = dimLimit - dimStart; - unsigned newLocalIdStart = getNumIds(); - appendId(IdKind::Local, convertCount); + unsigned newIdsBegin = getIdKindEnd(dstKind); + unsigned convertCount = idLimit - idStart; + appendId(dstKind, convertCount); // Swap the new local variables with dimensions. + // + // Essentially, this moves the information corresponding to the specified ids + // of kind `srcKind` to the `convertCount` newly created ids of kind + // `dstKind`. In particular, this moves the columns in the constraint + // matrices, and zeros out the initially occupied columns (because the newly + // created ids we're swapping with were zero-initialized). + unsigned offset = getIdKindOffset(srcKind); for (unsigned i = 0; i < convertCount; ++i) - swapId(i + dimStart, i + newLocalIdStart); + swapId(offset + idStart + i, newIdsBegin + i); - // Remove dimensions converted to local variables. - removeIdRange(IdKind::SetDim, dimStart, dimLimit); + // Complete the move by deleting the initially occupied columns. + removeIdRange(srcKind, idStart, idLimit); } void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) { diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -1618,15 +1618,15 @@ FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const { FlatAffineValueConstraints domain = *this; // Convert all range variables to local variables. - domain.convertDimToLocal(getNumDomainDims(), - getNumDomainDims() + getNumRangeDims()); + domain.convertToLocal(IdKind::SetDim, getNumDomainDims(), + getNumDomainDims() + getNumRangeDims()); return domain; } FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const { FlatAffineValueConstraints range = *this; // Convert all domain variables to local variables. - range.convertDimToLocal(0, getNumDomainDims()); + range.convertToLocal(IdKind::SetDim, 0, getNumDomainDims()); return range; } @@ -1658,12 +1658,13 @@ // Convert `rel` from [otherDomain] -> [otherRange thisRange] to // [otherDomain] -> [thisRange] by converting first otherRange range ids // to local ids. - rel.convertDimToLocal(rel.getNumDomainDims(), - rel.getNumDomainDims() + removeDims); + rel.convertToLocal(IdKind::SetDim, rel.getNumDomainDims(), + rel.getNumDomainDims() + removeDims); // Convert `this` from [otherDomain thisDomain] -> [thisRange] to // [otherDomain] -> [thisRange] by converting last thisDomain domain ids // to local ids. - convertDimToLocal(getNumDomainDims() - removeDims, getNumDomainDims()); + convertToLocal(IdKind::SetDim, getNumDomainDims() - removeDims, + getNumDomainDims()); auto thisMaybeValues = getMaybeDimValues(); auto relMaybeValues = rel.getMaybeDimValues(); diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -707,7 +707,7 @@ IntegerPolyhedron poly = parsePoly("(i, j, q) : (4*q - i - j + 2 >= 0, -4*q + i + j >= 0)"); // Convert `q` to a local variable. - poly.convertDimToLocal(2, 3); + poly.convertToLocal(IdKind::SetDim, 2, 3); std::vector> divisions = {{1, 1, 0, 1}}; SmallVector denoms = {4}; @@ -721,7 +721,7 @@ { IntegerPolyhedron poly = parsePoly("(i, j, q) : (-4*q + i + j == 0)"); // Convert `q` to a local variable. - poly.convertDimToLocal(2, 3); + poly.convertToLocal(IdKind::SetDim, 2, 3); std::vector> divisions = {{-1, -1, 0, 0}}; SmallVector denoms = {4}; @@ -731,7 +731,7 @@ { IntegerPolyhedron poly = parsePoly("(i, j, q) : (4*q - i - j == 0)"); // Convert `q` to a local variable. - poly.convertDimToLocal(2, 3); + poly.convertToLocal(IdKind::SetDim, 2, 3); std::vector> divisions = {{-1, -1, 0, 0}}; SmallVector denoms = {4}; @@ -741,7 +741,7 @@ { IntegerPolyhedron poly = parsePoly("(i, j, q) : (3*q + i + j - 2 == 0)"); // Convert `q` to a local variable. - poly.convertDimToLocal(2, 3); + poly.convertToLocal(IdKind::SetDim, 2, 3); std::vector> divisions = {{1, 1, 0, -2}}; SmallVector denoms = {3}; @@ -756,7 +756,7 @@ parsePoly("(i, j, q, k) : (-3*k + i + j == 0, 4*q - " "i - j + 2 >= 0, -4*q + i + j >= 0)"); // Convert `q` and `k` to local variables. - poly.convertDimToLocal(2, 4); + poly.convertToLocal(IdKind::SetDim, 2, 4); std::vector> divisions = {{1, 1, 0, 0, 1}, {-1, -1, 0, 0, 0}}; @@ -770,7 +770,7 @@ IntegerPolyhedron poly = parsePoly("(x, q) : (x - 3 * q >= 0, -x + 3 * q + 3 >= 0)"); // Convert q to a local variable. - poly.convertDimToLocal(1, 2); + poly.convertToLocal(IdKind::SetDim, 1, 2); std::vector> divisions = {{0, 0, 0}}; SmallVector denoms = {0}; @@ -783,7 +783,7 @@ IntegerPolyhedron poly = parsePoly("(x, q) : (-1 - 3*x - 6 * q >= 0, 6 + 3*x + 6*q >= 0)"); // Convert q to a local variable. - poly.convertDimToLocal(1, 2); + poly.convertToLocal(IdKind::SetDim, 1, 2); // q = floor((-1/3 - x)/2) // = floor((1/3) + (-1 - x)/2)