diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -386,9 +386,13 @@ /// O(VC) time. void removeRedundantConstraints(); - /// Converts identifiers in the column range [idStart, idLimit) to local - /// variables. - void convertDimToLocal(unsigned dimStart, unsigned dimLimit); + /// Converts identifiers of kind srcKind in the range [idStart, idLimit) to variables + /// of kind dstKind and placed after all the other variables of kind dstKind. The + /// internal ordering among the moved variables is preserved. + void convertIdKind(IdKind srcKind, unsigned idStart, unsigned idLimit, IdKind dstKind); + void convertToLocal(IdKind kind, unsigned idStart, unsigned idLimit) { + convertIdKind(kind, idStart, idLimit, IdKind::Local); + } /// Adds additional local ids to the sets such that they both have the union /// of the local ids in each set, without changing the set of points that diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1119,23 +1119,25 @@ } } -void IntegerRelation::convertDimToLocal(unsigned dimStart, unsigned dimLimit) { - assert(dimLimit <= getNumDimIds() && "Invalid dim pos range"); +void IntegerRelation::convertIdKind(IdKind srcKind, unsigned idStart, + unsigned idLimit, IdKind dstKind) { + assert(idLimit <= getNumIdKind(srcKind) && "Invalid id range"); - if (dimStart >= dimLimit) + if (idStart >= idLimit) return; // Append new local variables corresponding to the dimensions to be converted. - unsigned convertCount = dimLimit - dimStart; - unsigned newLocalIdStart = getNumIds(); - appendId(IdKind::Local, convertCount); + unsigned newIdsBegin = getIdKindEnd(dstKind); + unsigned convertCount = idLimit - idStart; + appendId(dstKind, convertCount); // Swap the new local variables with dimensions. + unsigned offset = getIdKindOffset(srcKind); for (unsigned i = 0; i < convertCount; ++i) - swapId(i + dimStart, i + newLocalIdStart); + swapId(offset + idStart + i, newIdsBegin + i); // Remove dimensions converted to local variables. - removeIdRange(IdKind::SetDim, dimStart, dimLimit); + removeIdRange(srcKind, idStart, idLimit); } void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) { diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -1618,7 +1618,7 @@ FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const { FlatAffineValueConstraints domain = *this; // Convert all range variables to local variables. - domain.convertDimToLocal(getNumDomainDims(), + domain.convertToLocal(IdKind::SetDim, getNumDomainDims(), getNumDomainDims() + getNumRangeDims()); return domain; } @@ -1626,7 +1626,7 @@ FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const { FlatAffineValueConstraints range = *this; // Convert all domain variables to local variables. - range.convertDimToLocal(0, getNumDomainDims()); + range.convertToLocal(IdKind::SetDim, 0, getNumDomainDims()); return range; } @@ -1658,12 +1658,12 @@ // Convert `rel` from [otherDomain] -> [otherRange thisRange] to // [otherDomain] -> [thisRange] by converting first otherRange range ids // to local ids. - rel.convertDimToLocal(rel.getNumDomainDims(), + rel.convertToLocal(IdKind::SetDim, rel.getNumDomainDims(), rel.getNumDomainDims() + removeDims); // Convert `this` from [otherDomain thisDomain] -> [thisRange] to // [otherDomain] -> [thisRange] by converting last thisDomain domain ids // to local ids. - convertDimToLocal(getNumDomainDims() - removeDims, getNumDomainDims()); + convertToLocal(IdKind::SetDim, getNumDomainDims() - removeDims, getNumDomainDims()); auto thisMaybeValues = getMaybeDimValues(); auto relMaybeValues = rel.getMaybeDimValues(); diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -707,7 +707,7 @@ IntegerPolyhedron poly = parsePoly("(i, j, q) : (4*q - i - j + 2 >= 0, -4*q + i + j >= 0)"); // Convert `q` to a local variable. - poly.convertDimToLocal(2, 3); + poly.convertToLocal(IdKind::SetDim, 2, 3); std::vector> divisions = {{1, 1, 0, 1}}; SmallVector denoms = {4}; @@ -721,7 +721,7 @@ { IntegerPolyhedron poly = parsePoly("(i, j, q) : (-4*q + i + j == 0)"); // Convert `q` to a local variable. - poly.convertDimToLocal(2, 3); + poly.convertToLocal(IdKind::SetDim, 2, 3); std::vector> divisions = {{-1, -1, 0, 0}}; SmallVector denoms = {4}; @@ -731,7 +731,7 @@ { IntegerPolyhedron poly = parsePoly("(i, j, q) : (4*q - i - j == 0)"); // Convert `q` to a local variable. - poly.convertDimToLocal(2, 3); + poly.convertToLocal(IdKind::SetDim, 2, 3); std::vector> divisions = {{-1, -1, 0, 0}}; SmallVector denoms = {4}; @@ -741,7 +741,7 @@ { IntegerPolyhedron poly = parsePoly("(i, j, q) : (3*q + i + j - 2 == 0)"); // Convert `q` to a local variable. - poly.convertDimToLocal(2, 3); + poly.convertToLocal(IdKind::SetDim, 2, 3); std::vector> divisions = {{1, 1, 0, -2}}; SmallVector denoms = {3}; @@ -756,7 +756,7 @@ parsePoly("(i, j, q, k) : (-3*k + i + j == 0, 4*q - " "i - j + 2 >= 0, -4*q + i + j >= 0)"); // Convert `q` and `k` to local variables. - poly.convertDimToLocal(2, 4); + poly.convertToLocal(IdKind::SetDim, 2, 4); std::vector> divisions = {{1, 1, 0, 0, 1}, {-1, -1, 0, 0, 0}}; @@ -770,7 +770,7 @@ IntegerPolyhedron poly = parsePoly("(x, q) : (x - 3 * q >= 0, -x + 3 * q + 3 >= 0)"); // Convert q to a local variable. - poly.convertDimToLocal(1, 2); + poly.convertToLocal(IdKind::SetDim, 1, 2); std::vector> divisions = {{0, 0, 0}}; SmallVector denoms = {0}; @@ -783,7 +783,7 @@ IntegerPolyhedron poly = parsePoly("(x, q) : (-1 - 3*x - 6 * q >= 0, 6 + 3*x + 6*q >= 0)"); // Convert q to a local variable. - poly.convertDimToLocal(1, 2); + poly.convertToLocal(IdKind::SetDim, 1, 2); // q = floor((-1/3 - x)/2) // = floor((1/3) + (-1 - x)/2)