diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h @@ -500,6 +500,9 @@ /// Return the index at which the specified kind of id starts. unsigned getIdKindOffset(IdKind kind) const; + /// Return the index at which the specified kind of id ends. + unsigned getIdKindEnd(IdKind kind) const; + /// Get the number of ids of the specified kind. unsigned getNumIdKind(IdKind kind) const; diff --git a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp --- a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp @@ -74,7 +74,7 @@ Optional> IntegerPolyhedron::getRationalLexMin() const { - assert(numSymbols == 0 && "Symbols are not supported!"); + assert(getNumSymbolIds() == 0 && "Symbols are not supported!"); Optional> maybeLexMin = LexSimplex(*this).getRationalLexMin(); @@ -172,7 +172,7 @@ return; // We are going to be removing one or more identifiers from the range. - assert(idStart < numIds && "invalid idStart position"); + assert(idStart < getNumIds() && "invalid idStart position"); // TODO: Make 'removeIdRange' a lambda called from here. // Remove eliminated identifiers from the constraints.. @@ -183,14 +183,15 @@ unsigned numDimsEliminated = 0; unsigned numLocalsEliminated = 0; unsigned numColsEliminated = idLimit - idStart; - if (idStart < numDims) { - numDimsEliminated = std::min(numDims, idLimit) - idStart; + if (idStart < getNumDimIds()) { + numDimsEliminated = std::min(getNumDimIds(), idLimit) - idStart; } // Check how many local id's were removed. Note that our identifier order is // [dims, symbols, locals]. Local id start at position numDims + numSymbols. - if (idLimit > numDims + numSymbols) { - numLocalsEliminated = std::min( - idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds()); + if (idLimit > getIdKindOffset(IdKind::Local)) { + numLocalsEliminated = + std::min(idLimit - std::max(idStart, getIdKindOffset(IdKind::Local)), + getNumLocalIds()); } unsigned numSymbolsEliminated = numColsEliminated - numDimsEliminated - numLocalsEliminated; @@ -243,6 +244,16 @@ llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!"); } +unsigned IntegerPolyhedron::getIdKindEnd(IdKind kind) const { + if (kind == IdKind::Dimension) + return getNumDimIds(); + if (kind == IdKind::Symbol) + return getNumDimAndSymbolIds(); + if (kind == IdKind::Local) + return getNumIds(); + llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!"); +} + unsigned IntegerPolyhedron::getNumIdKind(IdKind kind) const { if (kind == IdKind::Dimension) return getNumDimIds(); @@ -319,7 +330,8 @@ return false; // Catches errors where numDims, numSymbols, numIds aren't consistent. - if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds) + if (getNumDimIds() > getNumIds() || getNumSymbolIds() > getNumIds() || + getNumDimAndSymbolIds() > getNumIds()) return false; return true; @@ -912,7 +924,7 @@ unsigned IntegerPolyhedron::gaussianEliminateIds(unsigned posStart, unsigned posLimit) { // Return if identifier positions to eliminate are out of range. - assert(posLimit <= numIds); + assert(posLimit <= getNumIds()); assert(hasConsistentState()); if (posStart >= posLimit) @@ -1254,7 +1266,7 @@ } void IntegerPolyhedron::setDimSymbolSeparation(unsigned newSymbolCount) { - assert(newSymbolCount <= numDims + numSymbols && + assert(newSymbolCount <= getNumDimAndSymbolIds() && "invalid separation position"); numDims = numDims + numSymbols - newSymbolCount; numSymbols = newSymbolCount; @@ -1925,7 +1937,7 @@ // lower bounds and the max of the upper bounds along each of the dimensions. LogicalResult IntegerPolyhedron::unionBoundingBox(const IntegerPolyhedron &otherCst) { - assert(otherCst.getNumDimIds() == numDims && "dims mismatch"); + assert(otherCst.getNumDimIds() == getNumDimIds() && "dims mismatch"); assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here"); assert(getNumLocalIds() == 0 && "local ids not supported yet here"); diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -188,7 +188,7 @@ // Construct from an IntegerSet. FlatAffineValueConstraints::FlatAffineValueConstraints(IntegerSet set) : FlatAffineConstraints(set) { - values.resize(numIds, None); + values.resize(getNumIds(), None); } // Construct a hyperrectangular constraint set from ValueRanges that represent @@ -1212,11 +1212,15 @@ SmallVector *newSymsPtr = nullptr; #endif // NDEBUG - dims.reserve(numDims); - syms.reserve(numSymbols); - for (unsigned i = 0; i < numDims; ++i) + dims.reserve(getNumDimIds()); + syms.reserve(getNumSymbolIds()); + for (unsigned i = getIdKindOffset(IdKind::Dimension), + e = getIdKindEnd(IdKind::Dimension); + i < e; ++i) dims.push_back(values[i] ? *values[i] : Value()); - for (unsigned i = numDims, e = numDims + numSymbols; i < e; ++i) + for (unsigned i = getIdKindOffset(IdKind::Symbol), + e = getIdKindEnd(IdKind::Symbol); + i < e; ++i) syms.push_back(values[i] ? *values[i] : Value()); AffineMap alignedMap = @@ -1371,13 +1375,13 @@ *static_cast(this) = other; values.clear(); - values.resize(numIds, None); + values.resize(getNumIds(), None); } void FlatAffineValueConstraints::fourierMotzkinEliminate( unsigned pos, bool darkShadow, bool *isResultIntegerExact) { SmallVector, 8> newVals; - newVals.reserve(numIds - 1); + newVals.reserve(getNumIds() - 1); newVals.append(values.begin(), values.begin() + pos); newVals.append(values.begin() + pos + 1, values.end()); // Note: Base implementation discards all associated Values. @@ -1397,7 +1401,7 @@ LogicalResult FlatAffineValueConstraints::unionBoundingBox( const FlatAffineValueConstraints &otherCst) { - assert(otherCst.getNumDimIds() == numDims && "dims mismatch"); + assert(otherCst.getNumDimIds() == getNumDimIds() && "dims mismatch"); assert(otherCst.getMaybeValues() .slice(0, getNumDimIds()) .equals(getMaybeValues().slice(0, getNumDimIds())) && @@ -1408,7 +1412,7 @@ // Align `other` to this. if (!areIdsAligned(*this, otherCst)) { FlatAffineValueConstraints otherCopy(otherCst); - mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy); + mergeAndAlignIds(/*offset=*/getNumDimIds(), this, &otherCopy); return FlatAffineConstraints::unionBoundingBox(otherCopy); }