diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h --- a/mlir/include/mlir/Analysis/Presburger/Fraction.h +++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h @@ -14,6 +14,7 @@ #ifndef MLIR_ANALYSIS_PRESBURGER_FRACTION_H #define MLIR_ANALYSIS_PRESBURGER_FRACTION_H +#include "mlir/Analysis/Presburger/MPInt.h" #include "mlir/Support/MathExtras.h" namespace mlir { @@ -29,30 +30,34 @@ Fraction() = default; /// Construct a Fraction from a numerator and denominator. - Fraction(int64_t oNum, int64_t oDen) : num(oNum), den(oDen) { + Fraction(const MPInt &oNum, const MPInt &oDen) : num(oNum), den(oDen) { if (den < 0) { num = -num; den = -den; } } + /// Overloads for passing literals. + Fraction(const MPInt &num, int64_t den) : Fraction(num, MPInt(den)) {} + Fraction(int64_t num, const MPInt &den) : Fraction(MPInt(num), den) {} + Fraction(int64_t num, int64_t den) : Fraction(MPInt(num), MPInt(den)) {} // Return the value of the fraction as an integer. This should only be called // when the fraction's value is really an integer. - int64_t getAsInteger() const { + MPInt getAsInteger() const { assert(num % den == 0 && "Get as integer called on non-integral fraction!"); return num / den; } /// The numerator and denominator, respectively. The denominator is always /// positive. - int64_t num{0}, den{1}; + MPInt num{0}, den{1}; }; /// Three-way comparison between two fractions. /// Returns +1, 0, and -1 if the first fraction is greater than, equal to, or /// less than the second fraction, respectively. -inline int compare(Fraction x, Fraction y) { - int64_t diff = x.num * y.den - y.num * x.den; +inline int compare(const Fraction &x, const Fraction &y) { + MPInt diff = x.num * y.den - y.num * x.den; if (diff > 0) return +1; if (diff < 0) @@ -60,25 +65,37 @@ return 0; } -inline int64_t floor(Fraction f) { return floorDiv(f.num, f.den); } +inline MPInt floor(const Fraction &f) { return floorDiv(f.num, f.den); } -inline int64_t ceil(Fraction f) { return ceilDiv(f.num, f.den); } +inline MPInt ceil(const Fraction &f) { return ceilDiv(f.num, f.den); } -inline Fraction operator-(Fraction x) { return Fraction(-x.num, x.den); } +inline Fraction operator-(const Fraction &x) { return Fraction(-x.num, x.den); } -inline bool operator<(Fraction x, Fraction y) { return compare(x, y) < 0; } +inline bool operator<(const Fraction &x, const Fraction &y) { + return compare(x, y) < 0; +} -inline bool operator<=(Fraction x, Fraction y) { return compare(x, y) <= 0; } +inline bool operator<=(const Fraction &x, const Fraction &y) { + return compare(x, y) <= 0; +} -inline bool operator==(Fraction x, Fraction y) { return compare(x, y) == 0; } +inline bool operator==(const Fraction &x, const Fraction &y) { + return compare(x, y) == 0; +} -inline bool operator!=(Fraction x, Fraction y) { return compare(x, y) != 0; } +inline bool operator!=(const Fraction &x, const Fraction &y) { + return compare(x, y) != 0; +} -inline bool operator>(Fraction x, Fraction y) { return compare(x, y) > 0; } +inline bool operator>(const Fraction &x, const Fraction &y) { + return compare(x, y) > 0; +} -inline bool operator>=(Fraction x, Fraction y) { return compare(x, y) >= 0; } +inline bool operator>=(const Fraction &x, const Fraction &y) { + return compare(x, y) >= 0; +} -inline Fraction operator*(Fraction x, Fraction y) { +inline Fraction operator*(const Fraction &x, const Fraction &y) { return Fraction(x.num * y.num, x.den * y.den); } diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -133,14 +133,24 @@ bool isSubsetOf(const IntegerRelation &other) const; /// Returns the value at the specified equality row and column. - inline int64_t atEq(unsigned i, unsigned j) const { return equalities(i, j); } - inline int64_t &atEq(unsigned i, unsigned j) { return equalities(i, j); } + inline MPInt atEq(unsigned i, unsigned j) const { return equalities(i, j); } + /// The same, but casts to int64_t. This is unsafe and will assert-fail if the + /// value does not fit in an int64_t. + inline int64_t atEq64(unsigned i, unsigned j) const { + return int64_t(equalities(i, j)); + } + inline MPInt &atEq(unsigned i, unsigned j) { return equalities(i, j); } /// Returns the value at the specified inequality row and column. - inline int64_t atIneq(unsigned i, unsigned j) const { + inline MPInt atIneq(unsigned i, unsigned j) const { return inequalities(i, j); } - inline int64_t &atIneq(unsigned i, unsigned j) { return inequalities(i, j); } + /// The same, but casts to int64_t. This is unsafe and will assert-fail if the + /// value does not fit in an int64_t. + inline int64_t atIneq64(unsigned i, unsigned j) const { + return int64_t(inequalities(i, j)); + } + inline MPInt &atIneq(unsigned i, unsigned j) { return inequalities(i, j); } unsigned getNumConstraints() const { return getNumInequalities() + getNumEqualities(); @@ -174,13 +184,20 @@ return inequalities.getNumReservedRows(); } - inline ArrayRef getEquality(unsigned idx) const { + inline ArrayRef getEquality(unsigned idx) const { return equalities.getRow(idx); } - - inline ArrayRef getInequality(unsigned idx) const { + inline ArrayRef getInequality(unsigned idx) const { return inequalities.getRow(idx); } + /// The same, but casts to int64_t. This is unsafe and will assert-fail if the + /// value does not fit in an int64_t. + inline SmallVector getEquality64(unsigned idx) const { + return getInt64Vec(equalities.getRow(idx)); + } + inline SmallVector getInequality64(unsigned idx) const { + return getInt64Vec(inequalities.getRow(idx)); + } /// Get the number of vars of the specified kind. unsigned getNumVarKind(VarKind kind) const { @@ -245,9 +262,13 @@ unsigned appendVar(VarKind kind, unsigned num = 1); /// Adds an inequality (>= 0) from the coefficients specified in `inEq`. - void addInequality(ArrayRef inEq); + void addInequality(ArrayRef inEq); + void addInequality(ArrayRef inEq) { + addInequality(getMPIntVec(inEq)); + } /// Adds an equality from the coefficients specified in `eq`. - void addEquality(ArrayRef eq); + void addEquality(ArrayRef eq); + void addEquality(ArrayRef eq) { addEquality(getMPIntVec(eq)); } /// Eliminate the `posB^th` local variable, replacing every instance of it /// with the `posA^th` local variable. This should be used when the two @@ -282,7 +303,7 @@ /// For a generic integer sampling operation, findIntegerSample is more /// robust and should be preferred. Note that Domain is minimized first, then /// range. - MaybeOptimum> findIntegerLexMin() const; + MaybeOptimum> findIntegerLexMin() const; /// Swap the posA^th variable with the posB^th variable. virtual void swapVar(unsigned posA, unsigned posB); @@ -292,7 +313,10 @@ /// Sets the `values.size()` variables starting at `po`s to the specified /// values and removes them. - void setAndEliminate(unsigned pos, ArrayRef values); + void setAndEliminate(unsigned pos, ArrayRef values); + void setAndEliminate(unsigned pos, ArrayRef values) { + setAndEliminate(pos, getMPIntVec(values)); + } /// Replaces the contents of this IntegerRelation with `other`. virtual void clearAndCopyFrom(const IntegerRelation &other); @@ -337,20 +361,27 @@ /// /// Returns an integer sample point if one exists, or an empty Optional /// otherwise. The returned value also includes values of local ids. - Optional> findIntegerSample() const; + Optional> findIntegerSample() const; /// Compute an overapproximation of the number of integer points in the /// relation. Symbol vars currently not supported. If the computed /// overapproximation is infinite, an empty optional is returned. - Optional computeVolume() const; + Optional computeVolume() const; /// Returns true if the given point satisfies the constraints, or false /// otherwise. Takes the values of all vars including locals. - bool containsPoint(ArrayRef point) const; + bool containsPoint(ArrayRef point) const; + bool containsPoint(ArrayRef point) const { + return containsPoint(getMPIntVec(point)); + } /// Given the values of non-local vars, return a satisfying assignment to the /// local if one exists, or an empty optional otherwise. - Optional> - containsPointNoLocal(ArrayRef point) const; + Optional> + containsPointNoLocal(ArrayRef point) const; + Optional> + containsPointNoLocal(ArrayRef point) const { + return containsPointNoLocal(getMPIntVec(point)); + } /// Returns a `DivisonRepr` representing the division representation of local /// variables in the constraint system. @@ -367,17 +398,26 @@ enum BoundType { EQ, LB, UB }; /// Adds a constant bound for the specified variable. - void addBound(BoundType type, unsigned pos, int64_t value); + void addBound(BoundType type, unsigned pos, const MPInt &value); + void addBound(BoundType type, unsigned pos, int64_t value) { + addBound(type, pos, MPInt(value)); + } /// Adds a constant bound for the specified expression. - void addBound(BoundType type, ArrayRef expr, int64_t value); + void addBound(BoundType type, ArrayRef expr, const MPInt &value); + void addBound(BoundType type, ArrayRef expr, int64_t value) { + addBound(type, getMPIntVec(expr), MPInt(value)); + } /// Adds a new local variable as the floordiv of an affine function of other /// variables, the coefficients of which are provided in `dividend` and with /// respect to a positive constant `divisor`. Two constraints are added to the /// system to capture equivalence with the floordiv: /// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1. - void addLocalFloorDiv(ArrayRef dividend, int64_t divisor); + void addLocalFloorDiv(ArrayRef dividend, const MPInt &divisor); + void addLocalFloorDiv(ArrayRef dividend, int64_t divisor) { + addLocalFloorDiv(getMPIntVec(dividend), MPInt(divisor)); + } /// Projects out (aka eliminates) `num` variables starting at position /// `pos`. The resulting constraint system is the shadow along the dimensions @@ -432,15 +472,38 @@ /// lower bound is [(s0 + s2 - 1) floordiv 32] for a system with three /// symbolic variables, *lb = [1, 0, 1], lbDivisor = 32. See comments at /// function definition for examples. - Optional getConstantBoundOnDimSize( + Optional getConstantBoundOnDimSize( + unsigned pos, SmallVectorImpl *lb = nullptr, + MPInt *boundFloorDivisor = nullptr, SmallVectorImpl *ub = nullptr, + unsigned *minLbPos = nullptr, unsigned *minUbPos = nullptr) const; + /// The same, but casts to int64_t. This is unsafe and will assert-fail if the + /// value does not fit in an int64_t. + Optional getConstantBoundOnDimSize64( unsigned pos, SmallVectorImpl *lb = nullptr, int64_t *boundFloorDivisor = nullptr, SmallVectorImpl *ub = nullptr, unsigned *minLbPos = nullptr, - unsigned *minUbPos = nullptr) const; + unsigned *minUbPos = nullptr) const { + SmallVector ubMPInt, lbMPInt; + MPInt boundFloorDivisorMPInt; + Optional result = getConstantBoundOnDimSize( + pos, &lbMPInt, &boundFloorDivisorMPInt, &ubMPInt, minLbPos, minUbPos); + if (lb) + *lb = getInt64Vec(lbMPInt); + if (ub) + *ub = getInt64Vec(ubMPInt); + if (boundFloorDivisor) + *boundFloorDivisor = int64_t(boundFloorDivisorMPInt); + return result.map(int64FromMPInt); + } /// Returns the constant bound for the pos^th variable if there is one; /// None otherwise. - Optional getConstantBound(BoundType type, unsigned pos) const; + Optional getConstantBound(BoundType type, unsigned pos) const; + /// The same, but casts to int64_t. This is unsafe and will assert-fail if the + /// value does not fit in an int64_t. + Optional getConstantBound64(BoundType type, unsigned pos) const { + return getConstantBound(type, pos).map(int64FromMPInt); + } /// Removes constraints that are independent of (i.e., do not have a /// coefficient) variables in the range [pos, pos + num). @@ -619,7 +682,13 @@ /// Returns the constant lower bound bound if isLower is true, and the upper /// bound if isLower is false. template - Optional computeConstantLowerOrUpperBound(unsigned pos); + Optional computeConstantLowerOrUpperBound(unsigned pos); + /// The same, but casts to int64_t. This is unsafe and will assert-fail if the + /// value does not fit in an int64_t. + template + Optional computeConstantLowerOrUpperBound64(unsigned pos) { + return computeConstantLowerOrUpperBound(pos).map(int64FromMPInt); + } /// Eliminates a single variable at `position` from equality and inequality /// constraints. Returns `success` if the variable was eliminated, and diff --git a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h --- a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h +++ b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h @@ -40,14 +40,13 @@ // The given vector is interpreted as a row vector v. Post-multiply v with // this transform, say T, and return vT. - SmallVector preMultiplyWithRow(ArrayRef rowVec) const { + SmallVector preMultiplyWithRow(ArrayRef rowVec) const { return matrix.preMultiplyWithRow(rowVec); } // The given vector is interpreted as a column vector v. Pre-multiply v with // this transform, say T, and return Tv. - SmallVector - postMultiplyWithColumn(ArrayRef colVec) const { + SmallVector postMultiplyWithColumn(ArrayRef colVec) const { return matrix.postMultiplyWithColumn(colVec); } diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h --- a/mlir/include/mlir/Analysis/Presburger/Matrix.h +++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h @@ -14,6 +14,7 @@ #ifndef MLIR_ANALYSIS_PRESBURGER_MATRIX_H #define MLIR_ANALYSIS_PRESBURGER_MATRIX_H +#include "mlir/Analysis/Presburger/MPInt.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/raw_ostream.h" @@ -48,21 +49,21 @@ static Matrix identity(unsigned dimension); /// Access the element at the specified row and column. - int64_t &at(unsigned row, unsigned column) { + MPInt &at(unsigned row, unsigned column) { assert(row < nRows && "Row outside of range"); assert(column < nColumns && "Column outside of range"); return data[row * nReservedColumns + column]; } - int64_t at(unsigned row, unsigned column) const { + MPInt at(unsigned row, unsigned column) const { assert(row < nRows && "Row outside of range"); assert(column < nColumns && "Column outside of range"); return data[row * nReservedColumns + column]; } - int64_t &operator()(unsigned row, unsigned column) { return at(row, column); } + MPInt &operator()(unsigned row, unsigned column) { return at(row, column); } - int64_t operator()(unsigned row, unsigned column) const { + MPInt operator()(unsigned row, unsigned column) const { return at(row, column); } @@ -86,11 +87,11 @@ void reserveRows(unsigned rows); /// Get a [Mutable]ArrayRef corresponding to the specified row. - MutableArrayRef getRow(unsigned row); - ArrayRef getRow(unsigned row) const; + MutableArrayRef getRow(unsigned row); + ArrayRef getRow(unsigned row) const; /// Set the specified row to `elems`. - void setRow(unsigned row, ArrayRef elems); + void setRow(unsigned row, ArrayRef elems); /// Insert columns having positions pos, pos + 1, ... pos + count - 1. /// Columns that were at positions 0 to pos - 1 will stay where they are; @@ -124,13 +125,22 @@ void copyRow(unsigned sourceRow, unsigned targetRow); - void fillRow(unsigned row, int64_t value); + void fillRow(unsigned row, const MPInt &value); + void fillRow(unsigned row, int64_t value) { fillRow(row, MPInt(value)); } /// Add `scale` multiples of the source row to the target row. - void addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale); + void addToRow(unsigned sourceRow, unsigned targetRow, const MPInt &scale); + void addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale) { + addToRow(sourceRow, targetRow, MPInt(scale)); + } /// Add `scale` multiples of the source column to the target column. - void addToColumn(unsigned sourceColumn, unsigned targetColumn, int64_t scale); + void addToColumn(unsigned sourceColumn, unsigned targetColumn, + const MPInt &scale); + void addToColumn(unsigned sourceColumn, unsigned targetColumn, + int64_t scale) { + addToColumn(sourceColumn, targetColumn, MPInt(scale)); + } /// Negate the specified column. void negateColumn(unsigned column); @@ -140,19 +150,18 @@ /// Divide the first `nCols` of the specified row by their GCD. /// Returns the GCD of the first `nCols` of the specified row. - int64_t normalizeRow(unsigned row, unsigned nCols); + MPInt normalizeRow(unsigned row, unsigned nCols); /// Divide the columns of the specified row by their GCD. /// Returns the GCD of the columns of the specified row. - int64_t normalizeRow(unsigned row); + MPInt normalizeRow(unsigned row); /// The given vector is interpreted as a row vector v. Post-multiply v with /// this matrix, say M, and return vM. - SmallVector preMultiplyWithRow(ArrayRef rowVec) const; + SmallVector preMultiplyWithRow(ArrayRef rowVec) const; /// The given vector is interpreted as a column vector v. Pre-multiply v with /// this matrix, say M, and return Mv. - SmallVector - postMultiplyWithColumn(ArrayRef colVec) const; + SmallVector postMultiplyWithColumn(ArrayRef colVec) const; /// Resize the matrix to the specified dimensions. If a dimension is smaller, /// the values are truncated; if it is bigger, the new values are initialized @@ -169,7 +178,7 @@ unsigned appendExtraRow(); /// Same as above, but copy the given elements into the row. The length of /// `elems` must be equal to the number of columns. - unsigned appendExtraRow(ArrayRef elems); + unsigned appendExtraRow(ArrayRef elems); /// Print the matrix. void print(raw_ostream &os) const; @@ -188,7 +197,7 @@ /// Stores the data. data.size() is equal to nRows * nReservedColumns. /// data.capacity() / nReservedColumns is the number of reserved rows. - SmallVector data; + SmallVector data; }; } // namespace presburger diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -62,7 +62,7 @@ /// Get a matrix with each row representing row^th output expression. const Matrix &getOutputMatrix() const { return output; } /// Get the `i^th` output expression. - ArrayRef getOutputExpr(unsigned i) const { return output.getRow(i); } + ArrayRef getOutputExpr(unsigned i) const { return output.getRow(i); } /// Insert `num` variables of the specified kind at position `pos`. /// Positions are relative to the kind of variable. The coefficient columns @@ -91,7 +91,10 @@ /// Get the value of the function at the specified point. If the point lies /// outside the domain, an empty optional is returned. - Optional> valueAt(ArrayRef point) const; + Optional> valueAt(ArrayRef point) const; + Optional> valueAt(ArrayRef point) const { + return valueAt(getMPIntVec(point)); + } /// Truncate the output dimensions to the first `count` dimensions. /// @@ -159,7 +162,10 @@ /// Return the value at the specified point and an empty optional if the /// point does not lie in the domain. - Optional> valueAt(ArrayRef point) const; + Optional> valueAt(ArrayRef point) const; + Optional> valueAt(ArrayRef point) const { + return valueAt(getMPIntVec(point)); + } /// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether /// they have the same dimensions, the same domain and they take the same diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -83,7 +83,10 @@ PresburgerRelation intersect(const PresburgerRelation &set) const; /// Return true if the set contains the given point, and false otherwise. - bool containsPoint(ArrayRef point) const; + bool containsPoint(ArrayRef point) const; + bool containsPoint(ArrayRef point) const { + return containsPoint(getMPIntVec(point)); + } /// Return the complement of this set. All local variables in the set must /// correspond to floor divisions. @@ -108,7 +111,7 @@ /// Find an integer sample from the given set. This should not be called if /// any of the disjuncts in the union are unbounded. - bool findIntegerSample(SmallVectorImpl &sample); + bool findIntegerSample(SmallVectorImpl &sample); /// Compute an overapproximation of the number of integer points in the /// disjunct. Symbol vars are currently not supported. If the computed @@ -117,7 +120,7 @@ /// This currently just sums up the overapproximations of the volumes of the /// disjuncts, so the approximation might be far from the true volume in the /// case when there is a lot of overlap between disjuncts. - Optional computeVolume() const; + Optional computeVolume() const; /// Simplifies the representation of a PresburgerRelation. /// diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -166,7 +166,7 @@ /// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n /// is the current number of variables, then the corresponding inequality is /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0. - virtual void addInequality(ArrayRef coeffs) = 0; + virtual void addInequality(ArrayRef coeffs) = 0; /// Returns the number of variables in the tableau. unsigned getNumVariables() const; @@ -177,14 +177,16 @@ /// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n /// is the current number of variables, then the corresponding equality is /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0. - void addEquality(ArrayRef coeffs); + void addEquality(ArrayRef coeffs); /// Add new variables to the end of the list of variables. void appendVariable(unsigned count = 1); /// Append a new variable to the simplex and constrain it such that its only /// integer value is the floor div of `coeffs` and `denom`. - void addDivisionVariable(ArrayRef coeffs, int64_t denom); + /// + /// `denom` must be positive. + void addDivisionVariable(ArrayRef coeffs, const MPInt &denom); /// Mark the tableau as being empty. void markEmpty(); @@ -293,7 +295,7 @@ /// con. /// /// Returns the index of the new Unknown in con. - unsigned addRow(ArrayRef coeffs, bool makeRestricted = false); + unsigned addRow(ArrayRef coeffs, bool makeRestricted = false); /// Swap the two rows/columns in the tableau and associated data structures. void swapRows(unsigned i, unsigned j); @@ -421,7 +423,7 @@ /// /// This just adds the inequality to the tableau and does not try to create a /// consistent tableau configuration. - void addInequality(ArrayRef coeffs) final; + void addInequality(ArrayRef coeffs) final; /// Get a snapshot of the current state. This is used for rolling back. unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); } @@ -493,15 +495,15 @@ /// /// Note: this should be used only when the lexmin is really needed. To obtain /// any integer sample, use Simplex::findIntegerSample as that is more robust. - MaybeOptimum> findIntegerLexMin(); + MaybeOptimum> findIntegerLexMin(); /// Return whether the specified inequality is redundant/separate for the /// polytope. Redundant means every point satisfies the given inequality, and /// separate means no point satisfies it. /// /// These checks are integer-exact. - bool isSeparateInequality(ArrayRef coeffs); - bool isRedundantInequality(ArrayRef coeffs); + bool isSeparateInequality(ArrayRef coeffs); + bool isRedundantInequality(ArrayRef coeffs); private: /// Returns the current sample point, which may contain non-integer (rational) @@ -654,11 +656,11 @@ /// Get the numerator of the symbolic sample of the specific row. /// This is an affine expression in the symbols with integer coefficients. /// The last element is the constant term. This ignores the big M coefficient. - SmallVector getSymbolicSampleNumerator(unsigned row) const; + SmallVector getSymbolicSampleNumerator(unsigned row) const; /// Get an affine inequality in the symbols with integer coefficients that /// holds iff the symbolic sample of the specified row is non-negative. - SmallVector getSymbolicSampleIneq(unsigned row) const; + SmallVector getSymbolicSampleIneq(unsigned row) const; /// Return whether all the coefficients of the symbolic sample are integers. /// @@ -708,7 +710,7 @@ /// /// This also tries to restore the tableau configuration to a consistent /// state and marks the Simplex empty if this is not possible. - void addInequality(ArrayRef coeffs) final; + void addInequality(ArrayRef coeffs) final; /// Compute the maximum or minimum value of the given row, depending on /// direction. The specified row is never pivoted. On return, the row may @@ -724,7 +726,7 @@ /// Returns a Fraction denoting the optimum, or a null value if no optimum /// exists, i.e., if the expression is unbounded in this direction. MaybeOptimum computeOptimum(Direction direction, - ArrayRef coeffs); + ArrayRef coeffs); /// Returns whether the perpendicular of the specified constraint is a /// is a direction along which the polytope is bounded. @@ -766,8 +768,8 @@ /// Returns a (min, max) pair denoting the minimum and maximum integer values /// of the given expression. If no integer value exists, both results will be /// of kind Empty. - std::pair, MaybeOptimum> - computeIntegerBounds(ArrayRef coeffs); + std::pair, MaybeOptimum> + computeIntegerBounds(ArrayRef coeffs); /// Returns true if the polytope is unbounded, i.e., extends to infinity in /// some direction. Otherwise, returns false. @@ -779,7 +781,7 @@ /// Returns an integer sample point if one exists, or None /// otherwise. This should only be called for bounded sets. - Optional> findIntegerSample(); + Optional> findIntegerSample(); enum class IneqType { Redundant, Cut, Separate }; @@ -789,13 +791,13 @@ /// Redundant The inequality is satisfied in the polytope /// Cut The inequality is satisfied by some points, but not by others /// Separate The inequality is not satisfied by any point - IneqType findIneqType(ArrayRef coeffs); + IneqType findIneqType(ArrayRef coeffs); /// Check if the specified inequality already holds in the polytope. - bool isRedundantInequality(ArrayRef coeffs); + bool isRedundantInequality(ArrayRef coeffs); /// Check if the specified equality already holds in the polytope. - bool isRedundantEquality(ArrayRef coeffs); + bool isRedundantEquality(ArrayRef coeffs); /// Returns true if this Simplex's polytope is a rational subset of `rel`. /// Otherwise, returns false. @@ -803,7 +805,7 @@ /// Returns the current sample point if it is integral. Otherwise, returns /// None. - Optional> getSamplePointIfIntegral() const; + Optional> getSamplePointIfIntegral() const; /// Returns the current sample point, which may contain non-integer (rational) /// coordinates. Returns an empty optional when the tableau is empty. diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h --- a/mlir/include/mlir/Analysis/Presburger/Utils.h +++ b/mlir/include/mlir/Analysis/Presburger/Utils.h @@ -109,14 +109,15 @@ /// system. The coefficients of the dividends are stored in order: /// [nonLocalVars, localVars, constant]. Each local variable may or may not have /// a representation. If the local does not have a representation, the dividend -/// of the division has no meaning and the denominator is zero. +/// of the division has no meaning and the denominator is zero. If it has a +/// representation, the denominator will be positive. /// /// The i^th division here, represents the division representation of the /// variable at position `divOffset + i` in the constraint system. class DivisionRepr { public: DivisionRepr(unsigned numVars, unsigned numDivs) - : dividends(numDivs, numVars + 1), denoms(numDivs, 0) {} + : dividends(numDivs, numVars + 1), denoms(numDivs, MPInt(0)) {} DivisionRepr(unsigned numVars) : dividends(numVars + 1, 0) {} @@ -130,27 +131,23 @@ bool hasRepr(unsigned i) const { return denoms[i] != 0; } // Check whether all the divisions have a division representation or not. bool hasAllReprs() const { - return all_of(denoms, [](unsigned denom) { return denom != 0; }); + return all_of(denoms, [](const MPInt &denom) { return denom != 0; }); } // Clear the division representation of the i^th local variable. void clearRepr(unsigned i) { denoms[i] = 0; } // Get the dividend of the `i^th` division. - MutableArrayRef getDividend(unsigned i) { - return dividends.getRow(i); - } - ArrayRef getDividend(unsigned i) const { - return dividends.getRow(i); - } + MutableArrayRef getDividend(unsigned i) { return dividends.getRow(i); } + ArrayRef getDividend(unsigned i) const { return dividends.getRow(i); } // Get the `i^th` denominator. - unsigned &getDenom(unsigned i) { return denoms[i]; } - unsigned getDenom(unsigned i) const { return denoms[i]; } + MPInt &getDenom(unsigned i) { return denoms[i]; } + MPInt getDenom(unsigned i) const { return denoms[i]; } - ArrayRef getDenoms() const { return denoms; } + ArrayRef getDenoms() const { return denoms; } - void setDividend(unsigned i, ArrayRef dividend) { + void setDividend(unsigned i, ArrayRef dividend) { dividends.setRow(i, dividend); } @@ -176,7 +173,8 @@ /// Denominators of each division. If a denominator of a division is `0`, the /// division variable is considered to not have a division representation. - SmallVector denoms; + /// Otherwise, the denominator is positive. + SmallVector denoms; }; /// If `q` is defined to be equal to `expr floordiv d`, this equivalent to @@ -193,10 +191,13 @@ /// /// The coefficient of `q` in `dividend` must be zero, as it is not allowed for /// local variable to be a floor division of an expression involving itself. -SmallVector getDivUpperBound(ArrayRef dividend, - int64_t divisor, unsigned localVarIdx); -SmallVector getDivLowerBound(ArrayRef dividend, - int64_t divisor, unsigned localVarIdx); +/// The divisor must be positive. +SmallVector getDivUpperBound(ArrayRef dividend, + const MPInt &divisor, + unsigned localVarIdx); +SmallVector getDivLowerBound(ArrayRef dividend, + const MPInt &divisor, + unsigned localVarIdx); llvm::SmallBitVector getSubrangeBitVector(unsigned len, unsigned setOffset, unsigned numSet); @@ -209,14 +210,22 @@ SmallVector getMPIntVec(ArrayRef range); /// Return the given array as an array of int64_t. SmallVector getInt64Vec(ArrayRef range); + /// Returns the `MaybeLocalRepr` struct which contains the indices of the /// constraints that can be expressed as a floordiv of an affine function. If -/// the representation could be computed, `dividend` and `denominator` are set. -/// If the representation could not be computed, the kind attribute in -/// `MaybeLocalRepr` is set to None. +/// the representation could be computed, `dividend` and `divisor` are set, +/// in which case, denominator will be positive. If the representation could +/// not be computed, the kind attribute in `MaybeLocalRepr` is set to None. MaybeLocalRepr computeSingleVarRepr(const IntegerRelation &cst, ArrayRef foundRepr, unsigned pos, - MutableArrayRef dividend, + MutableArrayRef dividend, + MPInt &divisor); + +/// The following overload using int64_t is required for a callsite in +/// AffineStructures.h. +MaybeLocalRepr computeSingleVarRepr(const IntegerRelation &cst, + ArrayRef foundRepr, unsigned pos, + SmallVector ÷nd, unsigned &divisor); /// Given two relations, A and B, add additional local vars to the sets such @@ -235,26 +244,25 @@ llvm::function_ref merge); /// Compute the gcd of the range. -int64_t gcdRange(ArrayRef range); +MPInt gcdRange(ArrayRef range); /// Divide the range by its gcd and return the gcd. -int64_t normalizeRange(MutableArrayRef range); +MPInt normalizeRange(MutableArrayRef range); /// Normalize the given (numerator, denominator) pair by dividing out the /// common factors between them. The numerator here is an affine expression -/// with integer coefficients. -void normalizeDiv(MutableArrayRef num, int64_t &denom); +/// with integer coefficients. The denominator must be positive. +void normalizeDiv(MutableArrayRef num, MPInt &denom); /// Return `coeffs` with all the elements negated. -SmallVector getNegatedCoeffs(ArrayRef coeffs); +SmallVector getNegatedCoeffs(ArrayRef coeffs); /// Return the complement of the given inequality. /// /// The complement of a_1 x_1 + ... + a_n x_ + c >= 0 is /// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0, /// since all the variables are constrained to be integers. -SmallVector getComplementIneq(ArrayRef ineq); - +SmallVector getComplementIneq(ArrayRef ineq); } // namespace presburger } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h @@ -317,7 +317,7 @@ SmallVectorImpl *lb = nullptr, int64_t *lbFloorDivisor = nullptr) const { assert(pos < getRank() && "invalid position"); - return cst.getConstantBoundOnDimSize(pos, lb); + return cst.getConstantBoundOnDimSize64(pos, lb); } /// Returns the size of this MemRefRegion in bytes. diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -103,10 +103,9 @@ return maybeLexMin; } -MaybeOptimum> -IntegerRelation::findIntegerLexMin() const { +MaybeOptimum> IntegerRelation::findIntegerLexMin() const { assert(getNumSymbolVars() == 0 && "Symbols are not supported!"); - MaybeOptimum> maybeLexMin = + MaybeOptimum> maybeLexMin = LexSimplex(*this).findIntegerLexMin(); if (!maybeLexMin.isBounded()) @@ -123,8 +122,8 @@ return maybeLexMin; } -static bool rangeIsZero(ArrayRef range) { - return llvm::all_of(range, [](int64_t x) { return x == 0; }); +static bool rangeIsZero(ArrayRef range) { + return llvm::all_of(range, [](const MPInt &x) { return x == 0; }); } static void removeConstraintsInvolvingVarRange(IntegerRelation &poly, @@ -271,14 +270,14 @@ return insertVar(kind, pos, num); } -void IntegerRelation::addEquality(ArrayRef eq) { +void IntegerRelation::addEquality(ArrayRef eq) { assert(eq.size() == getNumCols()); unsigned row = equalities.appendExtraRow(); for (unsigned i = 0, e = eq.size(); i < e; ++i) equalities(row, i) = eq[i]; } -void IntegerRelation::addInequality(ArrayRef inEq) { +void IntegerRelation::addInequality(ArrayRef inEq) { assert(inEq.size() == getNumCols()); unsigned row = inequalities.appendExtraRow(); for (unsigned i = 0, e = inEq.size(); i < e; ++i) @@ -443,7 +442,7 @@ return true; } -void IntegerRelation::setAndEliminate(unsigned pos, ArrayRef values) { +void IntegerRelation::setAndEliminate(unsigned pos, ArrayRef values) { if (values.empty()) return; assert(pos + values.size() <= getNumVars() && @@ -469,7 +468,7 @@ bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq, unsigned *rowIdx) const { assert(colIdx < getNumCols() && "position out of bounds"); - auto at = [&](unsigned rowIdx) -> int64_t { + auto at = [&](unsigned rowIdx) -> MPInt { return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx); }; unsigned e = isEq ? getNumEqualities() : getNumInequalities(); @@ -496,7 +495,7 @@ for (unsigned i = 0, e = numRows; i < e; ++i) { unsigned j; for (j = 0; j < numCols - 1; ++j) { - int64_t v = isEq ? atEq(i, j) : atIneq(i, j); + MPInt v = isEq ? atEq(i, j) : atIneq(i, j); // Skip rows with non-zero variable coefficients. if (v != 0) break; @@ -506,7 +505,7 @@ } // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'. // Example invalid constraints include: '1 == 0' or '-1 >= 0' - int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1); + MPInt v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1); if ((isEq && v != 0) || (!isEq && v < 0)) { return true; } @@ -528,26 +527,26 @@ // Skip if equality 'rowIdx' if same as 'pivotRow'. if (isEq && rowIdx == pivotRow) return; - auto at = [&](unsigned i, unsigned j) -> int64_t { + auto at = [&](unsigned i, unsigned j) -> MPInt { return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j); }; - int64_t leadCoeff = at(rowIdx, pivotCol); + MPInt leadCoeff = at(rowIdx, pivotCol); // Skip if leading coefficient at 'rowIdx' is already zero. if (leadCoeff == 0) return; - int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol); - int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1; - int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff); - int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff)); - int64_t rowMultiplier = lcm / std::abs(leadCoeff); + MPInt pivotCoeff = constraints->atEq(pivotRow, pivotCol); + int sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1; + MPInt lcm = presburger::lcm(pivotCoeff, leadCoeff); + MPInt pivotMultiplier = sign * (lcm / abs(pivotCoeff)); + MPInt rowMultiplier = lcm / abs(leadCoeff); unsigned numCols = constraints->getNumCols(); for (unsigned j = 0; j < numCols; ++j) { // Skip updating column 'j' if it was just eliminated. if (j >= elimColStart && j < pivotCol) continue; - int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) + - rowMultiplier * at(rowIdx, j); + MPInt v = pivotMultiplier * constraints->atEq(pivotRow, j) + + rowMultiplier * at(rowIdx, j); isEq ? constraints->atEq(rowIdx, j) = v : constraints->atIneq(rowIdx, j) = v; } @@ -651,16 +650,15 @@ // has an integer solution iff: // // GCD of c_1, c_2, ..., c_n divides c_0. -// bool IntegerRelation::isEmptyByGCDTest() const { assert(hasConsistentState()); unsigned numCols = getNumCols(); for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { - uint64_t gcd = std::abs(atEq(i, 0)); + MPInt gcd = abs(atEq(i, 0)); for (unsigned j = 1; j < numCols - 1; ++j) { - gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j))); + gcd = presburger::gcd(gcd, abs(atEq(i, j))); } - int64_t v = std::abs(atEq(i, numCols - 1)); + MPInt v = abs(atEq(i, numCols - 1)); if (gcd > 0 && (v % gcd != 0)) { return true; } @@ -763,7 +761,7 @@ /// /// Concatenating the samples from B and C gives a sample v in S*T, so the /// returned sample T*v is a sample in S. -Optional> IntegerRelation::findIntegerSample() const { +Optional> IntegerRelation::findIntegerSample() const { // First, try the GCD test heuristic. if (isEmptyByGCDTest()) return {}; @@ -802,7 +800,7 @@ boundedSet.removeVarRange(numBoundedDims, boundedSet.getNumVars()); // 3) Try to obtain a sample from the bounded set. - Optional> boundedSample = + Optional> boundedSample = Simplex(boundedSet).findIntegerSample(); if (!boundedSample) return {}; @@ -841,7 +839,7 @@ // amount for the shrunken cone. for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) { for (unsigned j = 0; j < cone.getNumVars(); ++j) { - int64_t coeff = cone.atIneq(i, j); + MPInt coeff = cone.atIneq(i, j); if (coeff < 0) cone.atIneq(i, cone.getNumVars()) += coeff; } @@ -858,10 +856,10 @@ SmallVector shrunkenConeSample = *shrunkenConeSimplex.getRationalSample(); - SmallVector coneSample(llvm::map_range(shrunkenConeSample, ceil)); + SmallVector coneSample(llvm::map_range(shrunkenConeSample, ceil)); // 6) Return transform * concat(boundedSample, coneSample). - SmallVector &sample = *boundedSample; + SmallVector &sample = *boundedSample; sample.append(coneSample.begin(), coneSample.end()); return transform.postMultiplyWithColumn(sample); } @@ -869,10 +867,10 @@ /// Helper to evaluate an affine expression at a point. /// The expression is a list of coefficients for the dimensions followed by the /// constant term. -static int64_t valueAt(ArrayRef expr, ArrayRef point) { +static MPInt valueAt(ArrayRef expr, ArrayRef point) { assert(expr.size() == 1 + point.size() && "Dimensionalities of point and expression don't match!"); - int64_t value = expr.back(); + MPInt value = expr.back(); for (unsigned i = 0; i < point.size(); ++i) value += expr[i] * point[i]; return value; @@ -881,7 +879,7 @@ /// A point satisfies an equality iff the value of the equality at the /// expression is zero, and it satisfies an inequality iff the value of the /// inequality at that point is non-negative. -bool IntegerRelation::containsPoint(ArrayRef point) const { +bool IntegerRelation::containsPoint(ArrayRef point) const { for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { if (valueAt(getEquality(i), point) != 0) return false; @@ -901,8 +899,8 @@ /// compute the values of the locals that have division representations and /// only use the integer emptiness check for the locals that don't have this. /// Handling this correctly requires ordering the divs, though. -Optional> -IntegerRelation::containsPointNoLocal(ArrayRef point) const { +Optional> +IntegerRelation::containsPointNoLocal(ArrayRef point) const { assert(point.size() == getNumVars() - getNumLocalVars() && "Point should contain all vars except locals!"); assert(getVarKindOffset(VarKind::Local) == getNumVars() - getNumLocalVars() && @@ -959,9 +957,9 @@ unsigned numCols = getNumCols(); for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { // Normalize the constraint and tighten the constant term by the GCD. - int64_t gcd = inequalities.normalizeRow(i, getNumCols() - 1); + MPInt gcd = inequalities.normalizeRow(i, getNumCols() - 1); if (gcd > 1) - atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcd); + atIneq(i, numCols - 1) = floorDiv(atIneq(i, numCols - 1), gcd); } } @@ -1080,14 +1078,14 @@ equalities.resizeVertically(pos); } -Optional IntegerRelation::computeVolume() const { +Optional IntegerRelation::computeVolume() const { assert(getNumSymbolVars() == 0 && "Symbols are not yet supported!"); Simplex simplex(*this); // If the polytope is rationally empty, there are certainly no integer // points. if (simplex.isEmpty()) - return 0; + return MPInt(0); // Just find the maximum and minimum integer value of each non-local var // separately, thus finding the number of integer values each such var can @@ -1103,12 +1101,12 @@ // // If there is no such empty dimension, if any dimension is unbounded we // just return the result as unbounded. - uint64_t count = 1; - SmallVector dim(getNumVars() + 1); + MPInt count(1); + SmallVector dim(getNumVars() + 1); bool hasUnboundedVar = false; for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i) { dim[i] = 1; - MaybeOptimum min, max; + MaybeOptimum min, max; std::tie(min, max) = simplex.computeIntegerBounds(dim); dim[i] = 0; @@ -1125,13 +1123,13 @@ // In this case there are no valid integer points and the volume is // definitely zero. if (min.getBoundedOptimum() > max.getBoundedOptimum()) - return 0; + return MPInt(0); count *= (*max - *min + 1); } if (count == 0) - return 0; + return MPInt(0); if (hasUnboundedVar) return {}; return count; @@ -1223,7 +1221,7 @@ for (i = 0, e = getNumEqualities(); i < e; ++i) { // Find a local variable to eliminate using ith equality. for (j = getNumDimAndSymbolVars(), f = getNumVars(); j < f; ++j) - if (std::abs(atEq(i, j)) == 1) + if (abs(atEq(i, j)) == 1) break; // Local variable can be eliminated using ith equality. @@ -1281,7 +1279,8 @@ removeVarRange(srcKind, varStart, varLimit); } -void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) { +void IntegerRelation::addBound(BoundType type, unsigned pos, + const MPInt &value) { assert(pos < getNumCols()); if (type == BoundType::EQ) { unsigned row = equalities.appendExtraRow(); @@ -1295,8 +1294,8 @@ } } -void IntegerRelation::addBound(BoundType type, ArrayRef expr, - int64_t value) { +void IntegerRelation::addBound(BoundType type, ArrayRef expr, + const MPInt &value) { assert(type != BoundType::EQ && "EQ not implemented"); assert(expr.size() == getNumCols()); unsigned row = inequalities.appendExtraRow(); @@ -1311,15 +1310,15 @@ /// respect to a positive constant 'divisor'. Two constraints are added to the /// system to capture equivalence with the floordiv. /// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1. -void IntegerRelation::addLocalFloorDiv(ArrayRef dividend, - int64_t divisor) { +void IntegerRelation::addLocalFloorDiv(ArrayRef dividend, + const MPInt &divisor) { assert(dividend.size() == getNumCols() && "incorrect dividend size"); assert(divisor > 0 && "positive divisor expected"); appendVar(VarKind::Local); - SmallVector dividendCopy(dividend.begin(), dividend.end()); - dividendCopy.insert(dividendCopy.end() - 1, 0); + SmallVector dividendCopy(dividend.begin(), dividend.end()); + dividendCopy.insert(dividendCopy.end() - 1, MPInt(0)); addInequality( getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2)); addInequality( @@ -1335,7 +1334,7 @@ bool symbolic = false) { assert(pos < cst.getNumVars() && "invalid position"); for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { - int64_t v = cst.atEq(r, pos); + MPInt v = cst.atEq(r, pos); if (v * v != 1) continue; unsigned c; @@ -1364,7 +1363,7 @@ // atEq(rowIdx, pos) is either -1 or 1. assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); - int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); + MPInt constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); setAndEliminate(pos, constVal); return success(); } @@ -1390,10 +1389,9 @@ // s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. // s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = // ceil(s0 - 7 / 8) = floor(s0 / 8)). -Optional IntegerRelation::getConstantBoundOnDimSize( - unsigned pos, SmallVectorImpl *lb, int64_t *boundFloorDivisor, - SmallVectorImpl *ub, unsigned *minLbPos, - unsigned *minUbPos) const { +Optional IntegerRelation::getConstantBoundOnDimSize( + unsigned pos, SmallVectorImpl *lb, MPInt *boundFloorDivisor, + SmallVectorImpl *ub, unsigned *minLbPos, unsigned *minUbPos) const { assert(pos < getNumDimVars() && "Invalid variable position"); // Find an equality for 'pos'^th variable that equates it to some function @@ -1405,7 +1403,7 @@ // TODO: this can be handled in the future by using the explicit // representation of the local vars. if (!std::all_of(eq.begin() + getNumDimAndSymbolVars(), eq.end() - 1, - [](int64_t coeff) { return coeff == 0; })) + [](const MPInt &coeff) { return coeff == 0; })) return None; // This variable can only take a single value. @@ -1415,7 +1413,7 @@ if (ub) ub->resize(getNumSymbolVars() + 1); for (unsigned c = 0, f = getNumSymbolVars() + 1; c < f; c++) { - int64_t v = atEq(eqPos, pos); + MPInt v = atEq(eqPos, pos); // atEq(eqRow, pos) is either -1 or 1. assert(v * v == 1); (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimVars() + c) / -v @@ -1432,7 +1430,7 @@ *minLbPos = eqPos; if (minUbPos) *minUbPos = eqPos; - return 1; + return MPInt(1); } // Check if the variable appears at all in any of the inequalities. @@ -1456,7 +1454,7 @@ /*eqIndices=*/nullptr, /*offset=*/0, /*num=*/getNumDimVars()); - Optional minDiff = None; + Optional minDiff = None; unsigned minLbPosition = 0, minUbPosition = 0; for (auto ubPos : ubIndices) { for (auto lbPos : lbIndices) { @@ -1473,11 +1471,11 @@ } if (j < getNumCols() - 1) continue; - int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) + - atIneq(lbPos, getNumCols() - 1) + 1, - atIneq(lbPos, pos)); + MPInt diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) + + atIneq(lbPos, getNumCols() - 1) + 1, + atIneq(lbPos, pos)); // This bound is non-negative by definition. - diff = std::max(diff, 0); + diff = std::max(diff, MPInt(0)); if (minDiff == None || diff < minDiff) { minDiff = diff; minLbPosition = lbPos; @@ -1517,7 +1515,7 @@ } template -Optional +Optional IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { assert(pos < getNumVars() && "invalid position"); // Project to 'pos'. @@ -1539,7 +1537,7 @@ // If it doesn't, there isn't a bound on it. return None; - Optional minOrMaxConst = None; + Optional minOrMaxConst = None; // Take the max across all const lower bounds (or min across all constant // upper bounds). @@ -1560,9 +1558,9 @@ // Not a constant bound. continue; - int64_t boundConst = - isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0)) - : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0)); + MPInt boundConst = + isLower ? ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0)) + : floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0)); if (isLower) { if (minOrMaxConst == None || boundConst > minOrMaxConst) minOrMaxConst = boundConst; @@ -1574,8 +1572,8 @@ return minOrMaxConst; } -Optional IntegerRelation::getConstantBound(BoundType type, - unsigned pos) const { +Optional IntegerRelation::getConstantBound(BoundType type, + unsigned pos) const { if (type == BoundType::LB) return IntegerRelation(*this) .computeConstantLowerOrUpperBound(pos); @@ -1584,13 +1582,13 @@ .computeConstantLowerOrUpperBound(pos); assert(type == BoundType::EQ && "expected EQ"); - Optional lb = + Optional lb = IntegerRelation(*this).computeConstantLowerOrUpperBound( pos); - Optional ub = + Optional ub = IntegerRelation(*this) .computeConstantLowerOrUpperBound(pos); - return (lb && ub && *lb == *ub) ? Optional(*ub) : None; + return (lb && ub && *lb == *ub) ? Optional(*ub) : None; } // A simple (naive and conservative) check for hyper-rectangularity. @@ -1631,10 +1629,10 @@ // A map used to detect redundancy stemming from constraints that only differ // in their constant term. The value stored is // for a given row. - SmallDenseMap, std::pair> + SmallDenseMap, std::pair> rowsWithoutConstTerm; // To unique rows. - SmallDenseSet, 8> rowSet; + SmallDenseSet, 8> rowSet; // Check if constraint is of the form >= 0. auto isTriviallyValid = [&](unsigned r) -> bool { @@ -1648,8 +1646,8 @@ // Detect and mark redundant constraints. SmallVector redunIneq(getNumInequalities(), false); for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - int64_t *rowStart = &inequalities(r, 0); - auto row = ArrayRef(rowStart, getNumCols()); + MPInt *rowStart = &inequalities(r, 0); + auto row = ArrayRef(rowStart, getNumCols()); if (isTriviallyValid(r) || !rowSet.insert(row).second) { redunIneq[r] = true; continue; @@ -1659,8 +1657,8 @@ // everything other than the one with the smallest constant term redundant. // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the // former two are redundant). - int64_t constTerm = atIneq(r, getNumCols() - 1); - auto rowWithoutConstTerm = ArrayRef(rowStart, getNumCols() - 1); + MPInt constTerm = atIneq(r, getNumCols() - 1); + auto rowWithoutConstTerm = ArrayRef(rowStart, getNumCols() - 1); const auto &ret = rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}}); if (!ret.second) { @@ -1816,19 +1814,19 @@ // integer exact. for (auto ubPos : ubIndices) { for (auto lbPos : lbIndices) { - SmallVector ineq; + SmallVector ineq; ineq.reserve(newRel.getNumCols()); - int64_t lbCoeff = atIneq(lbPos, pos); + MPInt lbCoeff = atIneq(lbPos, pos); // Note that in the comments above, ubCoeff is the negation of the // coefficient in the canonical form as the view taken here is that of the // term being moved to the other size of '>='. - int64_t ubCoeff = -atIneq(ubPos, pos); + MPInt ubCoeff = -atIneq(ubPos, pos); // TODO: refactor this loop to avoid all branches inside. for (unsigned l = 0, e = getNumCols(); l < e; l++) { if (l == pos) continue; assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified"); - int64_t lcm = mlir::lcm(lbCoeff, ubCoeff); + MPInt lcm = presburger::lcm(lbCoeff, ubCoeff); ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) + atIneq(lbPos, l) * (lcm / lbCoeff)); assert(lcm > 0 && "lcm should be positive!"); @@ -1853,7 +1851,7 @@ // Copy over the constraints not involving this variable. for (auto nbPos : nbIndices) { - SmallVector ineq; + SmallVector ineq; ineq.reserve(getNumCols() - 1); for (unsigned l = 0, e = getNumCols(); l < e; l++) { if (l == pos) @@ -1868,7 +1866,7 @@ // Copy over the equalities. for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { - SmallVector eq; + SmallVector eq; eq.reserve(newRel.getNumCols()); for (unsigned l = 0, e = getNumCols(); l < e; l++) { if (l == pos) @@ -1932,7 +1930,7 @@ /// Compares two affine bounds whose coefficients are provided in 'first' and /// 'second'. The last coefficient is the constant term. -static BoundCmpResult compareBounds(ArrayRef a, ArrayRef b) { +static BoundCmpResult compareBounds(ArrayRef a, ArrayRef b) { assert(a.size() == b.size()); // For the bounds to be comparable, their corresponding variable @@ -1984,20 +1982,20 @@ IntegerRelation commonCst(PresburgerSpace::getRelationSpace()); getCommonConstraints(*this, otherCst, commonCst); - std::vector> boundingLbs; - std::vector> boundingUbs; + std::vector> boundingLbs; + std::vector> boundingUbs; boundingLbs.reserve(2 * getNumDimVars()); boundingUbs.reserve(2 * getNumDimVars()); // To hold lower and upper bounds for each dimension. - SmallVector lb, otherLb, ub, otherUb; + SmallVector lb, otherLb, ub, otherUb; // To compute min of lower bounds and max of upper bounds for each dimension. - SmallVector minLb(getNumSymbolVars() + 1); - SmallVector maxUb(getNumSymbolVars() + 1); + SmallVector minLb(getNumSymbolVars() + 1); + SmallVector maxUb(getNumSymbolVars() + 1); // To compute final new lower and upper bounds for the union. - SmallVector newLb(getNumCols()), newUb(getNumCols()); + SmallVector newLb(getNumCols()), newUb(getNumCols()); - int64_t lbFloorDivisor, otherLbFloorDivisor; + MPInt lbFloorDivisor, otherLbFloorDivisor; for (unsigned d = 0, e = getNumDimVars(); d < e; ++d) { auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub); if (!extent.has_value()) @@ -2060,7 +2058,7 @@ // Copy over the symbolic part + constant term. std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimVars()); std::transform(newLb.begin() + getNumDimVars(), newLb.end(), - newLb.begin() + getNumDimVars(), std::negate()); + newLb.begin() + getNumDimVars(), std::negate()); std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimVars()); boundingLbs.push_back(newLb); diff --git a/mlir/lib/Analysis/Presburger/LinearTransform.cpp b/mlir/lib/Analysis/Presburger/LinearTransform.cpp --- a/mlir/lib/Analysis/Presburger/LinearTransform.cpp +++ b/mlir/lib/Analysis/Presburger/LinearTransform.cpp @@ -25,7 +25,7 @@ assert(m(row, sourceCol) != 0 && "Cannot divide by zero!"); assert((m(row, sourceCol) > 0 && m(row, targetCol) > 0) && "Operands must be positive!"); - int64_t ratio = m(row, targetCol) / m(row, sourceCol); + MPInt ratio = m(row, targetCol) / m(row, sourceCol); m.addToColumn(sourceCol, targetCol, -ratio); otherMatrix.addToColumn(sourceCol, targetCol, -ratio); } @@ -116,21 +116,21 @@ IntegerRelation result(rel.getSpace()); for (unsigned i = 0, e = rel.getNumEqualities(); i < e; ++i) { - ArrayRef eq = rel.getEquality(i); + ArrayRef eq = rel.getEquality(i); - int64_t c = eq.back(); + const MPInt &c = eq.back(); - SmallVector newEq = preMultiplyWithRow(eq.drop_back()); + SmallVector newEq = preMultiplyWithRow(eq.drop_back()); newEq.push_back(c); result.addEquality(newEq); } for (unsigned i = 0, e = rel.getNumInequalities(); i < e; ++i) { - ArrayRef ineq = rel.getInequality(i); + ArrayRef ineq = rel.getInequality(i); - int64_t c = ineq.back(); + const MPInt &c = ineq.back(); - SmallVector newIneq = preMultiplyWithRow(ineq.drop_back()); + SmallVector newIneq = preMultiplyWithRow(ineq.drop_back()); newIneq.push_back(c); result.addInequality(newIneq); } diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp --- a/mlir/lib/Analysis/Presburger/Matrix.cpp +++ b/mlir/lib/Analysis/Presburger/Matrix.cpp @@ -41,7 +41,7 @@ return nRows - 1; } -unsigned Matrix::appendExtraRow(ArrayRef elems) { +unsigned Matrix::appendExtraRow(ArrayRef elems) { assert(elems.size() == nColumns && "elems must match row length!"); unsigned row = appendExtraRow(); for (unsigned col = 0; col < nColumns; ++col) @@ -84,15 +84,15 @@ std::swap(at(row, column), at(row, otherColumn)); } -MutableArrayRef Matrix::getRow(unsigned row) { +MutableArrayRef Matrix::getRow(unsigned row) { return {&data[row * nReservedColumns], nColumns}; } -ArrayRef Matrix::getRow(unsigned row) const { +ArrayRef Matrix::getRow(unsigned row) const { return {&data[row * nReservedColumns], nColumns}; } -void Matrix::setRow(unsigned row, ArrayRef elems) { +void Matrix::setRow(unsigned row, ArrayRef elems) { assert(elems.size() == getNumColumns() && "elems size must match row length!"); for (unsigned i = 0, e = getNumColumns(); i < e; ++i) @@ -115,7 +115,7 @@ for (int ci = nReservedColumns - 1; ci >= 0; --ci) { unsigned r = ri; unsigned c = ci; - int64_t &dest = data[r * nReservedColumns + c]; + MPInt &dest = data[r * nReservedColumns + c]; if (c >= nColumns) { // NOLINT // Out of bounds columns are zero-initialized. NOLINT because clang-tidy // complains about this branch being the same as the c >= pos one. @@ -186,12 +186,13 @@ at(targetRow, c) = at(sourceRow, c); } -void Matrix::fillRow(unsigned row, int64_t value) { +void Matrix::fillRow(unsigned row, const MPInt &value) { for (unsigned col = 0; col < nColumns; ++col) at(row, col) = value; } -void Matrix::addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale) { +void Matrix::addToRow(unsigned sourceRow, unsigned targetRow, + const MPInt &scale) { if (scale == 0) return; for (unsigned col = 0; col < nColumns; ++col) @@ -199,7 +200,7 @@ } void Matrix::addToColumn(unsigned sourceColumn, unsigned targetColumn, - int64_t scale) { + const MPInt &scale) { if (scale == 0) return; for (unsigned row = 0, e = getNumRows(); row < e; ++row) @@ -216,31 +217,30 @@ at(row, column) = -at(row, column); } -int64_t Matrix::normalizeRow(unsigned row, unsigned cols) { +MPInt Matrix::normalizeRow(unsigned row, unsigned cols) { return normalizeRange(getRow(row).slice(0, cols)); } -int64_t Matrix::normalizeRow(unsigned row) { +MPInt Matrix::normalizeRow(unsigned row) { return normalizeRow(row, getNumColumns()); } -SmallVector -Matrix::preMultiplyWithRow(ArrayRef rowVec) const { +SmallVector Matrix::preMultiplyWithRow(ArrayRef rowVec) const { assert(rowVec.size() == getNumRows() && "Invalid row vector dimension!"); - SmallVector result(getNumColumns(), 0); + SmallVector result(getNumColumns(), MPInt(0)); for (unsigned col = 0, e = getNumColumns(); col < e; ++col) for (unsigned i = 0, e = getNumRows(); i < e; ++i) result[col] += rowVec[i] * at(i, col); return result; } -SmallVector -Matrix::postMultiplyWithColumn(ArrayRef colVec) const { +SmallVector +Matrix::postMultiplyWithColumn(ArrayRef colVec) const { assert(getNumColumns() == colVec.size() && "Invalid column vector dimension!"); - SmallVector result(getNumRows(), 0); + SmallVector result(getNumRows(), MPInt(0)); for (unsigned row = 0, e = getNumRows(); row < e; row++) for (unsigned i = 0, e = getNumColumns(); i < e; i++) result[row] += at(row, i) * colVec[i]; diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -15,11 +15,11 @@ // Return the result of subtracting the two given vectors pointwise. // The vectors must be of the same size. // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5]. -static SmallVector subtract(ArrayRef vecA, - ArrayRef vecB) { +static SmallVector subtract(ArrayRef vecA, + ArrayRef vecB) { assert(vecA.size() == vecB.size() && "Cannot subtract vectors of differing lengths!"); - SmallVector result; + SmallVector result; result.reserve(vecA.size()); for (unsigned i = 0, e = vecA.size(); i < e; ++i) result.push_back(vecA[i] - vecB[i]); @@ -33,18 +33,18 @@ return domain; } -Optional> -MultiAffineFunction::valueAt(ArrayRef point) const { +Optional> +MultiAffineFunction::valueAt(ArrayRef point) const { assert(point.size() == domainSet.getNumDimAndSymbolVars() && "Point has incorrect dimensionality!"); - Optional> maybeLocalValues = + Optional> maybeLocalValues = getDomain().containsPointNoLocal(point); if (!maybeLocalValues) return {}; // The point lies in the domain, so we need to compute the output value. - SmallVector pointHomogenous{llvm::to_vector(point)}; + SmallVector pointHomogenous{llvm::to_vector(point)}; // The given point didn't include the values of locals which the output is a // function of; we have computed one possible set of values and use them // here. The function is not allowed to have local vars that take more than @@ -56,18 +56,17 @@ // a 1 appended at the end. We can see that output * v gives the desired // output vector. pointHomogenous.emplace_back(1); - SmallVector result = - output.postMultiplyWithColumn(pointHomogenous); + SmallVector result = output.postMultiplyWithColumn(pointHomogenous); assert(result.size() == getNumOutputs()); return result; } -Optional> -PWMAFunction::valueAt(ArrayRef point) const { +Optional> +PWMAFunction::valueAt(ArrayRef point) const { assert(point.size() == getNumInputs() && "Point has incorrect dimensionality!"); for (const MultiAffineFunction &piece : pieces) - if (Optional> output = piece.valueAt(point)) + if (Optional> output = piece.valueAt(point)) return output; return {}; } @@ -318,7 +317,7 @@ for (unsigned level = 0; level < mafA.getNumOutputs(); ++level) { // Create the expression `outA - outB` for this level. - SmallVector subExpr = + SmallVector subExpr = subtract(mafA.getOutputExpr(level), mafB.getOutputExpr(level)); if (lexMin) { @@ -326,13 +325,13 @@ // outA - outB <= -1 // outA <= outB - 1 // outA < outB - levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, -1); + levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, MPInt(-1)); } else { // For lexMax, we add a lower bound of 1: // outA - outB >= 1 // outA > outB + 1 // outA > outB - levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, 1); + levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, MPInt(1)); } // Union the set with the result. diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -68,7 +68,7 @@ } /// A point is contained in the union iff any of the parts contain the point. -bool PresburgerRelation::containsPoint(ArrayRef point) const { +bool PresburgerRelation::containsPoint(ArrayRef point) const { return llvm::any_of(disjuncts, [&](const IntegerRelation &disjunct) { return (disjunct.containsPointNoLocal(point)); }); @@ -121,15 +121,15 @@ /// /// For every eq `coeffs == 0` there are two possible ineqs to index into. /// The first is coeffs >= 0 and the second is coeffs <= 0. -static SmallVector getIneqCoeffsFromIdx(const IntegerRelation &rel, - unsigned idx) { +static SmallVector getIneqCoeffsFromIdx(const IntegerRelation &rel, + unsigned idx) { assert(idx < rel.getNumInequalities() + 2 * rel.getNumEqualities() && "idx out of bounds!"); if (idx < rel.getNumInequalities()) return llvm::to_vector<8>(rel.getInequality(idx)); idx -= rel.getNumInequalities(); - ArrayRef eqCoeffs = rel.getEquality(idx / 2); + ArrayRef eqCoeffs = rel.getEquality(idx / 2); if (idx % 2 == 0) return llvm::to_vector<8>(eqCoeffs); @@ -389,7 +389,7 @@ // state before adding this complement constraint, and add s_ij to b. simplex.rollback(frame.simplexSnapshot); b.truncate(frame.bCounts); - SmallVector ineq = + SmallVector ineq = getIneqCoeffsFromIdx(frame.sI, *frame.lastIneqProcessed); b.addInequality(ineq); simplex.addInequality(ineq); @@ -407,7 +407,7 @@ frame.simplexSnapshot = simplex.getSnapshot(); unsigned idx = frame.ineqsToProcess.back(); - SmallVector ineq = + SmallVector ineq = getComplementIneq(getIneqCoeffsFromIdx(frame.sI, idx)); b.addInequality(ineq); simplex.addInequality(ineq); @@ -459,10 +459,10 @@ return llvm::all_of(disjuncts, std::mem_fn(&IntegerRelation::isIntegerEmpty)); } -bool PresburgerRelation::findIntegerSample(SmallVectorImpl &sample) { +bool PresburgerRelation::findIntegerSample(SmallVectorImpl &sample) { // A sample exists iff any of the disjuncts contains a sample. for (const IntegerRelation &disjunct : disjuncts) { - if (Optional> opt = disjunct.findIntegerSample()) { + if (Optional> opt = disjunct.findIntegerSample()) { sample = std::move(*opt); return true; } @@ -470,13 +470,13 @@ return false; } -Optional PresburgerRelation::computeVolume() const { +Optional PresburgerRelation::computeVolume() const { assert(getNumSymbolVars() == 0 && "Symbols are not yet supported!"); // The sum of the volumes of the disjuncts is a valid overapproximation of the // volume of their union, even if they overlap. - uint64_t result = 0; + MPInt result(0); for (const IntegerRelation &disjunct : disjuncts) { - Optional volume = disjunct.computeVolume(); + Optional volume = disjunct.computeVolume(); if (!volume) return {}; result += *volume; @@ -511,20 +511,20 @@ /// The list of all inversed equalities during typing. This ensures that /// the constraints exist even after the typing function has concluded. - SmallVector, 2> negEqs; + SmallVector, 2> negEqs; /// `redundantIneqsA` is the inequalities of `a` that are redundant for `b` /// (similarly for `cuttingIneqsA`, `redundantIneqsB`, and `cuttingIneqsB`). - SmallVector, 2> redundantIneqsA; - SmallVector, 2> cuttingIneqsA; + SmallVector, 2> redundantIneqsA; + SmallVector, 2> cuttingIneqsA; - SmallVector, 2> redundantIneqsB; - SmallVector, 2> cuttingIneqsB; + SmallVector, 2> redundantIneqsB; + SmallVector, 2> cuttingIneqsB; /// Given a Simplex `simp` and one of its inequalities `ineq`, check /// that the facet of `simp` where `ineq` holds as an equality is contained /// within `a`. - bool isFacetContained(ArrayRef ineq, Simplex &simp); + bool isFacetContained(ArrayRef ineq, Simplex &simp); /// Removes redundant constraints from `disjunct`, adds it to `disjuncts` and /// removes the disjuncts at position `i` and `j`. Updates `simplices` to @@ -548,13 +548,13 @@ /// Types the inequality `ineq` according to its `IneqType` for `simp` into /// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate /// inequalities were encountered. Otherwise, returns failure. - LogicalResult typeInequality(ArrayRef ineq, Simplex &simp); + LogicalResult typeInequality(ArrayRef ineq, Simplex &simp); /// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and /// -`eq` >= 0 according to their `IneqType` for `simp` into /// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate /// inequalities were encountered. Otherwise, returns failure. - LogicalResult typeEquality(ArrayRef eq, Simplex &simp); + LogicalResult typeEquality(ArrayRef eq, Simplex &simp); /// Replaces the element at position `i` with the last element and erases /// the last element for both `disjuncts` and `simplices`. @@ -631,10 +631,10 @@ /// Given a Simplex `simp` and one of its inequalities `ineq`, check /// that all inequalities of `cuttingIneqsB` are redundant for the facet of /// `simp` where `ineq` holds as an equality is contained within `a`. -bool SetCoalescer::isFacetContained(ArrayRef ineq, Simplex &simp) { +bool SetCoalescer::isFacetContained(ArrayRef ineq, Simplex &simp) { SimplexRollbackScopeExit scopeExit(simp); simp.addEquality(ineq); - return llvm::all_of(cuttingIneqsB, [&simp](ArrayRef curr) { + return llvm::all_of(cuttingIneqsB, [&simp](ArrayRef curr) { return simp.isRedundantInequality(curr); }); } @@ -696,23 +696,23 @@ /// redundant ones are, so only the cutting ones remain to be checked. Simplex &simp = simplices[i]; IntegerRelation &disjunct = disjuncts[i]; - if (llvm::any_of(cuttingIneqsA, [this, &simp](ArrayRef curr) { + if (llvm::any_of(cuttingIneqsA, [this, &simp](ArrayRef curr) { return !isFacetContained(curr, simp); })) return failure(); IntegerRelation newSet(disjunct.getSpace()); - for (ArrayRef curr : redundantIneqsA) + for (ArrayRef curr : redundantIneqsA) newSet.addInequality(curr); - for (ArrayRef curr : redundantIneqsB) + for (ArrayRef curr : redundantIneqsB) newSet.addInequality(curr); addCoalescedDisjunct(i, j, newSet); return success(); } -LogicalResult SetCoalescer::typeInequality(ArrayRef ineq, +LogicalResult SetCoalescer::typeInequality(ArrayRef ineq, Simplex &simp) { Simplex::IneqType type = simp.findIneqType(ineq); if (type == Simplex::IneqType::Redundant) @@ -724,11 +724,11 @@ return success(); } -LogicalResult SetCoalescer::typeEquality(ArrayRef eq, Simplex &simp) { +LogicalResult SetCoalescer::typeEquality(ArrayRef eq, Simplex &simp) { if (typeInequality(eq, simp).failed()) return failure(); negEqs.push_back(getNegatedCoeffs(eq)); - ArrayRef inv(negEqs.back()); + ArrayRef inv(negEqs.back()); if (typeInequality(inv, simp).failed()) return failure(); return success(); diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -21,10 +21,10 @@ // Return a + scale*b; LLVM_ATTRIBUTE_UNUSED -static SmallVector -scaleAndAddForAssert(ArrayRef a, int64_t scale, ArrayRef b) { +static SmallVector +scaleAndAddForAssert(ArrayRef a, const MPInt &scale, ArrayRef b) { assert(a.size() == b.size()); - SmallVector res; + SmallVector res; res.reserve(a.size()); for (unsigned i = 0, e = a.size(); i < e; ++i) res.push_back(a[i] + scale * b[i]); @@ -100,7 +100,7 @@ /// Add a new row to the tableau corresponding to the given constant term and /// list of coefficients. The coefficients are specified as a vector of /// (variable index, coefficient) pairs. -unsigned SimplexBase::addRow(ArrayRef coeffs, bool makeRestricted) { +unsigned SimplexBase::addRow(ArrayRef coeffs, bool makeRestricted) { assert(coeffs.size() == var.size() + 1 && "Incorrect number of coefficients!"); assert(var.size() + getNumFixedCols() == getNumColumns() && @@ -123,7 +123,7 @@ // // Symbols don't use the big M parameter since they do not get lex // optimized. - int64_t bigMCoeff = 0; + MPInt bigMCoeff(0); for (unsigned i = 0; i < coeffs.size() - 1; ++i) if (!var[i].isSymbol) bigMCoeff -= coeffs[i]; @@ -149,9 +149,9 @@ // row, scaled by the coefficient for the variable, accounting for the two // rows potentially having different denominators. The new denominator is // the lcm of the two. - int64_t lcm = mlir::lcm(tableau(newRow, 0), tableau(pos, 0)); - int64_t nRowCoeff = lcm / tableau(newRow, 0); - int64_t idxRowCoeff = coeffs[i] * (lcm / tableau(pos, 0)); + MPInt lcm = presburger::lcm(tableau(newRow, 0), tableau(pos, 0)); + MPInt nRowCoeff = lcm / tableau(newRow, 0); + MPInt idxRowCoeff = coeffs[i] * (lcm / tableau(pos, 0)); tableau(newRow, 0) = lcm; for (unsigned col = 1, e = getNumColumns(); col < e; ++col) tableau(newRow, col) = @@ -164,7 +164,7 @@ } namespace { -bool signMatchesDirection(int64_t elem, Direction direction) { +bool signMatchesDirection(const MPInt &elem, Direction direction) { assert(elem != 0 && "elem should not be 0"); return direction == Direction::Up ? elem > 0 : elem < 0; } @@ -260,7 +260,7 @@ /// The constraint is violated when added (it would be useless otherwise) /// so we immediately try to move it to a column. LogicalResult LexSimplexBase::addCut(unsigned row) { - int64_t d = tableau(row, 0); + MPInt d = tableau(row, 0); unsigned cutRow = addZeroRow(/*makeRestricted=*/true); tableau(cutRow, 0) = d; tableau(cutRow, 1) = -mod(-tableau(row, 1), d); // -c%d. @@ -284,7 +284,7 @@ return {}; } -MaybeOptimum> LexSimplex::findIntegerLexMin() { +MaybeOptimum> LexSimplex::findIntegerLexMin() { // We first try to make the tableau consistent. if (restoreRationalConsistency().failed()) return OptimumKind::Empty; @@ -315,19 +315,19 @@ llvm::map_range(*sample, std::mem_fn(&Fraction::getAsInteger))); } -bool LexSimplex::isSeparateInequality(ArrayRef coeffs) { +bool LexSimplex::isSeparateInequality(ArrayRef coeffs) { SimplexRollbackScopeExit scopeExit(*this); addInequality(coeffs); return findIntegerLexMin().isEmpty(); } -bool LexSimplex::isRedundantInequality(ArrayRef coeffs) { +bool LexSimplex::isRedundantInequality(ArrayRef coeffs) { return isSeparateInequality(getComplementIneq(coeffs)); } -SmallVector +SmallVector SymbolicLexSimplex::getSymbolicSampleNumerator(unsigned row) const { - SmallVector sample; + SmallVector sample; sample.reserve(nSymbol + 1); for (unsigned col = 3; col < 3 + nSymbol; ++col) sample.push_back(tableau(row, col)); @@ -335,9 +335,9 @@ return sample; } -SmallVector +SmallVector SymbolicLexSimplex::getSymbolicSampleIneq(unsigned row) const { - SmallVector sample = getSymbolicSampleNumerator(row); + SmallVector sample = getSymbolicSampleNumerator(row); // The inequality is equivalent to the GCD-normalized one. normalizeRange(sample); return sample; @@ -350,13 +350,14 @@ nSymbol++; } -static bool isRangeDivisibleBy(ArrayRef range, int64_t divisor) { +static bool isRangeDivisibleBy(ArrayRef range, const MPInt &divisor) { assert(divisor > 0 && "divisor must be positive!"); - return llvm::all_of(range, [divisor](int64_t x) { return x % divisor == 0; }); + return llvm::all_of(range, + [divisor](const MPInt &x) { return x % divisor == 0; }); } bool SymbolicLexSimplex::isSymbolicSampleIntegral(unsigned row) const { - int64_t denom = tableau(row, 0); + MPInt denom = tableau(row, 0); return tableau(row, 1) % denom == 0 && isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), denom); } @@ -395,7 +396,7 @@ /// This constraint is violated when added so we immediately try to move it to a /// column. LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) { - int64_t d = tableau(row, 0); + MPInt d = tableau(row, 0); if (isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), d)) { // The coefficients of symbols in the symbol numerator are divisible // by the denominator, so we can add the constraint directly, @@ -404,9 +405,9 @@ } // Construct the division variable `q = ((-c%d) + sum_i (-a_i%d)s_i)/d`. - SmallVector divCoeffs; + SmallVector divCoeffs; divCoeffs.reserve(nSymbol + 1); - int64_t divDenom = d; + MPInt divDenom = d; for (unsigned col = 3; col < 3 + nSymbol; ++col) divCoeffs.push_back(mod(-tableau(row, col), divDenom)); // (-a_i%d)s_i divCoeffs.push_back(mod(-tableau(row, 1), divDenom)); // -c%d. @@ -447,7 +448,7 @@ return; } - int64_t denom = tableau(u.pos, 0); + MPInt denom = tableau(u.pos, 0); if (tableau(u.pos, 2) < denom) { // M + u has a sample value of fM + something, where f < 1, so // u = (f - 1)M + something, which has a negative coefficient for M, @@ -458,8 +459,8 @@ assert(tableau(u.pos, 2) == denom && "Coefficient of M should not be greater than 1!"); - SmallVector sample = getSymbolicSampleNumerator(u.pos); - for (int64_t &elem : sample) { + SmallVector sample = getSymbolicSampleNumerator(u.pos); + for (MPInt &elem : sample) { assert(elem % denom == 0 && "coefficients must be integral!"); elem /= denom; } @@ -546,7 +547,7 @@ continue; } - SmallVector symbolicSample; + SmallVector symbolicSample; unsigned splitRow = 0; for (unsigned e = getNumRows(); splitRow < e; ++splitRow) { if (tableau(splitRow, 2) > 0) @@ -631,7 +632,7 @@ // was negative. assert(u.orientation == Orientation::Row && "The split row should have been returned to row orientation!"); - SmallVector splitIneq = + SmallVector splitIneq = getComplementIneq(getSymbolicSampleIneq(u.pos)); normalizeRange(splitIneq); if (moveRowUnknownToColumn(u.pos).failed()) { @@ -807,7 +808,7 @@ // all possible values of the symbols. auto getSampleChangeCoeffForVar = [this, row](unsigned col, const Unknown &u) -> Fraction { - int64_t a = tableau(row, col); + MPInt a = tableau(row, col); if (u.orientation == Orientation::Column) { // Pivot column case. if (u.pos == col) @@ -822,7 +823,7 @@ return {1, 1}; // Non-pivot row case. - int64_t c = tableau(u.pos, col); + MPInt c = tableau(u.pos, col); return {c, a}; }; @@ -856,7 +857,7 @@ Direction direction) const { Optional col; for (unsigned j = 2, e = getNumColumns(); j < e; ++j) { - int64_t elem = tableau(row, j); + MPInt elem = tableau(row, j); if (elem == 0) continue; @@ -1005,18 +1006,18 @@ // retConst being used uninitialized in the initialization of `diff` below. In // reality, these are always initialized when that line is reached since these // are set whenever retRow is set. - int64_t retElem = 0, retConst = 0; + MPInt retElem, retConst; for (unsigned row = nRedundant, e = getNumRows(); row < e; ++row) { if (skipRow && row == *skipRow) continue; - int64_t elem = tableau(row, col); + MPInt elem = tableau(row, col); if (elem == 0) continue; if (!unknownFromRow(row).restricted) continue; if (signMatchesDirection(elem, direction)) continue; - int64_t constTerm = tableau(row, 1); + MPInt constTerm = tableau(row, 1); if (!retRow) { retRow = row; @@ -1025,7 +1026,7 @@ continue; } - int64_t diff = retConst * elem - constTerm * retElem; + MPInt diff = retConst * elem - constTerm * retElem; if ((diff == 0 && rowUnknown[row] < rowUnknown[*retRow]) || (diff != 0 && !signMatchesDirection(diff, direction))) { retRow = row; @@ -1076,7 +1077,7 @@ /// We add the inequality and mark it as restricted. We then try to make its /// sample value non-negative. If this is not possible, the tableau has become /// empty and we mark it as such. -void Simplex::addInequality(ArrayRef coeffs) { +void Simplex::addInequality(ArrayRef coeffs) { unsigned conIndex = addRow(coeffs, /*makeRestricted=*/true); LogicalResult result = restoreRow(con[conIndex]); if (failed(result)) @@ -1089,10 +1090,10 @@ /// /// We simply add two opposing inequalities, which force the expression to /// be zero. -void SimplexBase::addEquality(ArrayRef coeffs) { +void SimplexBase::addEquality(ArrayRef coeffs) { addInequality(coeffs); - SmallVector negatedCoeffs; - for (int64_t coeff : coeffs) + SmallVector negatedCoeffs; + for (const MPInt &coeff : coeffs) negatedCoeffs.emplace_back(-coeff); addInequality(negatedCoeffs); } @@ -1267,17 +1268,18 @@ /// /// This constrains the remainder `coeffs - denom*q` to be in the /// range `[0, denom - 1]`, which fixes the integer value of the quotient `q`. -void SimplexBase::addDivisionVariable(ArrayRef coeffs, int64_t denom) { - assert(denom != 0 && "Cannot divide by zero!\n"); +void SimplexBase::addDivisionVariable(ArrayRef coeffs, + const MPInt &denom) { + assert(denom > 0 && "Denominator must be positive!"); appendVariable(); - SmallVector ineq(coeffs.begin(), coeffs.end()); - int64_t constTerm = ineq.back(); + SmallVector ineq(coeffs.begin(), coeffs.end()); + MPInt constTerm = ineq.back(); ineq.back() = -denom; ineq.push_back(constTerm); addInequality(ineq); - for (int64_t &coeff : ineq) + for (MPInt &coeff : ineq) coeff = -coeff; ineq.back() += denom - 1; addInequality(ineq); @@ -1327,7 +1329,7 @@ /// Compute the optimum of the specified expression in the specified direction, /// or None if it is unbounded. MaybeOptimum Simplex::computeOptimum(Direction direction, - ArrayRef coeffs) { + ArrayRef coeffs) { if (empty) return OptimumKind::Empty; @@ -1436,7 +1438,7 @@ if (empty) return false; - SmallVector dir(var.size() + 1); + SmallVector dir(var.size() + 1); for (unsigned i = 0; i < var.size(); ++i) { dir[i] = 1; @@ -1546,14 +1548,14 @@ } else { // If the variable is in row position, its sample value is the // entry in the constant column divided by the denominator. - int64_t denom = tableau(u.pos, 0); + MPInt denom = tableau(u.pos, 0); sample.emplace_back(tableau(u.pos, 1), denom); } } return sample; } -void LexSimplexBase::addInequality(ArrayRef coeffs) { +void LexSimplexBase::addInequality(ArrayRef coeffs) { addRow(coeffs, /*makeRestricted=*/true); } @@ -1578,7 +1580,7 @@ // If the variable is in row position, its sample value is the // entry in the constant column divided by the denominator. - int64_t denom = tableau(u.pos, 0); + MPInt denom = tableau(u.pos, 0); if (usingBigM) if (tableau(u.pos, 2) != denom) return OptimumKind::Unbounded; @@ -1587,14 +1589,14 @@ return sample; } -Optional> Simplex::getSamplePointIfIntegral() const { +Optional> Simplex::getSamplePointIfIntegral() const { // If the tableau is empty, no sample point exists. if (empty) return {}; // The value will always exist since the Simplex is non-empty. SmallVector rationalSample = *getRationalSample(); - SmallVector integerSample; + SmallVector integerSample; integerSample.reserve(var.size()); for (const Fraction &coord : rationalSample) { // If the sample is non-integral, return None. @@ -1626,14 +1628,14 @@ /// Add an equality dotProduct(dir, x - y) == 0. /// First pushes a snapshot for the current simplex state to the stack so /// that this can be rolled back later. - void addEqualityForDirection(ArrayRef dir) { - assert(llvm::any_of(dir, [](int64_t x) { return x != 0; }) && + void addEqualityForDirection(ArrayRef dir) { + assert(llvm::any_of(dir, [](const MPInt &x) { return x != 0; }) && "Direction passed is the zero vector!"); snapshotStack.push_back(simplex.getSnapshot()); simplex.addEquality(getCoeffsForDirection(dir)); } /// Compute max(dotProduct(dir, x - y)). - Fraction computeWidth(ArrayRef dir) { + Fraction computeWidth(ArrayRef dir) { MaybeOptimum maybeWidth = simplex.computeOptimum(Direction::Up, getCoeffsForDirection(dir)); assert(maybeWidth.isBounded() && "Width should be bounded!"); @@ -1642,9 +1644,9 @@ /// Compute max(dotProduct(dir, x - y)) and save the dual variables for only /// the direction equalities to `dual`. - Fraction computeWidthAndDuals(ArrayRef dir, - SmallVectorImpl &dual, - int64_t &dualDenom) { + Fraction computeWidthAndDuals(ArrayRef dir, + SmallVectorImpl &dual, + MPInt &dualDenom) { // We can't just call into computeWidth or computeOptimum since we need to // access the state of the tableau after computing the optimum, and these // functions rollback the insertion of the objective function into the @@ -1712,12 +1714,12 @@ /// i.e., dir_1 * x_1 + dir_2 * x_2 + ... + dir_n * x_n /// - dir_1 * y_1 - dir_2 * y_2 - ... - dir_n * y_n, /// where n is the dimension of the original polytope. - SmallVector getCoeffsForDirection(ArrayRef dir) { + SmallVector getCoeffsForDirection(ArrayRef dir) { assert(2 * dir.size() == simplex.getNumVariables() && "Direction vector has wrong dimensionality"); - SmallVector coeffs(dir.begin(), dir.end()); + SmallVector coeffs(dir.begin(), dir.end()); coeffs.reserve(2 * dir.size()); - for (int64_t coeff : dir) + for (const MPInt &coeff : dir) coeffs.push_back(-coeff); coeffs.emplace_back(0); // constant term return coeffs; @@ -1794,8 +1796,8 @@ GBRSimplex gbrSimplex(*this); SmallVector width; - SmallVector dual; - int64_t dualDenom; + SmallVector dual; + MPInt dualDenom; // Finds the value of u that minimizes width_i(b_{i+1} + u*b_i), caches the // duals from this computation, sets b_{i+1} to b_{i+1} + u*b_i, and returns @@ -1818,11 +1820,11 @@ auto updateBasisWithUAndGetFCandidate = [&](unsigned i) -> Fraction { assert(i < level + dual.size() && "dual_i is not known!"); - int64_t u = floorDiv(dual[i - level], dualDenom); + MPInt u = floorDiv(dual[i - level], dualDenom); basis.addToRow(i, i + 1, u); if (dual[i - level] % dualDenom != 0) { - SmallVector candidateDual[2]; - int64_t candidateDualDenom[2]; + SmallVector candidateDual[2]; + MPInt candidateDualDenom[2]; Fraction widthI[2]; // Initially u is floor(dual) and basis reflects this. @@ -1849,11 +1851,13 @@ // Check the value at u - 1. assert(gbrSimplex.computeWidth(scaleAndAddForAssert( - basis.getRow(i + 1), -1, basis.getRow(i))) >= widthI[j] && + basis.getRow(i + 1), MPInt(-1), basis.getRow(i))) >= + widthI[j] && "Computed u value does not minimize the width!"); // Check the value at u + 1. assert(gbrSimplex.computeWidth(scaleAndAddForAssert( - basis.getRow(i + 1), +1, basis.getRow(i))) >= widthI[j] && + basis.getRow(i + 1), MPInt(+1), basis.getRow(i))) >= + widthI[j] && "Computed u value does not minimize the width!"); dual = std::move(candidateDual[j]); @@ -1953,7 +1957,7 @@ /// /// To avoid potentially arbitrarily large recursion depths leading to stack /// overflows, this algorithm is implemented iteratively. -Optional> Simplex::findIntegerSample() { +Optional> Simplex::findIntegerSample() { if (empty) return {}; @@ -1964,9 +1968,9 @@ // The snapshot just before constraining a direction to a value at each level. SmallVector snapshotStack; // The maximum value in the range of the direction for each level. - SmallVector upperBoundStack; + SmallVector upperBoundStack; // The next value to try constraining the basis vector to at each level. - SmallVector nextValueStack; + SmallVector nextValueStack; snapshotStack.reserve(basis.getNumRows()); upperBoundStack.reserve(basis.getNumRows()); @@ -1986,11 +1990,11 @@ // just come down a level ("recursed"). Find the lower and upper bounds. // If there is more than one integer point in the range, perform // generalized basis reduction. - SmallVector basisCoeffs = + SmallVector basisCoeffs = llvm::to_vector<8>(basis.getRow(level)); basisCoeffs.emplace_back(0); - MaybeOptimum minRoundedUp, maxRoundedDown; + MaybeOptimum minRoundedUp, maxRoundedDown; std::tie(minRoundedUp, maxRoundedDown) = computeIntegerBounds(basisCoeffs); @@ -2040,7 +2044,7 @@ // to the snapshot of the starting state at this level. (in the "recursed" // case this has no effect) rollback(snapshotStack.back()); - int64_t nextValue = nextValueStack.back(); + MPInt nextValue = nextValueStack.back(); ++nextValueStack.back(); if (nextValue > upperBoundStack.back()) { // We have exhausted the range and found no solution. Pop the stack and @@ -2053,8 +2057,8 @@ } // Try the next value in the range and "recurse" into the next level. - SmallVector basisCoeffs(basis.getRow(level).begin(), - basis.getRow(level).end()); + SmallVector basisCoeffs(basis.getRow(level).begin(), + basis.getRow(level).end()); basisCoeffs.push_back(-nextValue); addEquality(basisCoeffs); level++; @@ -2065,11 +2069,11 @@ /// Compute the minimum and maximum integer values the expression can take. We /// compute each separately. -std::pair, MaybeOptimum> -Simplex::computeIntegerBounds(ArrayRef coeffs) { - MaybeOptimum minRoundedUp( +std::pair, MaybeOptimum> +Simplex::computeIntegerBounds(ArrayRef coeffs) { + MaybeOptimum minRoundedUp( computeOptimum(Simplex::Direction::Down, coeffs).map(ceil)); - MaybeOptimum maxRoundedDown( + MaybeOptimum maxRoundedDown( computeOptimum(Simplex::Direction::Up, coeffs).map(floor)); return {minRoundedUp, maxRoundedDown}; } @@ -2140,7 +2144,7 @@ /// maximum satisfy it. Hence, it is a cut inequality. If both are < 0, no /// points of the polytope satisfy the inequality, which means it is a separate /// inequality. -Simplex::IneqType Simplex::findIneqType(ArrayRef coeffs) { +Simplex::IneqType Simplex::findIneqType(ArrayRef coeffs) { MaybeOptimum minimum = computeOptimum(Direction::Down, coeffs); if (minimum.isBounded() && *minimum >= Fraction(0, 1)) { return IneqType::Redundant; @@ -2155,7 +2159,7 @@ /// Checks whether the type of the inequality with coefficients `coeffs` /// is Redundant. -bool Simplex::isRedundantInequality(ArrayRef coeffs) { +bool Simplex::isRedundantInequality(ArrayRef coeffs) { assert(!empty && "It is not meaningful to ask about redundancy in an empty set!"); return findIneqType(coeffs) == IneqType::Redundant; @@ -2165,7 +2169,7 @@ /// the existing constraints. This is redundant when `coeffs` is already /// always zero under the existing constraints. `coeffs` is always zero /// when the minimum and maximum value that `coeffs` can take are both zero. -bool Simplex::isRedundantEquality(ArrayRef coeffs) { +bool Simplex::isRedundantEquality(ArrayRef coeffs) { assert(!empty && "It is not meaningful to ask about redundancy in an empty set!"); MaybeOptimum minimum = computeOptimum(Direction::Down, coeffs); diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp --- a/mlir/lib/Analysis/Presburger/Utils.cpp +++ b/mlir/lib/Analysis/Presburger/Utils.cpp @@ -12,6 +12,7 @@ #include "mlir/Analysis/Presburger/Utils.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/Presburger/MPInt.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/MathExtras.h" @@ -20,15 +21,16 @@ /// Normalize a division's `dividend` and the `divisor` by their GCD. For /// example: if the dividend and divisor are [2,0,4] and 4 respectively, -/// they get normalized to [1,0,2] and 2. -static void normalizeDivisionByGCD(MutableArrayRef dividend, - unsigned &divisor) { +/// they get normalized to [1,0,2] and 2. The divisor must be non-negative; +/// it is allowed for the divisor to be zero, but nothing is done in this case. +static void normalizeDivisionByGCD(MutableArrayRef dividend, + MPInt &divisor) { + assert(divisor > 0 && "divisor must be non-negative!"); if (divisor == 0 || dividend.empty()) return; // We take the absolute value of dividend's coefficients to make sure that // `gcd` is positive. - int64_t gcd = - llvm::greatestCommonDivisor(std::abs(dividend.front()), int64_t(divisor)); + MPInt gcd = presburger::gcd(abs(dividend.front()), divisor); // The reason for ignoring the constant term is as follows. // For a division: @@ -38,14 +40,14 @@ // Since `{a/m}/d` in the dividend satisfies 0 <= {a/m}/d < 1/d, it will not // influence the result of the floor division and thus, can be ignored. for (size_t i = 1, m = dividend.size() - 1; i < m; i++) { - gcd = llvm::greatestCommonDivisor(std::abs(dividend[i]), gcd); + gcd = presburger::gcd(abs(dividend[i]), gcd); if (gcd == 1) return; } // Normalize the dividend and the denominator. std::transform(dividend.begin(), dividend.end(), dividend.begin(), - [gcd](int64_t &n) { return floorDiv(n, gcd); }); + [gcd](MPInt &n) { return floorDiv(n, gcd); }); divisor /= gcd; } @@ -85,12 +87,11 @@ /// -divisor * var + expr - c >= 0 <-- Upper bound for 'var' /// /// If successful, `expr` is set to dividend of the division and `divisor` is -/// set to the denominator of the division. The final division expression is -/// normalized by GCD. +/// set to the denominator of the division, which will be positive. +/// The final division expression is normalized by GCD. static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, unsigned ubIneq, unsigned lbIneq, - MutableArrayRef expr, - unsigned &divisor) { + MutableArrayRef expr, MPInt &divisor) { assert(pos <= cst.getNumVars() && "Invalid variable position"); assert(ubIneq <= cst.getNumInequalities() && @@ -98,6 +99,8 @@ assert(lbIneq <= cst.getNumInequalities() && "Invalid upper bound inequality position"); assert(expr.size() == cst.getNumCols() && "Invalid expression size"); + assert(cst.atIneq(lbIneq, pos) > 0 && "lbIneq is not a lower bound!"); + assert(cst.atIneq(ubIneq, pos) < 0 && "ubIneq is not an upper bound!"); // Extract divisor from the lower bound. divisor = cst.atIneq(lbIneq, pos); @@ -115,12 +118,12 @@ // Then, check if the constant term is of the proper form. // Due to the form of the upper/lower bound inequalities, the sum of their // constants is `divisor - 1 - c`. From this, we can extract c: - int64_t constantSum = cst.atIneq(lbIneq, cst.getNumCols() - 1) + - cst.atIneq(ubIneq, cst.getNumCols() - 1); - int64_t c = divisor - 1 - constantSum; + MPInt constantSum = cst.atIneq(lbIneq, cst.getNumCols() - 1) + + cst.atIneq(ubIneq, cst.getNumCols() - 1); + MPInt c = divisor - 1 - constantSum; - // Check if `c` satisfies the condition `0 <= c <= divisor - 1`. This also - // implictly checks that `divisor` is positive. + // Check if `c` satisfies the condition `0 <= c <= divisor - 1`. + // This also implictly checks that `divisor` is positive. if (!(0 <= c && c <= divisor - 1)) // NOLINT return failure(); @@ -152,8 +155,8 @@ /// set to the denominator of the division. The final division expression is /// normalized by GCD. static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, - unsigned eqInd, MutableArrayRef expr, - unsigned &divisor) { + unsigned eqInd, MutableArrayRef expr, + MPInt &divisor) { assert(pos <= cst.getNumVars() && "Invalid variable position"); assert(eqInd <= cst.getNumEqualities() && "Invalid equality position"); @@ -162,10 +165,10 @@ // Extract divisor, the divisor can be negative and hence its sign information // is stored in `signDiv` to reverse the sign of dividend's coefficients. // Equality must involve the pos-th variable and hence `tempDiv` != 0. - int64_t tempDiv = cst.atEq(eqInd, pos); + MPInt tempDiv = cst.atEq(eqInd, pos); if (tempDiv == 0) return failure(); - int64_t signDiv = tempDiv < 0 ? -1 : 1; + int signDiv = tempDiv < 0 ? -1 : 1; // The divisor is always a positive integer. divisor = tempDiv * signDiv; @@ -184,7 +187,7 @@ // explicit representation has not been found yet, otherwise returns `true`. static bool checkExplicitRepresentation(const IntegerRelation &cst, ArrayRef foundRepr, - ArrayRef dividend, + ArrayRef dividend, unsigned pos) { // Exit to avoid circular dependencies between divisions. for (unsigned c = 0, e = cst.getNumVars(); c < e; ++c) { @@ -213,9 +216,11 @@ /// the representation could be computed, `dividend` and `denominator` are set. /// If the representation could not be computed, the kind attribute in /// `MaybeLocalRepr` is set to None. -MaybeLocalRepr presburger::computeSingleVarRepr( - const IntegerRelation &cst, ArrayRef foundRepr, unsigned pos, - MutableArrayRef dividend, unsigned &divisor) { +MaybeLocalRepr presburger::computeSingleVarRepr(const IntegerRelation &cst, + ArrayRef foundRepr, + unsigned pos, + MutableArrayRef dividend, + MPInt &divisor) { assert(pos < cst.getNumVars() && "invalid position"); assert(foundRepr.size() == cst.getNumVars() && "Size of foundRepr does not match total number of variables"); @@ -254,6 +259,18 @@ return repr; } +MaybeLocalRepr presburger::computeSingleVarRepr( + const IntegerRelation &cst, ArrayRef foundRepr, unsigned pos, + SmallVector ÷nd, unsigned &divisor) { + SmallVector dividendMPInt(cst.getNumCols()); + MPInt divisorMPInt; + MaybeLocalRepr result = + computeSingleVarRepr(cst, foundRepr, pos, dividendMPInt, divisorMPInt); + dividend = getInt64Vec(dividendMPInt); + divisor = unsigned(int64_t(divisorMPInt)); + return result; +} + llvm::SmallBitVector presburger::getSubrangeBitVector(unsigned len, unsigned setOffset, unsigned numSet) { @@ -290,68 +307,70 @@ divsA.removeDuplicateDivs(merge); } -SmallVector presburger::getDivUpperBound(ArrayRef dividend, - int64_t divisor, - unsigned localVarIdx) { +SmallVector presburger::getDivUpperBound(ArrayRef dividend, + const MPInt &divisor, + unsigned localVarIdx) { + assert(divisor > 0 && "divisor must be positive!"); assert(dividend[localVarIdx] == 0 && "Local to be set to division must have zero coeff!"); - SmallVector ineq(dividend.begin(), dividend.end()); + SmallVector ineq(dividend.begin(), dividend.end()); ineq[localVarIdx] = -divisor; return ineq; } -SmallVector presburger::getDivLowerBound(ArrayRef dividend, - int64_t divisor, - unsigned localVarIdx) { +SmallVector presburger::getDivLowerBound(ArrayRef dividend, + const MPInt &divisor, + unsigned localVarIdx) { + assert(divisor > 0 && "divisor must be positive!"); assert(dividend[localVarIdx] == 0 && "Local to be set to division must have zero coeff!"); - SmallVector ineq(dividend.size()); + SmallVector ineq(dividend.size()); std::transform(dividend.begin(), dividend.end(), ineq.begin(), - std::negate()); + std::negate()); ineq[localVarIdx] = divisor; ineq.back() += divisor - 1; return ineq; } -int64_t presburger::gcdRange(ArrayRef range) { - int64_t gcd = 0; - for (int64_t elem : range) { - gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(elem)); +MPInt presburger::gcdRange(ArrayRef range) { + MPInt gcd(0); + for (const MPInt &elem : range) { + gcd = presburger::gcd(gcd, abs(elem)); if (gcd == 1) return gcd; } return gcd; } -int64_t presburger::normalizeRange(MutableArrayRef range) { - int64_t gcd = gcdRange(range); - if (gcd == 0 || gcd == 1) +MPInt presburger::normalizeRange(MutableArrayRef range) { + MPInt gcd = gcdRange(range); + if ((gcd == 0) || (gcd == 1)) return gcd; - for (int64_t &elem : range) + for (MPInt &elem : range) elem /= gcd; return gcd; } -void presburger::normalizeDiv(MutableArrayRef num, int64_t &denom) { +void presburger::normalizeDiv(MutableArrayRef num, MPInt &denom) { assert(denom > 0 && "denom must be positive!"); - int64_t gcd = llvm::greatestCommonDivisor(gcdRange(num), denom); - for (int64_t &coeff : num) + MPInt gcd = presburger::gcd(gcdRange(num), denom); + for (MPInt &coeff : num) coeff /= gcd; denom /= gcd; } -SmallVector presburger::getNegatedCoeffs(ArrayRef coeffs) { - SmallVector negatedCoeffs; +SmallVector presburger::getNegatedCoeffs(ArrayRef coeffs) { + SmallVector negatedCoeffs; negatedCoeffs.reserve(coeffs.size()); - for (int64_t coeff : coeffs) + for (const MPInt &coeff : coeffs) negatedCoeffs.emplace_back(-coeff); return negatedCoeffs; } -SmallVector presburger::getComplementIneq(ArrayRef ineq) { - SmallVector coeffs; +SmallVector presburger::getComplementIneq(ArrayRef ineq) { + SmallVector coeffs; coeffs.reserve(ineq.size()); - for (int64_t coeff : ineq) + for (const MPInt &coeff : ineq) coeffs.emplace_back(-coeff); --coeffs.back(); return coeffs; diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -441,10 +441,12 @@ dependenceComponents->resize(numCommonLoops); for (unsigned j = 0; j < numCommonLoops; ++j) { (*dependenceComponents)[j].op = commonLoops[j].getOperation(); - auto lbConst = dependenceDomain->getConstantBound(IntegerPolyhedron::LB, j); + auto lbConst = + dependenceDomain->getConstantBound64(IntegerPolyhedron::LB, j); (*dependenceComponents)[j].lb = lbConst.value_or(std::numeric_limits::min()); - auto ubConst = dependenceDomain->getConstantBound(IntegerPolyhedron::UB, j); + auto ubConst = + dependenceDomain->getConstantBound64(IntegerPolyhedron::UB, j); (*dependenceComponents)[j].ub = ubConst.value_or(std::numeric_limits::max()); } diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -757,14 +757,14 @@ // Check for the aforementioned conditions in each equality. for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities(); curEquality < numEqualities; curEquality++) { - int64_t coefficientAtPos = cst.atEq(curEquality, pos); + int64_t coefficientAtPos = cst.atEq64(curEquality, pos); // If current equality does not involve `var_r`, continue to the next // equality. if (coefficientAtPos == 0) continue; // Constant term should be 0 in this equality. - if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0) + if (cst.atEq64(curEquality, cst.getNumCols() - 1) != 0) continue; // Traverse through the equality and construct the dividend expression @@ -786,7 +786,7 @@ // Ignore var_r. if (curVar == pos) continue; - int64_t coefficientOfCurVar = cst.atEq(curEquality, curVar); + int64_t coefficientOfCurVar = cst.atEq64(curEquality, curVar); // Ignore vars that do not contribute to the current equality. if (coefficientOfCurVar == 0) continue; @@ -827,8 +827,8 @@ // Express `var_r` as `var_n % divisor` and store the expression in `memo`. if (quotientCount >= 1) { - auto ub = cst.getConstantBound(FlatAffineValueConstraints::BoundType::UB, - dimExpr.getPosition()); + auto ub = cst.getConstantBound64( + FlatAffineValueConstraints::BoundType::UB, dimExpr.getPosition()); // If `var_n` has an upperbound that is less than the divisor, mod can be // eliminated altogether. if (ub && *ub < divisor) @@ -912,7 +912,7 @@ lbExprs.reserve(lbIndices.size() + eqIndices.size()); // Lower bound expressions. for (auto idx : lbIndices) { - auto ineq = getInequality(idx); + auto ineq = getInequality64(idx); // Extract the lower bound (in terms of other coeff's + const), i.e., if // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j // - 1. @@ -930,7 +930,7 @@ ubExprs.reserve(ubIndices.size() + eqIndices.size()); // Upper bound expressions. for (auto idx : ubIndices) { - auto ineq = getInequality(idx); + auto ineq = getInequality64(idx); // Extract the upper bound (in terms of other coeff's + const). addCoeffs(ineq, ub); auto expr = @@ -943,7 +943,7 @@ // Equalities. It's both a lower and a upper bound. SmallVector b; for (auto idx : eqIndices) { - auto eq = getEquality(idx); + auto eq = getEquality64(idx); addCoeffs(eq, b); if (eq[pos + offset] > 0) std::transform(b.begin(), b.end(), b.begin(), std::negate()); @@ -1006,8 +1006,8 @@ if (memo[pos]) continue; - auto lbConst = getConstantBound(BoundType::LB, pos); - auto ubConst = getConstantBound(BoundType::UB, pos); + auto lbConst = getConstantBound64(BoundType::LB, pos); + auto ubConst = getConstantBound64(BoundType::UB, pos); if (lbConst.has_value() && ubConst.has_value()) { // Detect equality to a constant. if (lbConst.value() == ubConst.value()) { @@ -1044,7 +1044,7 @@ for (j = 0, e = getNumVars(); j < e; ++j) { if (j == pos) continue; - int64_t c = atEq(idx, j); + int64_t c = atEq64(idx, j); if (c == 0) continue; // If any of the involved IDs hasn't been found yet, we can't proceed. @@ -1058,8 +1058,8 @@ continue; // Add constant term to AffineExpr. - expr = expr + atEq(idx, getNumVars()); - int64_t vPos = atEq(idx, pos); + expr = expr + atEq64(idx, getNumVars()); + int64_t vPos = atEq64(idx, pos); assert(vPos != 0 && "expected non-zero here"); if (vPos > 0) expr = (-expr).floorDiv(vPos); @@ -1118,7 +1118,7 @@ if (!lbMap || lbMap.getNumResults() > 1) { LLVM_DEBUG(llvm::dbgs() << "WARNING: Potentially over-approximating slice lb\n"); - auto lbConst = getConstantBound(BoundType::LB, pos + offset); + auto lbConst = getConstantBound64(BoundType::LB, pos + offset); if (lbConst.has_value()) { lbMap = AffineMap::get(numMapDims, numMapSymbols, @@ -1128,7 +1128,7 @@ if (!ubMap || ubMap.getNumResults() > 1) { LLVM_DEBUG(llvm::dbgs() << "WARNING: Potentially over-approximating slice ub\n"); - auto ubConst = getConstantBound(BoundType::UB, pos + offset); + auto ubConst = getConstantBound64(BoundType::UB, pos + offset); if (ubConst.has_value()) { ubMap = AffineMap::get( numMapDims, numMapSymbols, @@ -1488,7 +1488,7 @@ auto localExprs = ArrayRef(memo).take_back(getNumLocalVars()); // Compute the AffineExpr lower/upper bound for this inequality. - ArrayRef inequality = getInequality(ineqPos); + SmallVector inequality = getInequality64(ineqPos); SmallVector bound; bound.reserve(getNumCols() - 1); // Everything other than the coefficient at `pos`. @@ -1562,10 +1562,10 @@ exprs.reserve(getNumConstraints()); for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) - exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms, - localExprs, context)); + exprs.push_back(getAffineExprFromFlatForm(getEquality64(i), numDims, + numSyms, localExprs, context)); for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) - exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims, + exprs.push_back(getAffineExprFromFlatForm(getInequality64(i), numDims, numSyms, localExprs, context)); return IntegerSet::get(numDims, numSyms, exprs, eqFlags); } diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -374,7 +374,7 @@ for (unsigned d = 0; d < rank; d++) { SmallVector lb; Optional diff = - cstWithShapeBounds.getConstantBoundOnDimSize(d, &lb, &lbDivisor); + cstWithShapeBounds.getConstantBoundOnDimSize64(d, &lb, &lbDivisor); if (diff.has_value()) { diffConstant = diff.value(); assert(diffConstant >= 0 && "Dim size bound can't be negative"); diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1788,7 +1788,7 @@ newShape[d] = -1; } else { // The lower bound for the shape is always zero. - auto ubConst = fac.getConstantBound(IntegerPolyhedron::UB, d); + auto ubConst = fac.getConstantBound64(IntegerPolyhedron::UB, d); // For a static memref and an affine map with no symbols, this is // always bounded. assert(ubConst && "should always have an upper bound"); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -303,7 +303,7 @@ // of the terminals of the index computation. unsigned pos = getPosition(value); if (constantRequired) { - auto ubConst = constraints.getConstantBound( + auto ubConst = constraints.getConstantBound64( FlatAffineValueConstraints::BoundType::UB, pos); if (!ubConst) return; diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp --- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp @@ -189,7 +189,7 @@ // Skip unused operands and operands that are already constants. if (!newOperands[i] || getConstantIntValue(newOperands[i])) continue; - if (auto bound = constraints.getConstantBound(IntegerPolyhedron::EQ, i)) + if (auto bound = constraints.getConstantBound64(IntegerPolyhedron::EQ, i)) newOperands[i] = rewriter.create(op->getLoc(), *bound); } diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -39,8 +39,8 @@ return set; } -static void dump(ArrayRef vec) { - for (int64_t x : vec) +static void dump(ArrayRef vec) { + for (const MPInt &x : vec) llvm::errs() << x << ' '; llvm::errs() << '\n'; } @@ -58,8 +58,8 @@ /// opposite of hasSample. static void checkSample(bool hasSample, const IntegerPolyhedron &poly, TestFunction fn = TestFunction::Sample) { - Optional> maybeSample; - MaybeOptimum> maybeLexMin; + Optional> maybeSample; + MaybeOptimum> maybeLexMin; switch (fn) { case TestFunction::Sample: maybeSample = poly.findIntegerSample(); @@ -426,6 +426,12 @@ "-7*x - 4*y + z + 1 >= 0," "2*x - 7*y - 8*z - 7 >= 0," "9*x + 8*y - 9*z - 7 >= 0)")); + + checkSample( + true, + parsePoly( + "(x) : (1152921504606846977*(x floordiv 1152921504606846977) == x, " + "1152921504606846976*(x floordiv 1152921504606846976) == x)")); } TEST(IntegerPolyhedronTest, IsIntegerEmptyTest) { @@ -569,10 +575,10 @@ // y >= 128x >= 0. poly5.removeRedundantConstraints(); EXPECT_EQ(poly5.getNumInequalities(), 3u); - SmallVector redundantConstraint = {0, 1, 0}; + SmallVector redundantConstraint = getMPIntVec({0, 1, 0}); for (unsigned i = 0; i < 3; ++i) { // Ensure that the removed constraint was the redundant constraint [3]. - EXPECT_NE(poly5.getInequality(i), ArrayRef(redundantConstraint)); + EXPECT_NE(poly5.getInequality(i), ArrayRef(redundantConstraint)); } } @@ -611,11 +617,12 @@ static void checkDivisionRepresentation( IntegerPolyhedron &poly, const std::vector> &expectedDividends, - ArrayRef expectedDenominators) { + ArrayRef expectedDenominators) { DivisionRepr divs = poly.getLocalReprs(); // Check that the `denominators` and `expectedDenominators` match. - EXPECT_TRUE(expectedDenominators == divs.getDenoms()); + EXPECT_EQ(ArrayRef(getMPIntVec(expectedDenominators)), + divs.getDenoms()); // Check that the `dividends` and `expectedDividends` match. If the // denominator for a division is zero, we ignore its dividend. @@ -637,7 +644,7 @@ std::vector> divisions = {{1, 0, 0, 4}, {1, 0, 0, 100}}; - SmallVector denoms = {10, 10}; + SmallVector denoms = {10, 10}; // Check if floordivs can be computed when no other inequalities exist // and floor divs do not depend on each other. @@ -656,7 +663,7 @@ std::vector> divisions = {{0, 0, 0, 0, 0, 0, 3}, {0, 0, 0, 0, 0, 0, 2}}; - SmallVector denoms = {1, 1}; + SmallVector denoms = {1, 1}; // Check if floordivs with constant numerator can be computed. checkDivisionRepresentation(poly, divisions, denoms); @@ -680,7 +687,7 @@ {3, 0, 9, 2, 2, 0, 0, 10}, {0, 1, -123, 2, 0, -4, 0, 10}}; - SmallVector denoms = {3, 5, 3}; + SmallVector denoms = {3, 5, 3}; // Check if floordivs which may depend on other floordivs can be computed. checkDivisionRepresentation(poly, divisions, denoms); @@ -701,7 +708,7 @@ poly.removeRedundantConstraints(); std::vector> divisions = {{1, 0, 0}}; - SmallVector denoms = {3}; + SmallVector denoms = {3}; // Check if the divisions can be computed even with a tighter upper bound. checkDivisionRepresentation(poly, divisions, denoms); @@ -714,7 +721,7 @@ poly.convertToLocal(VarKind::SetDim, 2, 3); std::vector> divisions = {{1, 1, 0, 1}}; - SmallVector denoms = {4}; + SmallVector denoms = {4}; // Check if the divisions can be computed even with a tighter upper bound. checkDivisionRepresentation(poly, divisions, denoms); @@ -728,7 +735,7 @@ poly.convertToLocal(VarKind::SetDim, 2, 3); std::vector> divisions = {{1, 1, 0, 0}}; - SmallVector denoms = {4}; + SmallVector denoms = {4}; checkDivisionRepresentation(poly, divisions, denoms); } @@ -738,7 +745,7 @@ poly.convertToLocal(VarKind::SetDim, 2, 3); std::vector> divisions = {{1, 1, 0, 0}}; - SmallVector denoms = {4}; + SmallVector denoms = {4}; checkDivisionRepresentation(poly, divisions, denoms); } @@ -748,7 +755,7 @@ poly.convertToLocal(VarKind::SetDim, 2, 3); std::vector> divisions = {{-1, -1, 0, 2}}; - SmallVector denoms = {3}; + SmallVector denoms = {3}; checkDivisionRepresentation(poly, divisions, denoms); } @@ -764,7 +771,7 @@ std::vector> divisions = {{1, 1, 0, 0, 1}, {1, 1, 0, 0, 0}}; - SmallVector denoms = {4, 3}; + SmallVector denoms = {4, 3}; checkDivisionRepresentation(poly, divisions, denoms); } @@ -777,7 +784,7 @@ poly.convertToLocal(VarKind::SetDim, 1, 2); std::vector> divisions = {{0, 0, 0}}; - SmallVector denoms = {0}; + SmallVector denoms = {0}; // Check that no division is computed. checkDivisionRepresentation(poly, divisions, denoms); @@ -793,7 +800,7 @@ // = floor((1/3) + (-1 - x)/2) // = floor((-1 - x)/2). std::vector> divisions = {{-1, 0, -1}}; - SmallVector denoms = {2}; + SmallVector denoms = {2}; checkDivisionRepresentation(poly, divisions, denoms); } @@ -1061,7 +1068,7 @@ // Merging triggers normalization. std::vector> divisions = {{-1, 0, 0, 1}, {-1, 0, 0, -2}}; - SmallVector denoms = {2, 3}; + SmallVector denoms = {2, 3}; checkDivisionRepresentation(poly1, divisions, denoms); } @@ -1139,9 +1146,9 @@ } void expectIntegerLexMin(const IntegerPolyhedron &poly, ArrayRef min) { - auto lexMin = poly.findIntegerLexMin(); + MaybeOptimum> lexMin = poly.findIntegerLexMin(); ASSERT_TRUE(lexMin.isBounded()); - EXPECT_EQ(ArrayRef(*lexMin), min); + EXPECT_EQ(*lexMin, getMPIntVec(min)); } void expectNoIntegerLexMin(OptimumKind kind, const IntegerPolyhedron &poly) { @@ -1389,8 +1396,8 @@ static void expectComputedVolumeIsValidOverapprox(const IntegerPolyhedron &poly, - Optional trueVolume, - Optional resultBound) { + Optional trueVolume, + Optional resultBound) { expectComputedVolumeIsValidOverapprox(poly.computeVolume(), trueVolume, resultBound); } @@ -1442,19 +1449,24 @@ /*trueVolume=*/{}, /*resultBound=*/{}); } +bool containsPointNoLocal(const IntegerPolyhedron &poly, + ArrayRef point) { + return poly.containsPointNoLocal(getMPIntVec(point)).hasValue(); +} + TEST(IntegerPolyhedronTest, containsPointNoLocal) { IntegerPolyhedron poly1 = parsePoly("(x) : ((x floordiv 2) - x == 0)"); - EXPECT_TRUE(poly1.containsPointNoLocal({0})); - EXPECT_FALSE(poly1.containsPointNoLocal({1})); + EXPECT_TRUE(containsPointNoLocal(poly1, {0})); + EXPECT_FALSE(containsPointNoLocal(poly1, {1})); IntegerPolyhedron poly2 = parsePoly( "(x) : (x - 2*(x floordiv 2) == 0, x - 4*(x floordiv 4) - 2 == 0)"); - EXPECT_TRUE(poly2.containsPointNoLocal({6})); - EXPECT_FALSE(poly2.containsPointNoLocal({4})); + EXPECT_TRUE(containsPointNoLocal(poly2, {6})); + EXPECT_FALSE(containsPointNoLocal(poly2, {4})); IntegerPolyhedron poly3 = parsePoly("(x, y) : (2*x - y >= 0, y - 3*x >= 0)"); - EXPECT_TRUE(poly3.containsPointNoLocal({0, 0})); - EXPECT_FALSE(poly3.containsPointNoLocal({1, 0})); + EXPECT_TRUE(containsPointNoLocal(poly3, {0, 0})); + EXPECT_FALSE(containsPointNoLocal(poly3, {1, 0})); } TEST(IntegerPolyhedronTest, truncateEqualityRegressionTest) { diff --git a/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp b/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp --- a/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp +++ b/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp @@ -23,8 +23,7 @@ // In column echelon form, each row's last non-zero value can be at most one // column to the right of the last non-zero column among the previous rows. for (unsigned row = 0, nRows = m.getNumRows(); row < nRows; ++row) { - SmallVector rowVec = - transform.preMultiplyWithRow(m.getRow(row)); + SmallVector rowVec = transform.preMultiplyWithRow(m.getRow(row)); for (unsigned col = lastAllowedNonZeroCol + 1, nCols = m.getNumColumns(); col < nCols; ++col) { EXPECT_EQ(rowVec[col], 0); diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -780,8 +780,8 @@ static void expectComputedVolumeIsValidOverapprox(const PresburgerSet &set, - Optional trueVolume, - Optional resultBound) { + Optional trueVolume, + Optional resultBound) { expectComputedVolumeIsValidOverapprox(set.computeVolume(), trueVolume, resultBound); } diff --git a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp --- a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp +++ b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp @@ -17,6 +17,30 @@ using namespace mlir; using namespace presburger; +/// Convenience functions to pass literals to Simplex. +void addInequality(SimplexBase &simplex, ArrayRef coeffs) { + simplex.addInequality(getMPIntVec(coeffs)); +} +void addEquality(SimplexBase &simplex, ArrayRef coeffs) { + simplex.addEquality(getMPIntVec(coeffs)); +} +bool isRedundantInequality(Simplex &simplex, ArrayRef coeffs) { + return simplex.isRedundantInequality(getMPIntVec(coeffs)); +} +bool isRedundantInequality(LexSimplex &simplex, ArrayRef coeffs) { + return simplex.isRedundantInequality(getMPIntVec(coeffs)); +} +bool isRedundantEquality(Simplex &simplex, ArrayRef coeffs) { + return simplex.isRedundantEquality(getMPIntVec(coeffs)); +} +bool isSeparateInequality(LexSimplex &simplex, ArrayRef coeffs) { + return simplex.isSeparateInequality(getMPIntVec(coeffs)); +} + +Simplex::IneqType findIneqType(Simplex &simplex, ArrayRef coeffs) { + return simplex.findIneqType(getMPIntVec(coeffs)); +} + /// Take a snapshot, add constraints making the set empty, and rollback. /// The set should not be empty after rolling back. We add additional /// constraints after the set is already empty and roll back the addition @@ -25,17 +49,17 @@ TEST(SimplexTest, emptyRollback) { Simplex simplex(2); // (u - v) >= 0 - simplex.addInequality({1, -1, 0}); + addInequality(simplex, {1, -1, 0}); ASSERT_FALSE(simplex.isEmpty()); unsigned snapshot = simplex.getSnapshot(); // (u - v) <= -1 - simplex.addInequality({-1, 1, -1}); + addInequality(simplex, {-1, 1, -1}); ASSERT_TRUE(simplex.isEmpty()); unsigned snapshot2 = simplex.getSnapshot(); // (u - v) <= -3 - simplex.addInequality({-1, 1, -3}); + addInequality(simplex, {-1, 1, -3}); ASSERT_TRUE(simplex.isEmpty()); simplex.rollback(snapshot2); @@ -49,9 +73,9 @@ /// constraints. TEST(SimplexTest, addEquality_separate) { Simplex simplex(1); - simplex.addInequality({1, -1}); // x >= 1. + addInequality(simplex, {1, -1}); // x >= 1. ASSERT_FALSE(simplex.isEmpty()); - simplex.addEquality({1, 0}); // x == 0. + addEquality(simplex, {1, 0}); // x == 0. EXPECT_TRUE(simplex.isEmpty()); } @@ -59,7 +83,7 @@ bool expect) { ASSERT_FALSE(simplex.isEmpty()); unsigned snapshot = simplex.getSnapshot(); - simplex.addInequality(coeffs); + addInequality(simplex, coeffs); EXPECT_EQ(simplex.isEmpty(), expect); simplex.rollback(snapshot); } @@ -82,7 +106,7 @@ expectInequalityMakesSetEmpty(simplex, checkCoeffs[1], false); for (int i = 0; i < 4; i++) - simplex.addInequality(coeffs[(run + i) % 4]); + addInequality(simplex, coeffs[(run + i) % 4]); expectInequalityMakesSetEmpty(simplex, checkCoeffs[0], true); expectInequalityMakesSetEmpty(simplex, checkCoeffs[1], true); @@ -100,9 +124,9 @@ ArrayRef> eqs) { Simplex simplex(nDim); for (const auto &ineq : ineqs) - simplex.addInequality(ineq); + addInequality(simplex, ineq); for (const auto &eq : eqs) - simplex.addEquality(eq); + addEquality(simplex, eq); return simplex; } @@ -235,7 +259,7 @@ /// Some basic sanity checks involving zero or one variables. TEST(SimplexTest, isMarkedRedundant_no_var_ge_zero) { Simplex simplex(0); - simplex.addInequality({0}); // 0 >= 0. + addInequality(simplex, {0}); // 0 >= 0. simplex.detectRedundant(); ASSERT_FALSE(simplex.isEmpty()); @@ -244,7 +268,7 @@ TEST(SimplexTest, isMarkedRedundant_no_var_eq) { Simplex simplex(0); - simplex.addEquality({0}); // 0 == 0. + addEquality(simplex, {0}); // 0 == 0. simplex.detectRedundant(); ASSERT_FALSE(simplex.isEmpty()); EXPECT_TRUE(simplex.isMarkedRedundant(0)); @@ -252,7 +276,7 @@ TEST(SimplexTest, isMarkedRedundant_pos_var_eq) { Simplex simplex(1); - simplex.addEquality({1, 0}); // x == 0. + addEquality(simplex, {1, 0}); // x == 0. simplex.detectRedundant(); ASSERT_FALSE(simplex.isEmpty()); @@ -261,7 +285,7 @@ TEST(SimplexTest, isMarkedRedundant_zero_var_eq) { Simplex simplex(1); - simplex.addEquality({0, 0}); // 0x == 0. + addEquality(simplex, {0, 0}); // 0x == 0. simplex.detectRedundant(); ASSERT_FALSE(simplex.isEmpty()); EXPECT_TRUE(simplex.isMarkedRedundant(0)); @@ -269,7 +293,7 @@ TEST(SimplexTest, isMarkedRedundant_neg_var_eq) { Simplex simplex(1); - simplex.addEquality({-1, 0}); // -x == 0. + addEquality(simplex, {-1, 0}); // -x == 0. simplex.detectRedundant(); ASSERT_FALSE(simplex.isEmpty()); EXPECT_FALSE(simplex.isMarkedRedundant(0)); @@ -277,7 +301,7 @@ TEST(SimplexTest, isMarkedRedundant_pos_var_ge) { Simplex simplex(1); - simplex.addInequality({1, 0}); // x >= 0. + addInequality(simplex, {1, 0}); // x >= 0. simplex.detectRedundant(); ASSERT_FALSE(simplex.isEmpty()); EXPECT_FALSE(simplex.isMarkedRedundant(0)); @@ -285,7 +309,7 @@ TEST(SimplexTest, isMarkedRedundant_zero_var_ge) { Simplex simplex(1); - simplex.addInequality({0, 0}); // 0x >= 0. + addInequality(simplex, {0, 0}); // 0x >= 0. simplex.detectRedundant(); ASSERT_FALSE(simplex.isEmpty()); EXPECT_TRUE(simplex.isMarkedRedundant(0)); @@ -293,7 +317,7 @@ TEST(SimplexTest, isMarkedRedundant_neg_var_ge) { Simplex simplex(1); - simplex.addInequality({-1, 0}); // x <= 0. + addInequality(simplex, {-1, 0}); // x <= 0. simplex.detectRedundant(); ASSERT_FALSE(simplex.isEmpty()); EXPECT_FALSE(simplex.isMarkedRedundant(0)); @@ -304,9 +328,9 @@ TEST(SimplexTest, isMarkedRedundant_no_redundant) { Simplex simplex(3); - simplex.addEquality({-1, 0, 1, 0}); // u = w. - simplex.addInequality({-1, 16, 0, 15}); // 15 - (u - 16v) >= 0. - simplex.addInequality({1, -16, 0, 0}); // (u - 16v) >= 0. + addEquality(simplex, {-1, 0, 1, 0}); // u = w. + addInequality(simplex, {-1, 16, 0, 15}); // 15 - (u - 16v) >= 0. + addInequality(simplex, {1, -16, 0, 0}); // (u - 16v) >= 0. simplex.detectRedundant(); ASSERT_FALSE(simplex.isEmpty()); @@ -319,14 +343,14 @@ Simplex simplex(3); // [4] to [7] are repeats of [0] to [3]. - simplex.addInequality({0, -1, 0, 1}); // [0]: y <= 1. - simplex.addInequality({-1, 0, 8, 7}); // [1]: 8z >= x - 7. - simplex.addInequality({1, 0, -8, 0}); // [2]: 8z <= x. - simplex.addInequality({0, 1, 0, 0}); // [3]: y >= 0. - simplex.addInequality({-1, 0, 8, 7}); // [4]: 8z >= 7 - x. - simplex.addInequality({1, 0, -8, 0}); // [5]: 8z <= x. - simplex.addInequality({0, 1, 0, 0}); // [6]: y >= 0. - simplex.addInequality({0, -1, 0, 1}); // [7]: y <= 1. + addInequality(simplex, {0, -1, 0, 1}); // [0]: y <= 1. + addInequality(simplex, {-1, 0, 8, 7}); // [1]: 8z >= x - 7. + addInequality(simplex, {1, 0, -8, 0}); // [2]: 8z <= x. + addInequality(simplex, {0, 1, 0, 0}); // [3]: y >= 0. + addInequality(simplex, {-1, 0, 8, 7}); // [4]: 8z >= 7 - x. + addInequality(simplex, {1, 0, -8, 0}); // [5]: 8z <= x. + addInequality(simplex, {0, 1, 0, 0}); // [6]: y >= 0. + addInequality(simplex, {0, -1, 0, 1}); // [7]: y <= 1. simplex.detectRedundant(); ASSERT_FALSE(simplex.isEmpty()); @@ -343,14 +367,14 @@ TEST(SimplexTest, isMarkedRedundant) { Simplex simplex(3); - simplex.addInequality({0, -1, 0, 1}); // [0]: y <= 1. - simplex.addInequality({1, 0, 0, -1}); // [1]: x >= 1. - simplex.addInequality({-1, 0, 0, 2}); // [2]: x <= 2. - simplex.addInequality({-1, 0, 2, 7}); // [3]: 2z >= x - 7. - simplex.addInequality({1, 0, -2, 0}); // [4]: 2z <= x. - simplex.addInequality({0, 1, 0, 0}); // [5]: y >= 0. - simplex.addInequality({0, 1, -2, 1}); // [6]: y >= 2z - 1. - simplex.addInequality({-1, 1, 0, 1}); // [7]: y >= x - 1. + addInequality(simplex, {0, -1, 0, 1}); // [0]: y <= 1. + addInequality(simplex, {1, 0, 0, -1}); // [1]: x >= 1. + addInequality(simplex, {-1, 0, 0, 2}); // [2]: x <= 2. + addInequality(simplex, {-1, 0, 2, 7}); // [3]: 2z >= x - 7. + addInequality(simplex, {1, 0, -2, 0}); // [4]: 2z <= x. + addInequality(simplex, {0, 1, 0, 0}); // [5]: y >= 0. + addInequality(simplex, {0, 1, -2, 1}); // [6]: y >= 2z - 1. + addInequality(simplex, {-1, 1, 0, 1}); // [7]: y >= x - 1. simplex.detectRedundant(); ASSERT_FALSE(simplex.isEmpty()); @@ -372,12 +396,12 @@ TEST(SimplexTest, isMarkedRedundantTiledLoopNestConstraints) { Simplex simplex(3); // Variables are x, y, N. - simplex.addInequality({1, 0, 0, 0}); // [0]: x >= 0. - simplex.addInequality({-32, 0, 1, -1}); // [1]: 32x <= N - 1. - simplex.addInequality({0, 1, 0, 0}); // [2]: y >= 0. - simplex.addInequality({-32, 1, 0, 0}); // [3]: y >= 32x. - simplex.addInequality({32, -1, 0, 31}); // [4]: y <= 32x + 31. - simplex.addInequality({0, -1, 1, -1}); // [5]: y <= N - 1. + addInequality(simplex, {1, 0, 0, 0}); // [0]: x >= 0. + addInequality(simplex, {-32, 0, 1, -1}); // [1]: 32x <= N - 1. + addInequality(simplex, {0, 1, 0, 0}); // [2]: y >= 0. + addInequality(simplex, {-32, 1, 0, 0}); // [3]: y >= 32x. + addInequality(simplex, {32, -1, 0, 31}); // [4]: y <= 32x + 31. + addInequality(simplex, {0, -1, 1, -1}); // [5]: y <= N - 1. // [3] and [0] imply [2], as we have y >= 32x >= 0. // [3] and [5] imply [1], as we have 32x <= y <= N - 1. simplex.detectRedundant(); @@ -391,11 +415,11 @@ TEST(SimplexTest, pivotRedundantRegressionTest) { Simplex simplex(2); - simplex.addInequality({-1, 0, -1}); // x <= -1. + addInequality(simplex, {-1, 0, -1}); // x <= -1. unsigned snapshot = simplex.getSnapshot(); - simplex.addInequality({-1, 0, -2}); // x <= -2. - simplex.addInequality({-3, 0, -6}); + addInequality(simplex, {-1, 0, -2}); // x <= -2. + addInequality(simplex, {-3, 0, -6}); // This first marks x <= -1 as redundant. Then it performs some more pivots // to check if the other constraints are redundant. Pivot must update the @@ -408,14 +432,14 @@ // The maximum value of x should be -1. simplex.rollback(snapshot); MaybeOptimum maxX = - simplex.computeOptimum(Simplex::Direction::Up, {1, 0, 0}); + simplex.computeOptimum(Simplex::Direction::Up, getMPIntVec({1, 0, 0})); EXPECT_TRUE(maxX.isBounded() && *maxX == Fraction(-1, 1)); } TEST(SimplexTest, addInequality_already_redundant) { Simplex simplex(1); - simplex.addInequality({1, -1}); // x >= 1. - simplex.addInequality({1, 0}); // x >= 0. + addInequality(simplex, {1, -1}); // x >= 1. + addInequality(simplex, {1, 0}); // x >= 0. simplex.detectRedundant(); ASSERT_FALSE(simplex.isEmpty()); EXPECT_FALSE(simplex.isMarkedRedundant(0)); @@ -431,8 +455,8 @@ EXPECT_EQ(simplex.getNumVariables(), 2u); int64_t yMin = 2, yMax = 5; - simplex.addInequality({0, 1, -yMin}); // y >= 2. - simplex.addInequality({0, -1, yMax}); // y <= 5. + addInequality(simplex, {0, 1, -yMin}); // y >= 2. + addInequality(simplex, {0, -1, yMax}); // y <= 5. unsigned snapshot2 = simplex.getSnapshot(); simplex.appendVariable(2); @@ -441,9 +465,9 @@ EXPECT_EQ(simplex.getNumVariables(), 2u); EXPECT_EQ(simplex.getNumConstraints(), 2u); - EXPECT_EQ( - simplex.computeIntegerBounds({0, 1, 0}), - std::make_pair(MaybeOptimum(yMin), MaybeOptimum(yMax))); + EXPECT_EQ(simplex.computeIntegerBounds(getMPIntVec({0, 1, 0})), + std::make_pair(MaybeOptimum(MPInt(yMin)), + MaybeOptimum(MPInt(yMax)))); simplex.rollback(snapshot1); EXPECT_EQ(simplex.getNumVariables(), 1u); @@ -452,54 +476,54 @@ TEST(SimplexTest, isRedundantInequality) { Simplex simplex(2); - simplex.addInequality({0, -1, 2}); // y <= 2. - simplex.addInequality({1, 0, 0}); // x >= 0. - simplex.addEquality({-1, 1, 0}); // y = x. + addInequality(simplex, {0, -1, 2}); // y <= 2. + addInequality(simplex, {1, 0, 0}); // x >= 0. + addEquality(simplex, {-1, 1, 0}); // y = x. - EXPECT_TRUE(simplex.isRedundantInequality({-1, 0, 2})); // x <= 2. - EXPECT_TRUE(simplex.isRedundantInequality({0, 1, 0})); // y >= 0. + EXPECT_TRUE(isRedundantInequality(simplex, {-1, 0, 2})); // x <= 2. + EXPECT_TRUE(isRedundantInequality(simplex, {0, 1, 0})); // y >= 0. - EXPECT_FALSE(simplex.isRedundantInequality({-1, 0, -1})); // x <= -1. - EXPECT_FALSE(simplex.isRedundantInequality({0, 1, -2})); // y >= 2. - EXPECT_FALSE(simplex.isRedundantInequality({0, 1, -1})); // y >= 1. + EXPECT_FALSE(isRedundantInequality(simplex, {-1, 0, -1})); // x <= -1. + EXPECT_FALSE(isRedundantInequality(simplex, {0, 1, -2})); // y >= 2. + EXPECT_FALSE(isRedundantInequality(simplex, {0, 1, -1})); // y >= 1. } TEST(SimplexTest, ineqType) { Simplex simplex(2); - simplex.addInequality({0, -1, 2}); // y <= 2. - simplex.addInequality({1, 0, 0}); // x >= 0. - simplex.addEquality({-1, 1, 0}); // y = x. - - EXPECT_TRUE(simplex.findIneqType({-1, 0, 2}) == - Simplex::IneqType::Redundant); // x <= 2. - EXPECT_TRUE(simplex.findIneqType({0, 1, 0}) == - Simplex::IneqType::Redundant); // y >= 0. - - EXPECT_TRUE(simplex.findIneqType({0, 1, -1}) == - Simplex::IneqType::Cut); // y >= 1. - EXPECT_TRUE(simplex.findIneqType({-1, 0, 1}) == - Simplex::IneqType::Cut); // x <= 1. - EXPECT_TRUE(simplex.findIneqType({0, 1, -2}) == - Simplex::IneqType::Cut); // y >= 2. - - EXPECT_TRUE(simplex.findIneqType({-1, 0, -1}) == - Simplex::IneqType::Separate); // x <= -1. + addInequality(simplex, {0, -1, 2}); // y <= 2. + addInequality(simplex, {1, 0, 0}); // x >= 0. + addEquality(simplex, {-1, 1, 0}); // y = x. + + EXPECT_EQ(findIneqType(simplex, {-1, 0, 2}), + Simplex::IneqType::Redundant); // x <= 2. + EXPECT_EQ(findIneqType(simplex, {0, 1, 0}), + Simplex::IneqType::Redundant); // y >= 0. + + EXPECT_EQ(findIneqType(simplex, {0, 1, -1}), + Simplex::IneqType::Cut); // y >= 1. + EXPECT_EQ(findIneqType(simplex, {-1, 0, 1}), + Simplex::IneqType::Cut); // x <= 1. + EXPECT_EQ(findIneqType(simplex, {0, 1, -2}), + Simplex::IneqType::Cut); // y >= 2. + + EXPECT_EQ(findIneqType(simplex, {-1, 0, -1}), + Simplex::IneqType::Separate); // x <= -1. } TEST(SimplexTest, isRedundantEquality) { Simplex simplex(2); - simplex.addInequality({0, -1, 2}); // y <= 2. - simplex.addInequality({1, 0, 0}); // x >= 0. - simplex.addEquality({-1, 1, 0}); // y = x. + addInequality(simplex, {0, -1, 2}); // y <= 2. + addInequality(simplex, {1, 0, 0}); // x >= 0. + addEquality(simplex, {-1, 1, 0}); // y = x. - EXPECT_TRUE(simplex.isRedundantEquality({-1, 1, 0})); // y = x. - EXPECT_TRUE(simplex.isRedundantEquality({1, -1, 0})); // x = y. + EXPECT_TRUE(isRedundantEquality(simplex, {-1, 1, 0})); // y = x. + EXPECT_TRUE(isRedundantEquality(simplex, {1, -1, 0})); // x = y. - EXPECT_FALSE(simplex.isRedundantEquality({0, 1, -1})); // y = 1. + EXPECT_FALSE(isRedundantEquality(simplex, {0, 1, -1})); // y = 1. - simplex.addEquality({0, -1, 2}); // y = 2. + addEquality(simplex, {0, -1, 2}); // y = 2. - EXPECT_TRUE(simplex.isRedundantEquality({-1, 0, 2})); // x = 2. + EXPECT_TRUE(isRedundantEquality(simplex, {-1, 0, 2})); // x = 2. } TEST(SimplexTest, IsRationalSubsetOf) { @@ -541,27 +565,27 @@ TEST(SimplexTest, addDivisionVariable) { Simplex simplex(/*nVar=*/1); - simplex.addDivisionVariable({1, 0}, 2); - simplex.addInequality({1, 0, -3}); // x >= 3. - simplex.addInequality({-1, 0, 9}); // x <= 9. - Optional> sample = simplex.findIntegerSample(); + simplex.addDivisionVariable(getMPIntVec({1, 0}), MPInt(2)); + addInequality(simplex, {1, 0, -3}); // x >= 3. + addInequality(simplex, {-1, 0, 9}); // x <= 9. + Optional> sample = simplex.findIntegerSample(); ASSERT_TRUE(sample.has_value()); EXPECT_EQ((*sample)[0] / 2, (*sample)[1]); } TEST(SimplexTest, LexIneqType) { LexSimplex simplex(/*nVar=*/1); - simplex.addInequality({2, -1}); // x >= 1/2. + addInequality(simplex, {2, -1}); // x >= 1/2. // Redundant inequality x >= 2/3. - EXPECT_TRUE(simplex.isRedundantInequality({3, -2})); - EXPECT_FALSE(simplex.isSeparateInequality({3, -2})); + EXPECT_TRUE(isRedundantInequality(simplex, {3, -2})); + EXPECT_FALSE(isSeparateInequality(simplex, {3, -2})); // Separate inequality x <= 2/3. - EXPECT_FALSE(simplex.isRedundantInequality({-3, 2})); - EXPECT_TRUE(simplex.isSeparateInequality({-3, 2})); + EXPECT_FALSE(isRedundantInequality(simplex, {-3, 2})); + EXPECT_TRUE(isSeparateInequality(simplex, {-3, 2})); // Cut inequality x <= 1. - EXPECT_FALSE(simplex.isRedundantInequality({-1, 1})); - EXPECT_FALSE(simplex.isSeparateInequality({-1, 1})); + EXPECT_FALSE(isRedundantInequality(simplex, {-1, 1})); + EXPECT_FALSE(isSeparateInequality(simplex, {-1, 1})); } diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h --- a/mlir/unittests/Analysis/Presburger/Utils.h +++ b/mlir/unittests/Analysis/Presburger/Utils.h @@ -87,7 +87,7 @@ /// lhs and rhs represent non-negative integers or positive infinity. The /// infinity case corresponds to when the Optional is empty. -inline bool infinityOrUInt64LE(Optional lhs, Optional rhs) { +inline bool infinityOrUInt64LE(Optional lhs, Optional rhs) { // No constraint. if (!rhs) return true; @@ -101,15 +101,24 @@ /// the true volume `trueVolume`, while also being at least as good an /// approximation as `resultBound`. inline void -expectComputedVolumeIsValidOverapprox(Optional computedVolume, - Optional trueVolume, - Optional resultBound) { +expectComputedVolumeIsValidOverapprox(const Optional &computedVolume, + const Optional &trueVolume, + const Optional &resultBound) { assert(infinityOrUInt64LE(trueVolume, resultBound) && "can't expect result to be less than the true volume"); EXPECT_TRUE(infinityOrUInt64LE(trueVolume, computedVolume)); EXPECT_TRUE(infinityOrUInt64LE(computedVolume, resultBound)); } +inline void +expectComputedVolumeIsValidOverapprox(const Optional &computedVolume, + Optional trueVolume, + Optional resultBound) { + expectComputedVolumeIsValidOverapprox(computedVolume, + trueVolume.map(mpintFromInt64), + resultBound.map(mpintFromInt64)); +} + } // namespace presburger } // namespace mlir