diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -23,6 +23,7 @@ #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/StringSaver.h" #include "llvm/Support/raw_ostream.h" @@ -210,14 +211,18 @@ protected: /// Construct a SimplexBase with the specified number of variables and fixed - /// columns. + /// columns. The first overload should be used when there are nosymbols. + /// With the second overload, the specified range of vars will be marked + /// as symbols. With the third overload, `isSymbol` is a bitmask denoting + /// which vars are symbols. The size of `isSymbol` must be `nVar`. /// /// For example, Simplex uses two fixed columns: the denominator and the /// constant term, whereas LexSimplex has an extra fixed column for the /// so-called big M parameter. For more information see the documentation for /// LexSimplex. - SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset, - unsigned nSymbol); + SimplexBase(unsigned nVar, bool mustUseBigM); + SimplexBase(unsigned nVar, bool mustUseBigM, + const llvm::SmallBitVector &isSymbol); enum class Orientation { Row, Column }; @@ -422,12 +427,16 @@ unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); } protected: - LexSimplexBase(unsigned nVar, unsigned symbolOffset, unsigned nSymbol) - : SimplexBase(nVar, /*mustUseBigM=*/true, symbolOffset, nSymbol) {} + LexSimplexBase(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/true) {} + LexSimplexBase(unsigned nVar, const llvm::SmallBitVector &isSymbol) + : SimplexBase(nVar, /*mustUseBigM=*/true, isSymbol) {} explicit LexSimplexBase(const IntegerRelation &constraints) - : LexSimplexBase(constraints.getNumVars(), - constraints.getVarKindOffset(VarKind::Symbol), - constraints.getNumSymbolVars()) { + : LexSimplexBase(constraints.getNumVars()) { + intersectIntegerRelation(constraints); + } + explicit LexSimplexBase(const IntegerRelation &constraints, + const llvm::SmallBitVector &isSymbol) + : LexSimplexBase(constraints.getNumVars(), isSymbol) { intersectIntegerRelation(constraints); } @@ -470,13 +479,12 @@ /// provides support for integer-exact redundancy and separateness checks. class LexSimplex : public LexSimplexBase { public: - explicit LexSimplex(unsigned nVar) - : LexSimplexBase(nVar, /*symbolOffset=*/0, /*nSymbol=*/0) {} + explicit LexSimplex(unsigned nVar) : LexSimplexBase(nVar) {} + // Note that LexSimplex does NOT support symbolic lexmin; + // use SymbolicLexSimplex if that is required. LexSimplex ignores the VarKinds + // of the passed IntegerRelation. Symbols will be treated as ordinary vars. explicit LexSimplex(const IntegerRelation &constraints) - : LexSimplexBase(constraints) { - assert(constraints.getNumSymbolVars() == 0 && - "LexSimplex does not support symbols!"); - } + : LexSimplexBase(constraints) {} /// Return the lexicographically minimum rational solution to the constraints. MaybeOptimum> findRationalLexMin(); @@ -521,10 +529,9 @@ /// Represents the result of a symbolic lexicographic minimization computation. struct SymbolicLexMin { - SymbolicLexMin(unsigned nSymbols, unsigned nNonSymbols) - : lexmin(PresburgerSpace::getSetSpace(nSymbols), nNonSymbols), - unboundedDomain( - PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(nSymbols))) {} + SymbolicLexMin(const PresburgerSpace &domainSpace, unsigned numOutputs) + : lexmin(domainSpace, numOutputs), + unboundedDomain(PresburgerSet::getEmpty(domainSpace)) {} /// This maps assignments of symbols to the corresponding lexmin. /// Takes no value when no integer sample exists for the assignment or if the @@ -569,31 +576,39 @@ /// `constraints` is the set for which the symbolic lexmin will be computed. /// `symbolDomain` is the set of values of the symbols for which the lexmin /// will be computed. `symbolDomain` should have a dim var for every symbol in - /// `constraints`, and no other vars. + /// `constraints`, and no other vars. `isSymbol` specifies which vars of + /// `constraints` should be considered as symbols. + /// + /// The resulting SymbolicLexMin's space will be compatible with that of + /// symbolDomain. SymbolicLexSimplex(const IntegerRelation &constraints, - const IntegerPolyhedron &symbolDomain) - : SymbolicLexSimplex(constraints, - constraints.getVarKindOffset(VarKind::Symbol), - symbolDomain) { - assert(constraints.getNumSymbolVars() == symbolDomain.getNumVars()); + const IntegerPolyhedron &symbolDomain, + const llvm::SmallBitVector &isSymbol) + : LexSimplexBase(constraints, isSymbol), domainPoly(symbolDomain), + domainSimplex(symbolDomain) { + // TODO consider supporting this case. It amounts + // to just returning the input constraints. + assert(domainPoly.getNumVars() > 0 && + "there must be some non-symbols to optimize!"); } - /// An overload to select some other subrange of ids as symbols for lexmin. + /// An overload to select some subrange of ids as symbols for lexmin. /// The symbol ids are the range of ids with absolute index /// [symbolOffset, symbolOffset + symbolDomain.getNumVars()) - /// symbolDomain should only have dim ids. SymbolicLexSimplex(const IntegerRelation &constraints, unsigned symbolOffset, const IntegerPolyhedron &symbolDomain) - : LexSimplexBase(/*nVar=*/constraints.getNumVars(), symbolOffset, - symbolDomain.getNumVars()), - domainPoly(symbolDomain), domainSimplex(symbolDomain) { - // TODO consider supporting this case. It amounts - // to just returning the input constraints. - assert(domainPoly.getNumVars() > 0 && - "there must be some non-symbols to optimize!"); - assert(domainPoly.getNumVars() == domainPoly.getNumDimVars()); - intersectIntegerRelation(constraints); - } + : SymbolicLexSimplex(constraints, symbolDomain, + getSubrangeBitVector(constraints.getNumVars(), + symbolOffset, + symbolDomain.getNumVars())) {} + + /// An overload to select the symbols of `constraints` as symbols for lexmin. + SymbolicLexSimplex(const IntegerRelation &constraints, + const IntegerPolyhedron &symbolDomain) + : SymbolicLexSimplex(constraints, constraints.getVarKindOffset(VarKind::Symbol), symbolDomain) { + assert(constraints.getNumSymbolVars() == symbolDomain.getNumVars() && "symbolDomain must have as many vars as constraints has symbols!"); + } + /// The lexmin will be stored as a function `lexmin` from symbols to /// non-symbols in the result. @@ -678,9 +693,7 @@ enum class Direction { Up, Down }; Simplex() = delete; - explicit Simplex(unsigned nVar) - : SimplexBase(nVar, /*mustUseBigM=*/false, /*symbolOffset=*/0, - /*nSymbol=*/0) {} + explicit Simplex(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/false) {} explicit Simplex(const IntegerRelation &constraints) : Simplex(constraints.getNumVars()) { intersectIntegerRelation(constraints); diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h --- a/mlir/include/mlir/Analysis/Presburger/Utils.h +++ b/mlir/include/mlir/Analysis/Presburger/Utils.h @@ -15,6 +15,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" namespace mlir { namespace presburger { @@ -120,6 +121,9 @@ SmallVector getDivLowerBound(ArrayRef dividend, int64_t divisor, unsigned localVarIdx); +llvm::SmallBitVector getSubrangeBitVector(unsigned len, unsigned setOffset, + unsigned numSet); + /// Check if the pos^th variable can be expressed as a floordiv of an affine /// function of other variables (where the divisor is a positive constant). /// `foundRepr` contains a boolean for each variable indicating if the diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -31,23 +31,28 @@ return res; } -SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset, - unsigned nSymbol) - : usingBigM(mustUseBigM), nRedundant(0), nSymbol(nSymbol), +SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM) + : usingBigM(mustUseBigM), nRedundant(0), nSymbol(0), tableau(0, getNumFixedCols() + nVar), empty(false) { - assert(symbolOffset + nSymbol <= nVar); - colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex); for (unsigned i = 0; i < nVar; ++i) { var.emplace_back(Orientation::Column, /*restricted=*/false, /*pos=*/getNumFixedCols() + i); colUnknown.push_back(i); } +} - // Move the symbols to be in columns [3, 3 + nSymbol). - for (unsigned i = 0; i < nSymbol; ++i) { - var[symbolOffset + i].isSymbol = true; - swapColumns(var[symbolOffset + i].pos, getNumFixedCols() + i); +SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, + const llvm::SmallBitVector &isSymbol) + : SimplexBase(nVar, mustUseBigM) { + assert(isSymbol.size() == nVar && "invalid bitmask!"); + // Invariant: nSymbol is the number of symbols that have been marked + // already and these occupy the columns + // [getNumFixedCols(), getNumFixedCols() + nSymbol). + for (unsigned symbolIdx : isSymbol.set_bits()) { + var[symbolIdx].isSymbol = true; + swapColumns(var[symbolIdx].pos, getNumFixedCols() + nSymbol); + ++nSymbol; } } @@ -502,7 +507,7 @@ } SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() { - SymbolicLexMin result(nSymbol, var.size() - nSymbol); + SymbolicLexMin result(domainPoly.getSpace(), var.size() - nSymbol); /// The algorithm is more naturally expressed recursively, but we implement /// it iteratively here to avoid potential issues with stack overflows in the diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp --- a/mlir/lib/Analysis/Presburger/Utils.cpp +++ b/mlir/lib/Analysis/Presburger/Utils.cpp @@ -253,6 +253,14 @@ return repr; } +llvm::SmallBitVector presburger::getSubrangeBitVector(unsigned len, + unsigned setOffset, + unsigned numSet) { + llvm::SmallBitVector vec(len, false); + vec.set(setOffset, setOffset + numSet); + return vec; +} + void presburger::removeDuplicateDivs( std::vector> &divs, SmallVectorImpl &denoms, unsigned localOffset,