diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -71,7 +71,8 @@ numSymbols(numSymbols), equalities(0, numIds + 1, numReservedEqualities, numReservedCols), inequalities(0, numIds + 1, numReservedInequalities, numReservedCols), - divisions(numLocals, numIds + 1, numLocals, numReservedCols) { + divisions(numLocals, numIds + 1, numLocals, numReservedCols), + divDenom(numLocals, llvm::None) { assert(numReservedCols >= numIds + 1); assert(idArgs.empty() || idArgs.size() == numIds); ids.reserve(numReservedCols); @@ -79,7 +80,6 @@ ids.resize(numIds, None); else ids.append(idArgs.begin(), idArgs.end()); - divDenom.resize(numLocals, llvm::None); } /// Constructs a constraint system with the specified number of @@ -175,13 +175,12 @@ inline int64_t &atIneq(unsigned i, unsigned j) { return inequalities(i, j); } - /// Returns the value at specified division dividend row and column + /// Returns the value at specified division dividend row and column. inline int64_t atDiv(unsigned i, unsigned j) const { return divisions(i, j); } inline int64_t &atDiv(unsigned i, unsigned j) { return divisions(i, j); } - /// Returns the denominator for specified division - /// Returns None if the local identifier does not have an explicit - /// representation + /// Returns the denominator for the pos^th local identifier. Returns None if + /// the local identifier does not have an explicit representation. inline Optional atDivDenom(unsigned pos) const { return divDenom[pos]; } @@ -212,6 +211,10 @@ return inequalities.getRow(idx); } + inline ArrayRef getDivNumerator(unsigned idx) const { + return divisions.getRow(idx); + } + /// Adds constraints (lower and upper bounds) for the specified 'affine.for' /// operation's Value using IR information stored in its bound maps. The /// right identifier is first looked up using forOp's Value. Asserts if the @@ -302,13 +305,13 @@ /// Adds a constant upper bound constraint for the specified identifier. void addConstantUpperBound(unsigned pos, int64_t ub); - /// Sets the division for pos^th local identifier + /// Sets the division representation for pos^th local identifier. void setDivision(unsigned pos, ArrayRef dividend, int64_t divisor); - /// Clears division representation for the pos^th local identifier + /// Clears division representation for the pos^th local identifier. void clearDivision(unsigned pos); - /// Clears division representation for all local identifiers + /// Clears division representation for all local identifiers. void clearAllDivisions(); /// Adds a new local identifier as the floordiv of an affine function of other @@ -316,12 +319,12 @@ /// respect to a positive constant 'divisor'. Two constraints are added to the /// system to capture equivalence with the floordiv: /// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1. - /// - /// The added division representation for floordiv is considered to be part of - /// the constraint system. If the added inequalities are removed or modified - /// such that the division representation of that floordiv becomes invalid, - /// the division representation of that local identifier must be updated or - /// cleared. Otherwise, the state of divisions becomes inconsistent. + /// The added division representation for the floordiv is considered to be + /// part of the constraint system. If the added inequalities are removed or + /// modified such that the division representation of that floordiv becomes + /// invalid, the division representation of that local identifier must be + /// updated or cleared. Otherwise, the FAC is considered to be in an + /// inconsistent state. void addLocalFloorDiv(ArrayRef dividend, int64_t divisor); /// Adds a constant lower bound constraint for the specified expression. @@ -673,12 +676,12 @@ /// Coefficients of affine inequalities (in >= 0 form). Matrix inequalities; - /// Coefficients of affine expressions corresponding the dividend of division - /// representing local identifiers. + /// Coefficients of affine expressions corresponding to the dividend of + /// division representing local identifiers. Matrix divisions; - /// Denominators of divisions representing local identifiers. - /// Identifiers whose explicit representation is unknown are marked None + /// Denominators of divisions representing local identifiers. Identifiers + /// whose explicit representation is unknown are marked None. SmallVector, 8> divDenom; /// Values corresponding to the (column) identifiers of this constraint diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -223,14 +223,14 @@ addEquality(other.getEquality(r)); } - // Division representations that do not match are marked as unknown - for (unsigned r = 0, e = getNumLocalIds(); r < e; r++) { + // Division representations that do not match are marked as unknown. + for (unsigned r = 0, e = getNumLocalIds(); r < e; ++r) { if (atDivDenom(r) != other.atDivDenom(r)) { atDivDenom(r) = llvm::None; continue; } - for (unsigned k = 0; k < getNumIds(); k++) { + for (unsigned k = 0; k < getNumIds(); ++k) { if (atDiv(r, k) != other.atDiv(r, k)) { atDivDenom(r) = llvm::None; break; @@ -271,7 +271,7 @@ } else { absolutePos = pos + getNumDimIds() + getNumSymbolIds(); - // Insert unknown division for local identifier + // Insert unknown division for local identifier. divisions.insertRow(pos); divDenom.insert(divDenom.begin() + pos, llvm::None); } @@ -906,7 +906,7 @@ numLocalsEliminated = std::min( idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds()); - // Remove local identifier representations from [0, numLocalsEliminated] + // Remove local identifier representations from [0, numLocalsEliminated]. divisions.removeRows(0, numLocalsEliminated); divDenom.erase(divDenom.begin(), divDenom.begin() + numLocalsEliminated); } @@ -2163,7 +2163,7 @@ assert(dividend.size() == getNumCols() && "incorrect dividend size"); assert(divisor > 0 && "positive divisor expected"); - for (unsigned i = 0, e = getNumIds(); i < e; ++i) + for (unsigned i = 0, e = getNumIds(); i <= e; ++i) atDiv(pos, i) = dividend[i]; atDivDenom(pos) = divisor; } @@ -2173,7 +2173,7 @@ } void FlatAffineConstraints::clearAllDivisions() { - for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) + for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) clearDivision(i); } @@ -2190,23 +2190,21 @@ addLocalId(getNumLocalIds()); // Add two constraints for this new identifier 'q'. - SmallVector bound(dividend.size() + 1); + SmallVector bound(dividend.size() + 1, 0); - // dividend >= 0 + // bound = dividend, with zero as coefficient of the new local identifier. std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1, bound.begin()); bound.back() = dividend.back(); - bound[getNumIds() - 1] = 0; - // Set explicit representation for local identifier - // "bound" currently exactly represents the divisions expression + // Set explicit representation for local identifier. setDivision(getNumLocalIds() - 1, bound, divisor); // dividend - q * divisor >= 0 bound[getNumIds() - 1] = -divisor; addInequality(bound); - // -dividend +qdivisor * q + divisor - 1 >= 0 + // divisor * q + divisor - 1 >= 0 std::transform(bound.begin(), bound.end(), bound.begin(), std::negate()); bound[bound.size() - 1] += divisor - 1; @@ -2247,11 +2245,18 @@ std::swap(getId(posA), getId(posB)); - // Swap division representation if both local identifiers - int64_t localIdOffset = getNumDimAndSymbolIds(); + // Swap division representation if both identifiers are local. + // Clear division representation if one of them is a local identifier and + // other is not. + unsigned localIdOffset = getNumDimAndSymbolIds(); if (posA >= localIdOffset && posB >= localIdOffset) { std::swap(divDenom[posA - localIdOffset], divDenom[posB - localIdOffset]); divisions.swapRows(posA - localIdOffset, posB - localIdOffset); + } else { + if (posA < localIdOffset) + clearDivision(posA); + if (posB < localIdOffset) + clearDivision(posB); } } @@ -2627,7 +2632,7 @@ os << "["; for (unsigned j = 0, f = getNumCols() - 1; j < f; ++j) - os << atDiv(i, j) << " + "; + os << atDiv(i, j) << " "; os << atDiv(i, getNumCols() - 1); os << "] / "; diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -590,30 +590,62 @@ TEST(FlatAffineConstraintsTest, testDivisions) { FlatAffineConstraints fac = makeFACFromConstraints(1, {}, {}); - // Add inequality x >= 0 + // Add inequality x >= 0. fac.addInequality({1, 0}); EXPECT_EQ(fac.atIneq(0, 0), 1); EXPECT_EQ(fac.atIneq(0, 1), 0); - // Add division q = [(x)/4] + // Add division q = [(x)/4]. fac.addLocalFloorDiv({1, 0}, 4); - + EXPECT_NE(fac.atDivDenom(0), llvm::None); EXPECT_EQ(*fac.atDivDenom(0), 4); EXPECT_EQ(fac.atDiv(0, 0), 1); EXPECT_EQ(fac.atDiv(0, 1), 0); + EXPECT_EQ(fac.atDiv(0, 2), 0); - // Add local identifier p + // Add local identifier p. fac.addLocalId(1); EXPECT_EQ(fac.atDivDenom(1), llvm::None); - // Test appending checks divisions + // Set division representation to local identifier at pos 1. + fac.setDivision(1, {0, 1, 0, 0}, 2); + EXPECT_NE(fac.atDivDenom(1), llvm::None); + EXPECT_EQ(*fac.atDivDenom(1), 2); + EXPECT_EQ(fac.atDiv(1, 0), 0); + EXPECT_EQ(fac.atDiv(1, 1), 1); + EXPECT_EQ(fac.atDiv(1, 2), 0); + EXPECT_EQ(fac.atDiv(1, 3), 0); + + // Add local identifier at pos 0. + fac.addLocalId(0); + EXPECT_EQ(fac.atDivDenom(0), llvm::None); + EXPECT_EQ(*fac.atDivDenom(1), 4); + EXPECT_EQ(*fac.atDivDenom(2), 2); + + // Remove local identifier at 0. + fac.removeId(0 + fac.getNumDimAndSymbolIds()); + EXPECT_EQ(*fac.atDivDenom(0), 4); + EXPECT_EQ(*fac.atDivDenom(1), 2); + + // Clear division representation at local identifier at pos 1. + fac.removeId(1 + fac.getNumDimAndSymbolIds()); + EXPECT_EQ(fac.getNumLocalIds(), 1u); + + // Add local identifier at position 1. + fac.addLocalId(1); + EXPECT_EQ(fac.atDivDenom(1), llvm::None); + + // Match the division for testing of upcoming append test. + fac.setDivision(0, {1, 0, 0, 0}, 4); + + // Test appending divisions. FlatAffineConstraints fac2 = makeFACFromConstraints(1, {}, {}); fac2.addInequality({-1, 0}); - EXPECT_EQ(fac.atIneq(0, 0), 1); - EXPECT_EQ(fac.atIneq(0, 1), 0); + EXPECT_EQ(fac2.atIneq(0, 0), -1); + EXPECT_EQ(fac2.atIneq(0, 1), 0); fac2.addLocalFloorDiv({1, 0}, 4); - + EXPECT_NE(fac2.atDivDenom(0), llvm::None); EXPECT_EQ(*fac2.atDivDenom(0), 4); EXPECT_EQ(fac2.atDiv(0, 0), 1); EXPECT_EQ(fac2.atDiv(0, 1), 0); @@ -623,22 +655,48 @@ EXPECT_EQ(fac2.atDivDenom(1), llvm::None); fac.append(fac2); + EXPECT_NE(fac.atDivDenom(0), llvm::None); EXPECT_EQ(*fac.atDivDenom(0), 4); EXPECT_EQ(fac.atDiv(0, 0), 1); EXPECT_EQ(fac.atDiv(0, 1), 0); EXPECT_EQ(fac.atDivDenom(1), llvm::None); - // Test swapping swaps divisions + // Test swapping swaps divisions. fac.swapId(1, 2); EXPECT_EQ(fac.atDivDenom(0), llvm::None); EXPECT_EQ(*fac.atDivDenom(1), 4); EXPECT_EQ(fac.atDiv(1, 0), 1); EXPECT_EQ(fac.atDiv(1, 1), 0); - // Test removal of division + // Test removal of division. fac.removeId(2); - EXPECT_EQ(fac.getNumLocalIds(), (unsigned)1); + EXPECT_EQ(fac.getNumLocalIds(), 1u); + // Test division with constant terms. + fac.addLocalFloorDiv({1, 1, 1}, 5); + EXPECT_NE(fac.atDivDenom(1), llvm::None); + EXPECT_EQ(*fac.atDivDenom(1), 5); + EXPECT_EQ(fac.atDiv(1, 0), 1); + EXPECT_EQ(fac.atDiv(1, 1), 1); + EXPECT_EQ(fac.atDiv(1, 2), 0); + EXPECT_EQ(fac.atDiv(1, 3), 1); + + fac.addLocalFloorDiv({1, 0, 0, 5}, 7); + EXPECT_NE(fac.atDivDenom(2), llvm::None); + EXPECT_EQ(*fac.atDivDenom(2), 7); + EXPECT_EQ(fac.atDiv(2, 0), 1); + EXPECT_EQ(fac.atDiv(2, 1), 0); + EXPECT_EQ(fac.atDiv(2, 2), 0); + EXPECT_EQ(fac.atDiv(2, 3), 0); + EXPECT_EQ(fac.atDiv(2, 4), 5); + + // Test removing division between 2 divisions. + fac.removeId(1 + fac.getNumDimAndSymbolIds()); + EXPECT_EQ(fac.getNumLocalIds(), 2u); + EXPECT_EQ(*fac.atDivDenom(0), 5); + EXPECT_EQ(*fac.atDivDenom(1), 7); + + // Test division representation clearing. fac.clearDivision(0); EXPECT_EQ(fac.atDivDenom(0), llvm::None); }