diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -70,7 +70,9 @@ : numIds(numDims + numSymbols + numLocals), numDims(numDims), numSymbols(numSymbols), equalities(0, numIds + 1, numReservedEqualities, numReservedCols), - inequalities(0, numIds + 1, numReservedInequalities, numReservedCols) { + inequalities(0, numIds + 1, numReservedInequalities, numReservedCols), + divisions(numLocals, numIds + 1, numLocals, numReservedCols), + divDenom(numLocals, llvm::None) { assert(numReservedCols >= numIds + 1); assert(idArgs.empty() || idArgs.size() == numIds); ids.reserve(numReservedCols); @@ -173,6 +175,17 @@ inline int64_t &atIneq(unsigned i, unsigned j) { return inequalities(i, j); } + /// Returns the value at specified division dividend row and column. + inline int64_t atDiv(unsigned i, unsigned j) const { return divisions(i, j); } + inline int64_t &atDiv(unsigned i, unsigned j) { return divisions(i, j); } + + /// Returns the denominator for the pos^th local identifier. Returns None if + /// the local identifier does not have an explicit representation. + inline Optional atDivDenom(unsigned pos) const { + return divDenom[pos]; + } + inline Optional &atDivDenom(unsigned pos) { return divDenom[pos]; } + /// Returns the number of columns in the constraint system. inline unsigned getNumCols() const { return numIds + 1; } @@ -198,6 +211,10 @@ return inequalities.getRow(idx); } + inline ArrayRef getDivNumerator(unsigned idx) const { + return divisions.getRow(idx); + } + /// Adds constraints (lower and upper bounds) for the specified 'affine.for' /// operation's Value using IR information stored in its bound maps. The /// right identifier is first looked up using forOp's Value. Asserts if the @@ -288,11 +305,26 @@ /// Adds a constant upper bound constraint for the specified identifier. void addConstantUpperBound(unsigned pos, int64_t ub); + /// Sets the division representation for pos^th local identifier. + void setDivision(unsigned pos, ArrayRef dividend, int64_t divisor); + + /// Clears division representation for the pos^th local identifier. + void clearDivision(unsigned pos); + + /// Clears division representation for all local identifiers. + void clearAllDivisions(); + /// Adds a new local identifier as the floordiv of an affine function of other /// identifiers, the coefficients of which are provided in 'dividend' and with /// respect to a positive constant 'divisor'. Two constraints are added to the /// system to capture equivalence with the floordiv: /// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1. + /// The added division representation for the floordiv is considered to be + /// part of the constraint system. If the added inequalities are removed or + /// modified such that the division representation of that floordiv becomes + /// invalid, the division representation of that local identifier must be + /// updated or cleared. Otherwise, the FAC is considered to be in an + /// inconsistent state. void addLocalFloorDiv(ArrayRef dividend, int64_t divisor); /// Adds a constant lower bound constraint for the specified expression. @@ -644,6 +676,14 @@ /// Coefficients of affine inequalities (in >= 0 form). Matrix inequalities; + /// Coefficients of affine expressions corresponding to the dividend of + /// division representing local identifiers. + Matrix divisions; + + /// Denominators of divisions representing local identifiers. Identifiers + /// whose explicit representation is unknown are marked None. + SmallVector, 8> divDenom; + /// Values corresponding to the (column) identifiers of this constraint /// system appearing in the order the identifiers correspond to columns. /// Temporary ones or those that aren't associated to any Value are set to diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -154,7 +154,8 @@ : numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()), numSymbols(set.getNumSymbols()), equalities(0, numIds + 1, set.getNumEqualities(), numIds + 1), - inequalities(0, numIds + 1, set.getNumInequalities(), numIds + 1) { + inequalities(0, numIds + 1, set.getNumInequalities(), numIds + 1), + divisions(0, numIds + 1) { ids.resize(numIds, None); // Flatten expressions and add them to the constraint system. @@ -221,6 +222,21 @@ for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { addEquality(other.getEquality(r)); } + + // Division representations that do not match are marked as unknown. + for (unsigned r = 0, e = getNumLocalIds(); r < e; ++r) { + if (atDivDenom(r) != other.atDivDenom(r)) { + atDivDenom(r) = llvm::None; + continue; + } + + for (unsigned k = 0; k < getNumIds(); ++k) { + if (atDiv(r, k) != other.atDiv(r, k)) { + atDivDenom(r) = llvm::None; + break; + } + } + } } void FlatAffineConstraints::addLocalId(unsigned pos) { @@ -254,11 +270,16 @@ numSymbols++; } else { absolutePos = pos + getNumDimIds() + getNumSymbolIds(); + + // Insert unknown division for local identifier. + divisions.insertRow(pos); + divDenom.insert(divDenom.begin() + pos, llvm::None); } numIds++; inequalities.insertColumn(absolutePos); equalities.insertColumn(absolutePos); + divisions.insertColumn(absolutePos); // If an 'id' is provided, insert it; otherwise use None. if (id) @@ -770,8 +791,14 @@ return false; if (!equalities.hasConsistentState()) return false; + if (!divisions.hasConsistentState()) + return false; if (ids.size() != getNumIds()) return false; + if (divDenom.size() != getNumLocalIds()) + return false; + if (divDenom.size() != divisions.getNumRows()) + return false; // Catches errors where numDims, numSymbols, numIds aren't consistent. if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds) @@ -864,6 +891,7 @@ // Remove eliminated identifiers from the constraints.. equalities.removeColumns(idStart, idLimit - idStart); inequalities.removeColumns(idStart, idLimit - idStart); + divisions.removeColumns(idStart, idLimit - idStart); // Update members numDims, numSymbols and numIds. unsigned numDimsEliminated = 0; @@ -877,7 +905,12 @@ if (idLimit > numDims + numSymbols) { numLocalsEliminated = std::min( idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds()); + + // Remove local identifier representations from [0, numLocalsEliminated]. + divisions.removeRows(0, numLocalsEliminated); + divDenom.erase(divDenom.begin(), divDenom.begin() + numLocalsEliminated); } + unsigned numSymbolsEliminated = numColsEliminated - numDimsEliminated - numLocalsEliminated; @@ -2123,6 +2156,27 @@ inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) += ub; } +void FlatAffineConstraints::setDivision(unsigned pos, + ArrayRef dividend, + int64_t divisor) { + assert(pos < getNumLocalIds() && "Local identifier position out of range"); + assert(dividend.size() == getNumCols() && "incorrect dividend size"); + assert(divisor > 0 && "positive divisor expected"); + + for (unsigned i = 0, e = getNumIds(); i <= e; ++i) + atDiv(pos, i) = dividend[i]; + atDivDenom(pos) = divisor; +} + +void FlatAffineConstraints::clearDivision(unsigned pos) { + atDivDenom(pos) = llvm::None; +} + +void FlatAffineConstraints::clearAllDivisions() { + for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) + clearDivision(i); +} + /// Adds a new local identifier as the floordiv of an affine function of other /// identifiers, the coefficients of which are provided in 'dividend' and with /// respect to a positive constant 'divisor'. Two constraints are added to the @@ -2136,16 +2190,21 @@ addLocalId(getNumLocalIds()); // Add two constraints for this new identifier 'q'. - SmallVector bound(dividend.size() + 1); + SmallVector bound(dividend.size() + 1, 0); - // dividend - q * divisor >= 0 + // bound = dividend, with zero as coefficient of the new local identifier. std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1, bound.begin()); bound.back() = dividend.back(); + + // Set explicit representation for local identifier. + setDivision(getNumLocalIds() - 1, bound, divisor); + + // dividend - q * divisor >= 0 bound[getNumIds() - 1] = -divisor; addInequality(bound); - // -dividend +qdivisor * q + divisor - 1 >= 0 + // divisor * q + divisor - 1 >= 0 std::transform(bound.begin(), bound.end(), bound.begin(), std::negate()); bound[bound.size() - 1] += divisor - 1; @@ -2181,7 +2240,24 @@ std::swap(atIneq(r, posA), atIneq(r, posB)); for (unsigned r = 0, e = getNumEqualities(); r < e; r++) std::swap(atEq(r, posA), atEq(r, posB)); + for (unsigned r = 0, e = getNumLocalIds(); r < e; r++) + std::swap(atDiv(r, posA), atDiv(r, posB)); + std::swap(getId(posA), getId(posB)); + + // Swap division representation if both identifiers are local. + // Clear division representation if one of them is a local identifier and + // other is not. + unsigned localIdOffset = getNumDimAndSymbolIds(); + if (posA >= localIdOffset && posB >= localIdOffset) { + std::swap(divDenom[posA - localIdOffset], divDenom[posB - localIdOffset]); + divisions.swapRows(posA - localIdOffset, posB - localIdOffset); + } else { + if (posA >= localIdOffset) + clearDivision(posA); + if (posB >= localIdOffset) + clearDivision(posB); + } } void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { @@ -2546,7 +2622,25 @@ } os << ">= 0\n"; } - os << '\n'; + os << "Local Ids:\n"; + + for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) { + if (!atDivDenom(i)) { + os << "Local Id with Unknown explicit representation\n"; + continue; + } + + os << "["; + for (unsigned j = 0, f = getNumCols() - 1; j < f; ++j) + os << atDiv(i, j) << " "; + os << atDiv(i, getNumCols() - 1); + + os << "] / "; + os << *atDivDenom(i); + os << "\n"; + } + + os << "\n"; } void FlatAffineConstraints::dump() const { print(llvm::errs()); } diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -587,4 +587,118 @@ EXPECT_EQ(fac.atIneq(0, 1), 0); } +TEST(FlatAffineConstraintsTest, testDivisions) { + FlatAffineConstraints fac = makeFACFromConstraints(1, {}, {}); + + // Add inequality x >= 0. + fac.addInequality({1, 0}); + EXPECT_EQ(fac.atIneq(0, 0), 1); + EXPECT_EQ(fac.atIneq(0, 1), 0); + + // Add division q = [(x)/4]. + fac.addLocalFloorDiv({1, 0}, 4); + EXPECT_NE(fac.atDivDenom(0), llvm::None); + EXPECT_EQ(*fac.atDivDenom(0), 4); + EXPECT_EQ(fac.atDiv(0, 0), 1); + EXPECT_EQ(fac.atDiv(0, 1), 0); + EXPECT_EQ(fac.atDiv(0, 2), 0); + + // Add local identifier p. + fac.addLocalId(1); + EXPECT_EQ(fac.atDivDenom(1), llvm::None); + + // Set division representation to local identifier at pos 1. + fac.setDivision(1, {0, 1, 0, 0}, 2); + EXPECT_NE(fac.atDivDenom(1), llvm::None); + EXPECT_EQ(*fac.atDivDenom(1), 2); + EXPECT_EQ(fac.atDiv(1, 0), 0); + EXPECT_EQ(fac.atDiv(1, 1), 1); + EXPECT_EQ(fac.atDiv(1, 2), 0); + EXPECT_EQ(fac.atDiv(1, 3), 0); + + // Add local identifier at pos 0. + fac.addLocalId(0); + EXPECT_EQ(fac.atDivDenom(0), llvm::None); + EXPECT_EQ(*fac.atDivDenom(1), 4); + EXPECT_EQ(*fac.atDivDenom(2), 2); + + // Remove local identifier at 0. + fac.removeId(0 + fac.getNumDimAndSymbolIds()); + EXPECT_EQ(*fac.atDivDenom(0), 4); + EXPECT_EQ(*fac.atDivDenom(1), 2); + + // Clear division representation at local identifier at pos 1. + fac.removeId(1 + fac.getNumDimAndSymbolIds()); + EXPECT_EQ(fac.getNumLocalIds(), 1u); + + // Add local identifier at position 1. + fac.addLocalId(1); + EXPECT_EQ(fac.atDivDenom(1), llvm::None); + + // Match the division for testing of upcoming append test. + fac.setDivision(0, {1, 0, 0, 0}, 4); + + // Test appending divisions. + FlatAffineConstraints fac2 = makeFACFromConstraints(1, {}, {}); + fac2.addInequality({-1, 0}); + EXPECT_EQ(fac2.atIneq(0, 0), -1); + EXPECT_EQ(fac2.atIneq(0, 1), 0); + + fac2.addLocalFloorDiv({1, 0}, 4); + EXPECT_NE(fac2.atDivDenom(0), llvm::None); + EXPECT_EQ(*fac2.atDivDenom(0), 4); + EXPECT_EQ(fac2.atDiv(0, 0), 1); + EXPECT_EQ(fac2.atDiv(0, 1), 0); + + fac2.addLocalId(1); + EXPECT_EQ(*fac2.atDivDenom(0), 4); + EXPECT_EQ(fac2.atDivDenom(1), llvm::None); + + fac.append(fac2); + EXPECT_NE(fac.atDivDenom(0), llvm::None); + EXPECT_EQ(*fac.atDivDenom(0), 4); + EXPECT_EQ(fac.atDiv(0, 0), 1); + EXPECT_EQ(fac.atDiv(0, 1), 0); + EXPECT_EQ(fac.atDivDenom(1), llvm::None); + + // Test swapping swaps divisions. + fac.swapId(1, 2); + EXPECT_EQ(fac.atDivDenom(0), llvm::None); + EXPECT_EQ(*fac.atDivDenom(1), 4); + EXPECT_EQ(fac.atDiv(1, 0), 1); + EXPECT_EQ(fac.atDiv(1, 1), 0); + + // Test removal of division. + fac.removeId(2); + EXPECT_EQ(fac.getNumLocalIds(), 1u); + + // Test division with constant terms. + fac.addLocalFloorDiv({1, 1, 1}, 5); + EXPECT_NE(fac.atDivDenom(1), llvm::None); + EXPECT_EQ(*fac.atDivDenom(1), 5); + EXPECT_EQ(fac.atDiv(1, 0), 1); + EXPECT_EQ(fac.atDiv(1, 1), 1); + EXPECT_EQ(fac.atDiv(1, 2), 0); + EXPECT_EQ(fac.atDiv(1, 3), 1); + + fac.addLocalFloorDiv({1, 0, 0, 5}, 7); + EXPECT_NE(fac.atDivDenom(2), llvm::None); + EXPECT_EQ(*fac.atDivDenom(2), 7); + EXPECT_EQ(fac.atDiv(2, 0), 1); + EXPECT_EQ(fac.atDiv(2, 1), 0); + EXPECT_EQ(fac.atDiv(2, 2), 0); + EXPECT_EQ(fac.atDiv(2, 3), 0); + EXPECT_EQ(fac.atDiv(2, 4), 5); + + // Test removing division between 2 divisions. + fac.removeId(1 + fac.getNumDimAndSymbolIds()); + EXPECT_EQ(fac.getNumLocalIds(), 2u); + EXPECT_EQ(*fac.atDivDenom(0), 5); + EXPECT_EQ(*fac.atDivDenom(1), 7); + + // Test division representation clearing. + fac.clearDivision(0); + EXPECT_EQ(fac.atDivDenom(0), llvm::None); +} + } // namespace mlir