diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -16,6 +16,7 @@ #include "mlir/Support/TypeID.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include "llvm/Support/raw_ostream.h" @@ -26,7 +27,21 @@ /// Kind of variable. Implementation wise SetDims are treated as Range /// vars, and spaces with no distinction between dimension vars are treated /// as relations with zero domain vars. -enum class VarKind { Symbol, Local, Domain, Range, SetDim = Range }; +enum class VarKind : unsigned { + Symbol = 1 << 0, + Domain = 1 << 1, + Range = 1 << 2, + Local = 1 << 3, + SetDim = Range +}; +/// Note: this reflects the number of distinct VarKinds. SetDim doesn't count. +const unsigned numVarKinds = 4; + +/// These are useful when a mask of VarKinds needs to be passed. +const unsigned varKindFlagSymbol = unsigned(VarKind::Symbol); +const unsigned varKindFlagDomain = unsigned(VarKind::Domain); +const unsigned varKindFlagRange = unsigned(VarKind::Range); +const unsigned varKindFlagLocal = unsigned(VarKind::Local); /// PresburgerSpace is the space of all possible values of a tuple of integer /// valued variables/variables. Each variable has one of the three types: @@ -153,6 +168,11 @@ /// split become dimensions. void setVarSymbolSeperation(unsigned newSymbolCount); + /// Given a mask representing a subset of VarKinds, return a mask representing + /// the subset of this space's variables that are of one of the specified + /// VarKinds. + llvm::SmallBitVector getVarMaskFromVarKindMask(unsigned varKindMask) const; + void print(llvm::raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -23,6 +23,7 @@ #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/StringSaver.h" #include "llvm/Support/raw_ostream.h" @@ -210,14 +211,18 @@ protected: /// Construct a SimplexBase with the specified number of variables and fixed - /// columns. + /// columns. The first overload should be used when there are nosymbols. + /// With the second overload, the specified range of vars will be marked + /// as symbols. With the third overload, `isSymbol` is a bitmask denoting + /// which vars are symbols. The size of `isSymbol` must be `nVar`. /// /// For example, Simplex uses two fixed columns: the denominator and the /// constant term, whereas LexSimplex has an extra fixed column for the /// so-called big M parameter. For more information see the documentation for /// LexSimplex. - SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset, - unsigned nSymbol); + SimplexBase(unsigned nVar, bool mustUseBigM); + SimplexBase(unsigned nVar, bool mustUseBigM, + const llvm::SmallBitVector &isSymbol); enum class Orientation { Row, Column }; @@ -422,12 +427,16 @@ unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); } protected: - LexSimplexBase(unsigned nVar, unsigned symbolOffset, unsigned nSymbol) - : SimplexBase(nVar, /*mustUseBigM=*/true, symbolOffset, nSymbol) {} + LexSimplexBase(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/true) {} + LexSimplexBase(unsigned nVar, const llvm::SmallBitVector &isSymbol) + : SimplexBase(nVar, /*mustUseBigM=*/true, isSymbol) {} explicit LexSimplexBase(const IntegerRelation &constraints) - : LexSimplexBase(constraints.getNumVars(), - constraints.getVarKindOffset(VarKind::Symbol), - constraints.getNumSymbolVars()) { + : LexSimplexBase(constraints.getNumVars()) { + intersectIntegerRelation(constraints); + } + explicit LexSimplexBase(const IntegerRelation &constraints, + const llvm::SmallBitVector &isSymbol) + : LexSimplexBase(constraints.getNumVars(), isSymbol) { intersectIntegerRelation(constraints); } @@ -470,13 +479,11 @@ /// provides support for integer-exact redundancy and separateness checks. class LexSimplex : public LexSimplexBase { public: - explicit LexSimplex(unsigned nVar) - : LexSimplexBase(nVar, /*symbolOffset=*/0, /*nSymbol=*/0) {} + explicit LexSimplex(unsigned nVar) : LexSimplexBase(nVar) {} + // Note that LexSimplex does NOT support symbols. + // Symbols in `constraints` will be IGNORED. explicit LexSimplex(const IntegerRelation &constraints) - : LexSimplexBase(constraints) { - assert(constraints.getNumSymbolVars() == 0 && - "LexSimplex does not support symbols!"); - } + : LexSimplexBase(constraints) {} /// Return the lexicographically minimum rational solution to the constraints. MaybeOptimum> findRationalLexMin(); @@ -521,10 +528,9 @@ /// Represents the result of a symbolic lexicographic minimization computation. struct SymbolicLexMin { - SymbolicLexMin(unsigned nSymbols, unsigned nNonSymbols) - : lexmin(PresburgerSpace::getSetSpace(nSymbols), nNonSymbols), - unboundedDomain( - PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(nSymbols))) {} + SymbolicLexMin(const PresburgerSpace &domainSpace, unsigned numOutputs) + : lexmin(domainSpace, numOutputs), + unboundedDomain(PresburgerSet::getEmpty(domainSpace)) {} /// This maps assignments of symbols to the corresponding lexmin. /// Takes no value when no integer sample exists for the assignment or if the @@ -569,31 +575,42 @@ /// `constraints` is the set for which the symbolic lexmin will be computed. /// `symbolDomain` is the set of values of the symbols for which the lexmin /// will be computed. `symbolDomain` should have a dim var for every symbol in - /// `constraints`, and no other vars. + /// `constraints`, and no other vars. `isSymbol` specifies which vars of + /// `constraints` should be considered as symbols. + /// + /// The resulting SymbolicLexMin's space will be compatible with that of + /// symbolDomain. SymbolicLexSimplex(const IntegerRelation &constraints, - const IntegerPolyhedron &symbolDomain) - : SymbolicLexSimplex(constraints, - constraints.getVarKindOffset(VarKind::Symbol), - symbolDomain) { - assert(constraints.getNumSymbolVars() == symbolDomain.getNumVars()); + const IntegerPolyhedron &symbolDomain, + const llvm::SmallBitVector &isSymbol) + : LexSimplexBase(constraints, isSymbol), domainPoly(symbolDomain), + domainSimplex(symbolDomain) { + // TODO consider supporting this case. It amounts + // to just returning the input constraints. + assert(domainPoly.getNumVars() > 0 && + "there must be some non-symbols to optimize!"); } + /// Same as above, but varKindMask specifies which VarKinds in `constraints` + /// should be considered as symbols; by default the symbols are considered + /// as symbols. + SymbolicLexSimplex(const IntegerRelation &constraints, + const IntegerPolyhedron &symbolDomain, + unsigned isSymbolKindMask = varKindFlagSymbol) + : SymbolicLexSimplex(constraints, symbolDomain, + constraints.getSpace().getVarMaskFromVarKindMask( + isSymbolKindMask)) {} + /// An overload to select some other subrange of ids as symbols for lexmin. /// The symbol ids are the range of ids with absolute index /// [symbolOffset, symbolOffset + symbolDomain.getNumVars()) /// symbolDomain should only have dim ids. SymbolicLexSimplex(const IntegerRelation &constraints, unsigned symbolOffset, const IntegerPolyhedron &symbolDomain) - : LexSimplexBase(/*nVar=*/constraints.getNumVars(), symbolOffset, - symbolDomain.getNumVars()), - domainPoly(symbolDomain), domainSimplex(symbolDomain) { - // TODO consider supporting this case. It amounts - // to just returning the input constraints. - assert(domainPoly.getNumVars() > 0 && - "there must be some non-symbols to optimize!"); - assert(domainPoly.getNumVars() == domainPoly.getNumDimVars()); - intersectIntegerRelation(constraints); - } + : SymbolicLexSimplex(constraints, symbolDomain, + getSubrangeBitVector(constraints.getNumVars(), + symbolOffset, + symbolDomain.getNumVars())) {} /// The lexmin will be stored as a function `lexmin` from symbols to /// non-symbols in the result. @@ -678,9 +695,7 @@ enum class Direction { Up, Down }; Simplex() = delete; - explicit Simplex(unsigned nVar) - : SimplexBase(nVar, /*mustUseBigM=*/false, /*symbolOffset=*/0, - /*nSymbol=*/0) {} + explicit Simplex(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/false) {} explicit Simplex(const IntegerRelation &constraints) : Simplex(constraints.getNumVars()) { intersectIntegerRelation(constraints); diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h --- a/mlir/include/mlir/Analysis/Presburger/Utils.h +++ b/mlir/include/mlir/Analysis/Presburger/Utils.h @@ -15,6 +15,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" namespace mlir { namespace presburger { @@ -120,6 +121,9 @@ SmallVector getDivLowerBound(ArrayRef dividend, int64_t divisor, unsigned localVarIdx); +llvm::SmallBitVector getSubrangeBitVector(unsigned len, unsigned setOffset, + unsigned numSet); + /// Check if the pos^th variable can be expressed as a floordiv of an affine /// function of other variables (where the divisor is a positive constant). /// `foundRepr` contains a boolean for each variable indicating if the diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -37,6 +37,16 @@ llvm_unreachable("VarKind does not exist!"); } +llvm::SmallBitVector +PresburgerSpace::getVarMaskFromVarKindMask(unsigned varKindMask) const { + llvm::SmallBitVector result(getNumVars(), false); + for (unsigned varKind = 1; varKind < (1 << numVarKinds); varKind <<= 1) + if (varKindMask & varKind) + result.set(getVarKindOffset(VarKind(varKind)), + getVarKindEnd(VarKind(varKind))); + return result; +} + unsigned PresburgerSpace::getVarKindEnd(VarKind kind) const { return getVarKindOffset(kind) + getNumVarKind(kind); } diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -31,23 +31,28 @@ return res; } -SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset, - unsigned nSymbol) - : usingBigM(mustUseBigM), nRedundant(0), nSymbol(nSymbol), +SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM) + : usingBigM(mustUseBigM), nRedundant(0), nSymbol(0), tableau(0, getNumFixedCols() + nVar), empty(false) { - assert(symbolOffset + nSymbol <= nVar); - colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex); for (unsigned i = 0; i < nVar; ++i) { var.emplace_back(Orientation::Column, /*restricted=*/false, /*pos=*/getNumFixedCols() + i); colUnknown.push_back(i); } +} - // Move the symbols to be in columns [3, 3 + nSymbol). - for (unsigned i = 0; i < nSymbol; ++i) { - var[symbolOffset + i].isSymbol = true; - swapColumns(var[symbolOffset + i].pos, getNumFixedCols() + i); +SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, + const llvm::SmallBitVector &isSymbol) + : SimplexBase(nVar, mustUseBigM) { + assert(isSymbol.size() == nVar && "invalid bitmask!"); + // Invariant: nSymbol is the number of symbols that have been marked + // already and these occupy the columns + // [getNumFixedCols(), getNumFixedCols() + nSymbol). + for (unsigned symbolIdx : isSymbol.set_bits()) { + var[symbolIdx].isSymbol = true; + swapColumns(var[symbolIdx].pos, getNumFixedCols() + nSymbol); + ++nSymbol; } } @@ -502,7 +507,7 @@ } SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() { - SymbolicLexMin result(nSymbol, var.size() - nSymbol); + SymbolicLexMin result(domainPoly.getSpace(), var.size() - nSymbol); /// The algorithm is more naturally expressed recursively, but we implement /// it iteratively here to avoid potential issues with stack overflows in the diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp --- a/mlir/lib/Analysis/Presburger/Utils.cpp +++ b/mlir/lib/Analysis/Presburger/Utils.cpp @@ -253,6 +253,14 @@ return repr; } +llvm::SmallBitVector presburger::getSubrangeBitVector(unsigned len, + unsigned setOffset, + unsigned numSet) { + llvm::SmallBitVector vec(len, false); + vec.set(setOffset, setOffset + numSet); + return vec; +} + void presburger::removeDuplicateDivs( std::vector> &divs, SmallVectorImpl &denoms, unsigned localOffset,