diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -70,7 +70,8 @@ : numIds(numDims + numSymbols + numLocals), numDims(numDims), numSymbols(numSymbols), equalities(0, numIds + 1, numReservedEqualities, numReservedCols), - inequalities(0, numIds + 1, numReservedInequalities, numReservedCols) { + inequalities(0, numIds + 1, numReservedInequalities, numReservedCols), + divisions(numLocals, numIds + 1, numLocals, numReservedCols) { assert(numReservedCols >= numIds + 1); assert(idArgs.empty() || idArgs.size() == numIds); ids.reserve(numReservedCols); @@ -78,6 +79,7 @@ ids.resize(numIds, None); else ids.append(idArgs.begin(), idArgs.end()); + divDenom.resize(numLocals, llvm::None); } /// Constructs a constraint system with the specified number of @@ -173,6 +175,18 @@ inline int64_t &atIneq(unsigned i, unsigned j) { return inequalities(i, j); } + /// Returns the value at specified division dividend row and column + inline int64_t atDiv(unsigned i, unsigned j) const { return divisions(i, j); } + inline int64_t &atDiv(unsigned i, unsigned j) { return divisions(i, j); } + + /// Returns the denominator for specified division + /// Returns None if the local identifier does not have an explicit + /// representation + inline Optional atDivDenom(unsigned pos) const { + return divDenom[pos]; + } + inline Optional &atDivDenom(unsigned pos) { return divDenom[pos]; } + /// Returns the number of columns in the constraint system. inline unsigned getNumCols() const { return numIds + 1; } @@ -288,11 +302,26 @@ /// Adds a constant upper bound constraint for the specified identifier. void addConstantUpperBound(unsigned pos, int64_t ub); + /// Sets the division for pos^th local identifier + void setDivision(unsigned pos, ArrayRef dividend, int64_t divisor); + + /// Clears division representation for the pos^th local identifier + void clearDivision(unsigned pos); + + /// Clears division representation for all local identifiers + void clearAllDivisions(); + /// Adds a new local identifier as the floordiv of an affine function of other /// identifiers, the coefficients of which are provided in 'dividend' and with /// respect to a positive constant 'divisor'. Two constraints are added to the /// system to capture equivalence with the floordiv: /// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1. + /// + /// The added division representation for floordiv is considered to be part of + /// the constraint system. If the added inequalities are removed or modified + /// such that the division representation of that floordiv becomes invalid, + /// the division representation of that local identifier must be updated or + /// cleared. Otherwise, the state of divisions becomes inconsistent. void addLocalFloorDiv(ArrayRef dividend, int64_t divisor); /// Adds a constant lower bound constraint for the specified expression. @@ -644,6 +673,14 @@ /// Coefficients of affine inequalities (in >= 0 form). Matrix inequalities; + /// Coefficients of affine expressions corresponding the dividend of division + /// representing local identifiers. + Matrix divisions; + + /// Denominators of divisions representing local identifiers. + /// Identifiers whose explicit representation is unknown are marked None + SmallVector, 8> divDenom; + /// Values corresponding to the (column) identifiers of this constraint /// system appearing in the order the identifiers correspond to columns. /// Temporary ones or those that aren't associated to any Value are set to diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -154,7 +154,8 @@ : numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()), numSymbols(set.getNumSymbols()), equalities(0, numIds + 1, set.getNumEqualities(), numIds + 1), - inequalities(0, numIds + 1, set.getNumInequalities(), numIds + 1) { + inequalities(0, numIds + 1, set.getNumInequalities(), numIds + 1), + divisions(0, numIds + 1) { ids.resize(numIds, None); // Flatten expressions and add them to the constraint system. @@ -221,6 +222,21 @@ for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { addEquality(other.getEquality(r)); } + + // Division representations that do not match are marked as unknown + for (unsigned r = 0, e = getNumLocalIds(); r < e; r++) { + if (atDivDenom(r) != other.atDivDenom(r)) { + atDivDenom(r) = llvm::None; + continue; + } + + for (unsigned k = 0; k < getNumIds(); k++) { + if (atDiv(r, k) != other.atDiv(r, k)) { + atDivDenom(r) = llvm::None; + break; + } + } + } } void FlatAffineConstraints::addLocalId(unsigned pos) { @@ -254,11 +270,16 @@ numSymbols++; } else { absolutePos = pos + getNumDimIds() + getNumSymbolIds(); + + // Insert unknown division for local identifier + divisions.insertRow(pos); + divDenom.insert(divDenom.begin() + pos, llvm::None); } numIds++; inequalities.insertColumn(absolutePos); equalities.insertColumn(absolutePos); + divisions.insertColumn(absolutePos); // If an 'id' is provided, insert it; otherwise use None. if (id) @@ -770,8 +791,14 @@ return false; if (!equalities.hasConsistentState()) return false; + if (!divisions.hasConsistentState()) + return false; if (ids.size() != getNumIds()) return false; + if (divDenom.size() != getNumLocalIds()) + return false; + if (divDenom.size() != divisions.getNumRows()) + return false; // Catches errors where numDims, numSymbols, numIds aren't consistent. if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds) @@ -864,6 +891,7 @@ // Remove eliminated identifiers from the constraints.. equalities.removeColumns(idStart, idLimit - idStart); inequalities.removeColumns(idStart, idLimit - idStart); + divisions.removeColumns(idStart, idLimit - idStart); // Update members numDims, numSymbols and numIds. unsigned numDimsEliminated = 0; @@ -877,7 +905,12 @@ if (idLimit > numDims + numSymbols) { numLocalsEliminated = std::min( idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds()); + + // Remove local identifier representations from [0, numLocalsEliminated] + divisions.removeRows(0, numLocalsEliminated); + divDenom.erase(divDenom.begin(), divDenom.begin() + numLocalsEliminated); } + unsigned numSymbolsEliminated = numColsEliminated - numDimsEliminated - numLocalsEliminated; @@ -2123,6 +2156,27 @@ inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) += ub; } +void FlatAffineConstraints::setDivision(unsigned pos, + ArrayRef dividend, + int64_t divisor) { + assert(pos < getNumLocalIds() && "Local identifier position out of range"); + assert(dividend.size() == getNumCols() && "incorrect dividend size"); + assert(divisor > 0 && "positive divisor expected"); + + for (unsigned i = 0, e = getNumIds(); i < e; ++i) + atDiv(pos, i) = dividend[i]; + atDivDenom(pos) = divisor; +} + +void FlatAffineConstraints::clearDivision(unsigned pos) { + atDivDenom(pos) = llvm::None; +} + +void FlatAffineConstraints::clearAllDivisions() { + for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) + clearDivision(i); +} + /// Adds a new local identifier as the floordiv of an affine function of other /// identifiers, the coefficients of which are provided in 'dividend' and with /// respect to a positive constant 'divisor'. Two constraints are added to the @@ -2138,10 +2192,17 @@ // Add two constraints for this new identifier 'q'. SmallVector bound(dividend.size() + 1); - // dividend - q * divisor >= 0 + // dividend >= 0 std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1, bound.begin()); bound.back() = dividend.back(); + bound[getNumIds() - 1] = 0; + + // Set explicit representation for local identifier + // "bound" currently exactly represents the divisions expression + setDivision(getNumLocalIds() - 1, bound, divisor); + + // dividend - q * divisor >= 0 bound[getNumIds() - 1] = -divisor; addInequality(bound); @@ -2181,7 +2242,17 @@ std::swap(atIneq(r, posA), atIneq(r, posB)); for (unsigned r = 0, e = getNumEqualities(); r < e; r++) std::swap(atEq(r, posA), atEq(r, posB)); + for (unsigned r = 0, e = getNumLocalIds(); r < e; r++) + std::swap(atDiv(r, posA), atDiv(r, posB)); + std::swap(getId(posA), getId(posB)); + + // Swap division representation if both local identifiers + int64_t localIdOffset = getNumDimAndSymbolIds(); + if (posA >= localIdOffset && posB >= localIdOffset) { + std::swap(divDenom[posA - localIdOffset], divDenom[posB - localIdOffset]); + divisions.swapRows(posA - localIdOffset, posB - localIdOffset); + } } void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { @@ -2546,7 +2617,25 @@ } os << ">= 0\n"; } - os << '\n'; + os << "Local Ids:\n"; + + for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) { + if (!atDivDenom(i)) { + os << "Local Id with Unknown explicit representation\n"; + continue; + } + + os << "["; + for (unsigned j = 0, f = getNumCols() - 1; j < f; ++j) + os << atDiv(i, j) << " + "; + os << atDiv(i, getNumCols() - 1); + + os << "] / "; + os << *atDivDenom(i); + os << "\n"; + } + + os << "\n"; } void FlatAffineConstraints::dump() const { print(llvm::errs()); } diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -587,4 +587,60 @@ EXPECT_EQ(fac.atIneq(0, 1), 0); } +TEST(FlatAffineConstraintsTest, testDivisions) { + FlatAffineConstraints fac = makeFACFromConstraints(1, {}, {}); + + // Add inequality x >= 0 + fac.addInequality({1, 0}); + EXPECT_EQ(fac.atIneq(0, 0), 1); + EXPECT_EQ(fac.atIneq(0, 1), 0); + + // Add division q = [(x)/4] + fac.addLocalFloorDiv({1, 0}, 4); + + EXPECT_EQ(*fac.atDivDenom(0), 4); + EXPECT_EQ(fac.atDiv(0, 0), 1); + EXPECT_EQ(fac.atDiv(0, 1), 0); + + // Add local identifier p + fac.addLocalId(1); + EXPECT_EQ(fac.atDivDenom(1), llvm::None); + + // Test appending checks divisions + FlatAffineConstraints fac2 = makeFACFromConstraints(1, {}, {}); + fac2.addInequality({-1, 0}); + EXPECT_EQ(fac.atIneq(0, 0), 1); + EXPECT_EQ(fac.atIneq(0, 1), 0); + + fac2.addLocalFloorDiv({1, 0}, 4); + + EXPECT_EQ(*fac2.atDivDenom(0), 4); + EXPECT_EQ(fac2.atDiv(0, 0), 1); + EXPECT_EQ(fac2.atDiv(0, 1), 0); + + fac2.addLocalId(1); + EXPECT_EQ(*fac2.atDivDenom(0), 4); + EXPECT_EQ(fac2.atDivDenom(1), llvm::None); + + fac.append(fac2); + EXPECT_EQ(*fac.atDivDenom(0), 4); + EXPECT_EQ(fac.atDiv(0, 0), 1); + EXPECT_EQ(fac.atDiv(0, 1), 0); + EXPECT_EQ(fac.atDivDenom(1), llvm::None); + + // Test swapping swaps divisions + fac.swapId(1, 2); + EXPECT_EQ(fac.atDivDenom(0), llvm::None); + EXPECT_EQ(*fac.atDivDenom(1), 4); + EXPECT_EQ(fac.atDiv(1, 0), 1); + EXPECT_EQ(fac.atDiv(1, 1), 0); + + // Test removal of division + fac.removeId(2); + EXPECT_EQ(fac.getNumLocalIds(), (unsigned)1); + + fac.clearDivision(0); + EXPECT_EQ(fac.atDivDenom(0), llvm::None); +} + } // namespace mlir