diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -57,9 +57,8 @@ /// of constraints and identifiers. IntegerRelation(unsigned numReservedInequalities, unsigned numReservedEqualities, unsigned numReservedCols, - unsigned numDomain, unsigned numRange, unsigned numSymbols, - unsigned numLocals) - : PresburgerSpace(numDomain, numRange, numSymbols, numLocals), + const PresburgerSpace &space) + : PresburgerSpace(space), equalities(0, getNumIds() + 1, numReservedEqualities, numReservedCols), inequalities(0, getNumIds() + 1, numReservedInequalities, numReservedCols) { @@ -67,20 +66,15 @@ } /// Constructs a relation with the specified number of dimensions and symbols. - IntegerRelation(unsigned numDomain = 0, unsigned numRange = 0, - unsigned numSymbols = 0, unsigned numLocals = 0) + IntegerRelation(const PresburgerSpace &space) : IntegerRelation(/*numReservedInequalities=*/0, /*numReservedEqualities=*/0, - /*numReservedCols=*/numDomain + numRange + numSymbols + - numLocals + 1, - numDomain, numRange, numSymbols, numLocals) {} + /*numReservedCols=*/space.getNumIds() + 1, space) {} /// Return a system with no constraints, i.e., one which is satisfied by all /// points. - static IntegerRelation getUniverse(unsigned numDomain = 0, - unsigned numRange = 0, - unsigned numSymbols = 0) { - return IntegerRelation(numDomain, numRange, numSymbols); + static IntegerRelation getUniverse(const PresburgerSpace &space) { + return IntegerRelation(space); } /// Return the kind of this IntegerRelation. @@ -562,25 +556,24 @@ /// of constraints and identifiers. IntegerPolyhedron(unsigned numReservedInequalities, unsigned numReservedEqualities, unsigned numReservedCols, - unsigned numDims, unsigned numSymbols, unsigned numLocals) + const PresburgerSpace &space) : IntegerRelation(numReservedInequalities, numReservedEqualities, - numReservedCols, /*numDomain=*/0, /*numRange=*/numDims, - numSymbols, numLocals) {} + numReservedCols, space) { + assert(space.getNumDomainIds() == 0 && + "Number of domain id's should be zero in Set kind space."); + } - /// Constructs a relation with the specified number of dimensions and symbols. - IntegerPolyhedron(unsigned numDims = 0, unsigned numSymbols = 0, - unsigned numLocals = 0) + /// Constructs a relation with the specified number of dimensions and + /// symbols. + IntegerPolyhedron(const PresburgerSpace &space) : IntegerPolyhedron(/*numReservedInequalities=*/0, /*numReservedEqualities=*/0, - /*numReservedCols=*/numDims + numSymbols + numLocals + - 1, - numDims, numSymbols, numLocals) {} + /*numReservedCols=*/space.getNumIds() + 1, space) {} /// Return a system with no constraints, i.e., one which is satisfied by all /// points. - static IntegerPolyhedron getUniverse(unsigned numDims = 0, - unsigned numSymbols = 0) { - return IntegerPolyhedron(numDims, numSymbols); + static IntegerPolyhedron getUniverse(const PresburgerSpace &space) { + return IntegerPolyhedron(space); } /// Return the kind of this IntegerRelation. diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -57,9 +57,8 @@ MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output) : IntegerPolyhedron(domain), output(output) {} - MultiAffineFunction(const Matrix &output, unsigned numDims, - unsigned numSymbols = 0, unsigned numLocals = 0) - : IntegerPolyhedron(numDims, numSymbols, numLocals), output(output) {} + MultiAffineFunction(const Matrix &output, const PresburgerSpace &space) + : IntegerPolyhedron(space), output(output) {} ~MultiAffineFunction() override = default; Kind getKind() const override { return Kind::MultiAffineFunction; } @@ -137,10 +136,10 @@ /// finding the value of the function at a point. class PWMAFunction : public PresburgerSpace { public: - PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs) - : PresburgerSpace(/*numDomain=*/0, /*numRange=*/numDims, numSymbols, - /*numLocals=*/0), - numOutputs(numOutputs) { + PWMAFunction(const PresburgerSpace &space, unsigned numOutputs) + : PresburgerSpace(space), numOutputs(numOutputs) { + assert(getNumDomainIds() == 0 && "Set type space should zero domain ids."); + assert(getNumLocalIds() == 0 && "PWMAFunction cannot have local ids."); assert(numOutputs >= 1 && "The function must output something!"); } diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -37,13 +37,10 @@ class PresburgerRelation : public PresburgerSpace { public: /// Return a universe set of the specified type that contains all points. - static PresburgerRelation getUniverse(unsigned numDomain, unsigned numRange, - unsigned numSymbols); + static PresburgerRelation getUniverse(const PresburgerSpace &space); /// Return an empty set of the specified type that contains no points. - static PresburgerRelation getEmpty(unsigned numDomain = 0, - unsigned numRange = 0, - unsigned numSymbols = 0); + static PresburgerRelation getEmpty(const PresburgerSpace &space); explicit PresburgerRelation(const IntegerRelation &disjunct); @@ -119,9 +116,10 @@ protected: /// Construct an empty PresburgerRelation with the specified number of /// dimension and symbols. - PresburgerRelation(unsigned numDomain = 0, unsigned numRange = 0, - unsigned numSymbols = 0) - : PresburgerSpace(numDomain, numRange, numSymbols, /*numLocals=*/0) {} + PresburgerRelation(const PresburgerSpace &space) : PresburgerSpace(space) { + assert(space.getNumLocalIds() == 0 && + "PresburgerRelation cannot have local ids."); + } /// The list of disjuncts that this set is the union of. SmallVector integerRelations; @@ -132,11 +130,10 @@ class PresburgerSet : public PresburgerRelation { public: /// Return a universe set of the specified type that contains all points. - static PresburgerSet getUniverse(unsigned numDims = 0, - unsigned numSymbols = 0); + static PresburgerSet getUniverse(const PresburgerSpace &space); /// Return an empty set of the specified type that contains no points. - static PresburgerSet getEmpty(unsigned numDims = 0, unsigned numSymbols = 0); + static PresburgerSet getEmpty(const PresburgerSpace &space); /// Create a set from a relation. explicit PresburgerSet(const IntegerPolyhedron &disjunct); @@ -154,8 +151,11 @@ protected: /// Construct an empty PresburgerRelation with the specified number of /// dimension and symbols. - PresburgerSet(unsigned numDims = 0, unsigned numSymbols = 0) - : PresburgerRelation(/*numDomain=*/0, numDims, numSymbols) {} + PresburgerSet(const PresburgerSpace &space) : PresburgerRelation(space) { + assert(space.getNumDomainIds() == 0 && "Set type cannot have domain ids."); + assert(space.getNumLocalIds() == 0 && + "PresburgerRelation cannot have local ids."); + } }; } // namespace presburger diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -64,10 +64,24 @@ /// identifiers of each kind are equal. class PresburgerSpace { public: - PresburgerSpace(unsigned numDomain = 0, unsigned numRange = 0, - unsigned numSymbols = 0, unsigned numLocals = 0) - : numDomain(numDomain), numRange(numRange), numSymbols(numSymbols), - numLocals(numLocals) {} + static PresburgerSpace getRelationSpace(unsigned numDomain = 0, + unsigned numRange = 0, + unsigned numSymbols = 0, + unsigned numLocals = 0) { + return PresburgerSpace(numDomain, numRange, numSymbols, numLocals); + } + + static PresburgerSpace getSetSpace(unsigned numDims = 0, + unsigned numSymbols = 0, + unsigned numLocals = 0) { + return PresburgerSpace(/*numDomain=*/0, /*numRange=*/numDims, numSymbols, + numLocals); + } + + PresburgerSpace getSpace() const { return *this; } + PresburgerSpace getCompatibleSpace() const { + return PresburgerSpace(numDomain, numRange, numSymbols); + } virtual ~PresburgerSpace() = default; @@ -99,6 +113,9 @@ unsigned getIdKindOverlap(IdKind kind, unsigned idStart, unsigned idLimit) const; + /// Return the IdKind of the id at the specified position. + IdKind getIdKindAt(unsigned pos) const; + /// Insert `num` identifiers of the specified kind at position `pos`. /// Positions are relative to the kind of identifier. Return the absolute /// column position (i.e., not relative to the kind of identifier) of the @@ -131,6 +148,12 @@ void print(llvm::raw_ostream &os) const; void dump() const; +protected: + PresburgerSpace(unsigned numDomain = 0, unsigned numRange = 0, + unsigned numSymbols = 0, unsigned numLocals = 0) + : numDomain(numDomain), numRange(numRange), numSymbols(numSymbols), + numLocals(numLocals) {} + private: // Number of identifiers corresponding to domain identifiers. unsigned numDomain; diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h --- a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h @@ -65,18 +65,19 @@ unsigned numReservedEqualities, unsigned numReservedCols, unsigned numDims, unsigned numSymbols, unsigned numLocals) - : IntegerPolyhedron(numReservedInequalities, numReservedEqualities, - numReservedCols, numDims, numSymbols, numLocals) {} + : IntegerPolyhedron( + numReservedInequalities, numReservedEqualities, numReservedCols, + PresburgerSpace::getSetSpace(numDims, numSymbols, numLocals)) {} /// Constructs a constraint system with the specified number of /// dimensions and symbols. FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0, unsigned numLocals = 0) - : IntegerPolyhedron(/*numReservedInequalities=*/0, - /*numReservedEqualities=*/0, - /*numReservedCols=*/numDims + numSymbols + numLocals + - 1, - numDims, numSymbols, numLocals) {} + : FlatAffineConstraints(/*numReservedInequalities=*/0, + /*numReservedEqualities=*/0, + /*numReservedCols=*/numDims + numSymbols + + numLocals + 1, + numDims, numSymbols, numLocals) {} explicit FlatAffineConstraints(const IntegerPolyhedron &poly) : IntegerPolyhedron(poly) {} diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1702,20 +1702,14 @@ } } - // Set the number of dimensions, symbols, locals in the resulting system. - unsigned newNumDomain = - getNumDomainIds() - getIdKindOverlap(IdKind::Domain, pos, pos + 1); - unsigned newNumRange = - getNumRangeIds() - getIdKindOverlap(IdKind::Range, pos, pos + 1); - unsigned newNumSymbols = - getNumSymbolIds() - getIdKindOverlap(IdKind::Symbol, pos, pos + 1); - unsigned newNumLocals = - getNumLocalIds() - getIdKindOverlap(IdKind::Local, pos, pos + 1); + PresburgerSpace newSpace = getSpace(); + IdKind idKindRemove = newSpace.getIdKindAt(pos); + unsigned relativePos = pos - newSpace.getIdKindOffset(idKindRemove); + newSpace.removeIdRange(idKindRemove, relativePos, relativePos + 1); /// Create the new system which has one identifier less. IntegerRelation newRel(lbIndices.size() * ubIndices.size() + nbIndices.size(), - getNumEqualities(), getNumCols() - 1, newNumDomain, - newNumRange, newNumSymbols, newNumLocals); + getNumEqualities(), getNumCols() - 1, newSpace); // This will be used to check if the elimination was integer exact. unsigned lcmProducts = 1; @@ -1864,10 +1858,9 @@ } // namespace // Returns constraints that are common to both A & B. -static void getCommonConstraints(const IntegerRelation &a, - const IntegerRelation &b, IntegerRelation &c) { - c = IntegerRelation(a.getNumDomainIds(), a.getNumRangeIds(), - a.getNumSymbolIds(), a.getNumLocalIds()); +static IntegerRelation getCommonConstraints(const IntegerRelation &a, + const IntegerRelation &b) { + IntegerRelation c(a.getSpace()); // a naive O(n^2) check should be enough here given the input sizes. for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) { for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) { @@ -1885,6 +1878,8 @@ } } } + + return c; } // Computes the bounding box with respect to 'other' by finding the min of the @@ -1896,8 +1891,7 @@ // Get the constraints common to both systems; these will be added as is to // the union. - IntegerRelation commonCst; - getCommonConstraints(*this, otherCst, commonCst); + IntegerRelation commonCst = getCommonConstraints(*this, otherCst); std::vector> boundingLbs; std::vector> boundingUbs; diff --git a/mlir/lib/Analysis/Presburger/LinearTransform.cpp b/mlir/lib/Analysis/Presburger/LinearTransform.cpp --- a/mlir/lib/Analysis/Presburger/LinearTransform.cpp +++ b/mlir/lib/Analysis/Presburger/LinearTransform.cpp @@ -113,8 +113,7 @@ } IntegerRelation LinearTransform::applyTo(const IntegerRelation &rel) const { - IntegerRelation result(rel.getNumDomainIds(), rel.getNumRangeIds(), - rel.getNumSymbolIds(), rel.getNumLocalIds()); + IntegerRelation result(rel.getSpace()); for (unsigned i = 0, e = rel.getNumEqualities(); i < e; ++i) { ArrayRef eq = rel.getEquality(i); diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -27,8 +27,7 @@ } PresburgerSet PWMAFunction::getDomain() const { - PresburgerSet domain = - PresburgerSet::getEmpty(getNumDimIds(), getNumSymbolIds()); + PresburgerSet domain = PresburgerSet::getEmpty(getSpace()); for (const MultiAffineFunction &piece : pieces) domain.unionInPlace(piece.getDomain()); return domain; diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -17,7 +17,7 @@ using namespace presburger; PresburgerRelation::PresburgerRelation(const IntegerRelation &disjunct) - : PresburgerSpace(disjunct) { + : PresburgerSpace(disjunct.getCompatibleSpace()) { unionInPlace(disjunct); } @@ -67,19 +67,15 @@ }); } -PresburgerRelation PresburgerRelation::getUniverse(unsigned numDomain, - unsigned numRange, - unsigned numSymbols) { - PresburgerRelation result(numDomain, numRange, numSymbols); - result.unionInPlace( - IntegerRelation::getUniverse(numDomain, numRange, numSymbols)); +PresburgerRelation +PresburgerRelation::getUniverse(const PresburgerSpace &space) { + PresburgerRelation result(space); + result.unionInPlace(IntegerRelation::getUniverse(space)); return result; } -PresburgerRelation PresburgerRelation::getEmpty(unsigned numDomain, - unsigned numRange, - unsigned numSymbols) { - return PresburgerRelation(numDomain, numRange, numSymbols); +PresburgerRelation PresburgerRelation::getEmpty(const PresburgerSpace &space) { + return PresburgerRelation(space); } // Return the intersection of this set with the given set. @@ -93,8 +89,7 @@ PresburgerRelation::intersect(const PresburgerRelation &set) const { assert(isSpaceCompatible(set) && "Spaces should match"); - PresburgerRelation result(getNumDomainIds(), getNumRangeIds(), - getNumSymbolIds()); + PresburgerRelation result(getSpace()); for (const IntegerRelation &csA : integerRelations) { for (const IntegerRelation &csB : set.integerRelations) { IntegerRelation intersection = csA.intersect(csB); @@ -283,13 +278,10 @@ const PresburgerRelation &set) { assert(disjunct.isSpaceCompatible(set) && "Spaces should match"); if (disjunct.isEmptyByGCDTest()) - return PresburgerRelation::getEmpty(disjunct.getNumDomainIds(), - disjunct.getNumRangeIds(), - disjunct.getNumSymbolIds()); + return PresburgerRelation::getEmpty(disjunct.getCompatibleSpace()); - PresburgerRelation result = PresburgerRelation::getEmpty( - disjunct.getNumDomainIds(), disjunct.getNumRangeIds(), - disjunct.getNumSymbolIds()); + PresburgerRelation result = + PresburgerRelation::getEmpty(disjunct.getCompatibleSpace()); Simplex simplex(disjunct); subtractRecursively(disjunct, simplex, set, 0, result); return result; @@ -297,10 +289,7 @@ /// Return the complement of this set. PresburgerRelation PresburgerRelation::complement() const { - return getSetDifference(IntegerRelation::getUniverse(getNumDomainIds(), - getNumRangeIds(), - getNumSymbolIds()), - *this); + return getSetDifference(IntegerRelation::getUniverse(getSpace()), *this); } /// Return the result of subtract the given set from this set, i.e., @@ -308,8 +297,7 @@ PresburgerRelation PresburgerRelation::subtract(const PresburgerRelation &set) const { assert(isSpaceCompatible(set) && "Spaces should match"); - PresburgerRelation result(getNumDomainIds(), getNumRangeIds(), - getNumSymbolIds()); + PresburgerRelation result(getSpace()); // We compute (U_i t_i) \ (U_i set_i) as U_i (t_i \ V_i set_i). for (const IntegerRelation &disjunct : integerRelations) result.unionInPlace(getSetDifference(disjunct, set)); @@ -505,7 +493,8 @@ } PresburgerRelation newSet = - PresburgerRelation::getEmpty(numDomainIds, numRangeIds, numSymbolIds); + PresburgerRelation::getEmpty(PresburgerSpace::getRelationSpace( + numDomainIds, numRangeIds, numSymbolIds)); for (unsigned i = 0, e = disjuncts.size(); i < e; ++i) newSet.unionInPlace(disjuncts[i]); @@ -584,8 +573,7 @@ return !isFacetContained(curr, simp); })) return failure(); - IntegerRelation newSet(disjunct.getNumDomainIds(), disjunct.getNumRangeIds(), - disjunct.getNumSymbolIds(), disjunct.getNumLocalIds()); + IntegerRelation newSet(disjunct.getSpace()); for (ArrayRef curr : redundantIneqsA) newSet.addInequality(curr); @@ -707,15 +695,14 @@ void PresburgerRelation::dump() const { print(llvm::errs()); } -PresburgerSet PresburgerSet::getUniverse(unsigned numDims, - unsigned numSymbols) { - PresburgerSet result(numDims, numSymbols); - result.unionInPlace(IntegerPolyhedron::getUniverse(numDims, numSymbols)); +PresburgerSet PresburgerSet::getUniverse(const PresburgerSpace &space) { + PresburgerSet result(space); + result.unionInPlace(IntegerPolyhedron::getUniverse(space)); return result; } -PresburgerSet PresburgerSet::getEmpty(unsigned numDims, unsigned numSymbols) { - return PresburgerSet(numDims, numSymbols); +PresburgerSet PresburgerSet::getEmpty(const PresburgerSpace &space) { + return PresburgerSet(space); } PresburgerSet::PresburgerSet(const IntegerPolyhedron &disjunct) diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -56,6 +56,19 @@ return overlapEnd - overlapStart; } +IdKind PresburgerSpace::getIdKindAt(unsigned pos) const { + assert(pos < getNumIds() && "`pos` should represent a valid id position"); + if (pos < getIdKindEnd(IdKind::Domain)) + return IdKind::Domain; + if (pos < getIdKindEnd(IdKind::Range)) + return IdKind::Range; + if (pos < getIdKindEnd(IdKind::Symbol)) + return IdKind::Symbol; + if (pos < getIdKindEnd(IdKind::Local)) + return IdKind::Local; + llvm_unreachable("`pos` should represent a valid id position"); +} + unsigned PresburgerSpace::insertId(IdKind kind, unsigned pos, unsigned num) { assert(pos <= getNumIdKind(kind)); diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -158,8 +158,9 @@ FlatAffineConstraints::FlatAffineConstraints(IntegerSet set) : IntegerPolyhedron(set.getNumInequalities(), set.getNumEqualities(), set.getNumDims() + set.getNumSymbols() + 1, - set.getNumDims(), set.getNumSymbols(), - /*numLocals=*/0) { + PresburgerSpace::getSetSpace(set.getNumDims(), + set.getNumSymbols(), + /*numLocals=*/0)) { // Flatten expressions and add them to the constraint system. std::vector> flatExprs; diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -28,8 +28,9 @@ makeSetFromConstraints(unsigned ids, ArrayRef> ineqs, ArrayRef> eqs, unsigned syms = 0) { - IntegerPolyhedron set(ineqs.size(), eqs.size(), ids + 1, ids - syms, syms, - /*numLocals=*/0); + IntegerPolyhedron set( + ineqs.size(), eqs.size(), ids + 1, + PresburgerSpace::getSetSpace(ids - syms, syms, /*numLocals=*/0)); for (const auto &eq : eqs) set.addEquality(eq); for (const auto &ineq : ineqs) @@ -178,7 +179,7 @@ } TEST(IntegerPolyhedronTest, removeIdRange) { - IntegerPolyhedron set(3, 2, 1); + IntegerPolyhedron set(PresburgerSpace::getSetSpace(3, 2, 1)); set.addInequality({10, 11, 12, 20, 21, 30, 40}); set.removeId(IdKind::Symbol, 1); @@ -572,7 +573,7 @@ } TEST(IntegerPolyhedronTest, addConstantUpperBound) { - IntegerPolyhedron poly(2); + IntegerPolyhedron poly(PresburgerSpace::getSetSpace(2)); poly.addBound(IntegerPolyhedron::UB, 0, 1); EXPECT_EQ(poly.atIneq(0, 0), -1); EXPECT_EQ(poly.atIneq(0, 1), 0); @@ -585,7 +586,7 @@ } TEST(IntegerPolyhedronTest, addConstantLowerBound) { - IntegerPolyhedron poly(2); + IntegerPolyhedron poly(PresburgerSpace::getSetSpace(2)); poly.addBound(IntegerPolyhedron::LB, 0, 1); EXPECT_EQ(poly.atIneq(0, 0), 1); EXPECT_EQ(poly.atIneq(0, 1), 0); @@ -626,7 +627,7 @@ } TEST(IntegerPolyhedronTest, computeLocalReprSimple) { - IntegerPolyhedron poly(1); + IntegerPolyhedron poly(PresburgerSpace::getSetSpace(1)); poly.addLocalFloorDiv({1, 4}, 10); poly.addLocalFloorDiv({1, 0, 100}, 10); @@ -641,7 +642,7 @@ } TEST(IntegerPolyhedronTest, computeLocalReprConstantFloorDiv) { - IntegerPolyhedron poly(4); + IntegerPolyhedron poly(PresburgerSpace::getSetSpace(4)); poly.addInequality({1, 0, 3, 1, 2}); poly.addInequality({1, 2, -8, 1, 10}); @@ -659,7 +660,7 @@ } TEST(IntegerPolyhedronTest, computeLocalReprRecursive) { - IntegerPolyhedron poly(4); + IntegerPolyhedron poly(PresburgerSpace::getSetSpace(4)); poly.addInequality({1, 0, 3, 1, 2}); poly.addInequality({1, 2, -8, 1, 10}); poly.addEquality({1, 2, -4, 1, 10}); @@ -795,14 +796,14 @@ TEST(IntegerPolyhedronTest, simplifyLocalsTest) { // (x) : (exists y: 2x + y = 1 and y = 2). - IntegerPolyhedron poly(1, 0, 1); + IntegerPolyhedron poly(PresburgerSpace::getSetSpace(1, 0, 1)); poly.addEquality({2, 1, -1}); poly.addEquality({0, 1, -2}); EXPECT_TRUE(poly.isEmpty()); // (x) : (exists y, z, w: 3x + y = 1 and 2y = z and 3y = w and z = w). - IntegerPolyhedron poly2(1, 0, 3); + IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1, 0, 3)); poly2.addEquality({3, 1, 0, 0, -1}); poly2.addEquality({0, 2, -1, 0, 0}); poly2.addEquality({0, 3, 0, -1, 0}); @@ -811,7 +812,7 @@ EXPECT_TRUE(poly2.isEmpty()); // (x) : (exists y: x >= y + 1 and 2x + y = 0 and y >= -1). - IntegerPolyhedron poly3(1, 0, 1); + IntegerPolyhedron poly3(PresburgerSpace::getSetSpace(1, 0, 1)); poly3.addInequality({1, -1, -1}); poly3.addInequality({0, 1, 1}); poly3.addEquality({2, 1, 0}); @@ -822,13 +823,13 @@ TEST(IntegerPolyhedronTest, mergeDivisionsSimple) { { // (x) : (exists z, y = [x / 2] : x = 3y and x + z + 1 >= 0). - IntegerPolyhedron poly1(1, 0, 1); + IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1, 0, 1)); poly1.addLocalFloorDiv({1, 0, 0}, 2); // y = [x / 2]. poly1.addEquality({1, 0, -3, 0}); // x = 3y. poly1.addInequality({1, 1, 0, 1}); // x + z + 1 >= 0. // (x) : (exists y = [x / 2], z : x = 5y). - IntegerPolyhedron poly2(1); + IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1)); poly2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. poly2.addEquality({1, -5, 0}); // x = 5y. poly2.appendId(IdKind::Local); // Add local id z. @@ -845,13 +846,13 @@ { // (x) : (exists z = [x / 5], y = [x / 2] : x = 3y). - IntegerPolyhedron poly1(1); + IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1)); poly1.addLocalFloorDiv({1, 0}, 5); // z = [x / 5]. poly1.addLocalFloorDiv({1, 0, 0}, 2); // y = [x / 2]. poly1.addEquality({1, 0, -3, 0}); // x = 3y. // (x) : (exists y = [x / 2], z = [x / 5]: x = 5z). - IntegerPolyhedron poly2(1); + IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1)); poly2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. poly2.addLocalFloorDiv({1, 0, 0}, 5); // z = [x / 5]. poly2.addEquality({1, 0, -5, 0}); // x = 5z. @@ -869,14 +870,14 @@ { // Division Normalization test. // (x) : (exists z, y = [x / 2] : x = 3y and x + z + 1 >= 0). - IntegerPolyhedron poly1(1, 0, 1); + IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1, 0, 1)); // This division would be normalized. poly1.addLocalFloorDiv({3, 0, 0}, 6); // y = [3x / 6] -> [x/2]. poly1.addEquality({1, 0, -3, 0}); // x = 3z. poly1.addInequality({1, 1, 0, 1}); // x + y + 1 >= 0. // (x) : (exists y = [x / 2], z : x = 5y). - IntegerPolyhedron poly2(1); + IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1)); poly2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. poly2.addEquality({1, -5, 0}); // x = 5y. poly2.appendId(IdKind::Local); // Add local id z. @@ -895,13 +896,13 @@ TEST(IntegerPolyhedronTest, mergeDivisionsNestedDivsions) { { // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x). - IntegerPolyhedron poly1(1); + IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1)); poly1.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. poly1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. poly1.addInequality({-1, 1, 1, 0}); // y + z >= x. // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x). - IntegerPolyhedron poly2(1); + IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1)); poly2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. poly2.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. poly2.addInequality({1, -1, -1, 0}); // y + z <= x. @@ -918,14 +919,14 @@ { // (x) : (exists y = [x / 2], z = [x + y / 3], w = [z + 1 / 5]: y + z >= x). - IntegerPolyhedron poly1(1); + IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1)); poly1.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. poly1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. poly1.addLocalFloorDiv({0, 0, 1, 1}, 5); // w = [z + 1 / 5]. poly1.addInequality({-1, 1, 1, 0, 0}); // y + z >= x. // (x) : (exists y = [x / 2], z = [x + y / 3], w = [z + 1 / 5]: y + z <= x). - IntegerPolyhedron poly2(1); + IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1)); poly2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. poly2.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. poly2.addLocalFloorDiv({0, 0, 1, 1}, 5); // w = [z + 1 / 5]. @@ -942,13 +943,13 @@ } { // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x). - IntegerPolyhedron poly1(1); + IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1)); poly1.addLocalFloorDiv({2, 0}, 4); // y = [2x / 4] -> [x / 2]. poly1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3]. poly1.addInequality({-1, 1, 1, 0}); // y + z >= x. // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x). - IntegerPolyhedron poly2(1); + IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1)); poly2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2]. // This division would be normalized. poly2.addLocalFloorDiv({3, 3, 0}, 9); // z = [3x + 3y / 9] -> [x + y / 3]. @@ -968,13 +969,13 @@ TEST(IntegerPolyhedronTest, mergeDivisionsConstants) { { // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z >= x). - IntegerPolyhedron poly1(1); + IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1)); poly1.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 2]. poly1.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3]. poly1.addInequality({-1, 1, 1, 0}); // y + z >= x. // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x). - IntegerPolyhedron poly2(1); + IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1)); poly2.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 2]. poly2.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3]. poly2.addInequality({1, -1, -1, 0}); // y + z <= x. @@ -990,14 +991,14 @@ } { // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z >= x). - IntegerPolyhedron poly1(1); + IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1)); poly1.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 2]. // Normalization test. poly1.addLocalFloorDiv({3, 0, 6}, 9); // z = [3x + 6 / 9] -> [x + 2 / 3]. poly1.addInequality({-1, 1, 1, 0}); // y + z >= x. // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x). - IntegerPolyhedron poly2(1); + IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1)); // Normalization test. poly2.addLocalFloorDiv({2, 2}, 4); // y = [2x + 2 / 4] -> [x + 1 / 2]. poly2.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3]. @@ -1016,14 +1017,14 @@ TEST(IntegerPolyhedronTest, negativeDividends) { // (x) : (exists y = [-x + 1 / 2], z = [-x - 2 / 3]: y + z >= x). - IntegerPolyhedron poly1(1); + IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1)); poly1.addLocalFloorDiv({-1, 1}, 2); // y = [x + 1 / 2]. // Normalization test with negative dividends poly1.addLocalFloorDiv({-3, 0, -6}, 9); // z = [3x + 6 / 9] -> [x + 2 / 3]. poly1.addInequality({-1, 1, 1, 0}); // y + z >= x. // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x). - IntegerPolyhedron poly2(1); + IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1)); // Normalization test. poly2.addLocalFloorDiv({-2, 2}, 4); // y = [-2x + 2 / 4] -> [-x + 1 / 2]. poly2.addLocalFloorDiv({-1, 0, -2}, 3); // z = [-x - 2 / 3]. @@ -1206,7 +1207,7 @@ TEST(IntegerPolyhedronTest, truncateEqualityRegressionTest) { // IntegerRelation::truncate was truncating inequalities to the number of // equalities. - IntegerRelation set(1); + IntegerRelation set(PresburgerSpace::getSetSpace(1)); IntegerRelation::CountsSnapshot snapshot = set.getCounts(); set.addEquality({1, 0}); set.truncate(snapshot); diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -89,7 +89,8 @@ /// local ids. static PresburgerSet makeSetFromPoly(unsigned numDims, ArrayRef polys) { - PresburgerSet set = PresburgerSet::getEmpty(numDims); + PresburgerSet set = + PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(numDims)); for (const IntegerPolyhedron &poly : polys) set.unionInPlace(poly); return set; @@ -131,23 +132,26 @@ {"(x) : (x - 2 >= 0, -x + 8 >= 0)", "(x) : (x - 10 >= 0, -x + 20 >= 0)"}); // Universe union set. - testUnionAtPoints(PresburgerSet::getUniverse(1), set, - {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + testUnionAtPoints(PresburgerSet::getUniverse(PresburgerSpace::getSetSpace(1)), + set, {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); // empty set union set. - testUnionAtPoints(PresburgerSet::getEmpty(1), set, - {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + testUnionAtPoints(PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(1)), + set, {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); // empty set union Universe. - testUnionAtPoints(PresburgerSet::getEmpty(1), PresburgerSet::getUniverse(1), + testUnionAtPoints(PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(1)), + PresburgerSet::getUniverse(PresburgerSpace::getSetSpace(1)), {{1}, {2}, {0}, {-1}}); // Universe union empty set. - testUnionAtPoints(PresburgerSet::getUniverse(1), PresburgerSet::getEmpty(1), + testUnionAtPoints(PresburgerSet::getUniverse(PresburgerSpace::getSetSpace(1)), + PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(1)), {{1}, {2}, {0}, {-1}}); // empty set union empty set. - testUnionAtPoints(PresburgerSet::getEmpty(1), PresburgerSet::getEmpty(1), + testUnionAtPoints(PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))), + PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))), {{1}, {2}, {0}, {-1}}); } @@ -157,24 +161,32 @@ {"(x) : (x - 2 >= 0, -x + 8 >= 0)", "(x) : (x - 10 >= 0, -x + 20 >= 0)"}); // Universe intersection set. - testIntersectAtPoints(PresburgerSet::getUniverse(1), set, - {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + testIntersectAtPoints( + PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))), set, + {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); // empty set intersection set. - testIntersectAtPoints(PresburgerSet::getEmpty(1), set, - {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); + testIntersectAtPoints( + PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))), set, + {{1}, {2}, {8}, {9}, {10}, {20}, {21}}); // empty set intersection Universe. - testIntersectAtPoints(PresburgerSet::getEmpty(1), - PresburgerSet::getUniverse(1), {{1}, {2}, {0}, {-1}}); + testIntersectAtPoints( + PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))), + PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))), + {{1}, {2}, {0}, {-1}}); // Universe intersection empty set. - testIntersectAtPoints(PresburgerSet::getUniverse(1), - PresburgerSet::getEmpty(1), {{1}, {2}, {0}, {-1}}); + testIntersectAtPoints( + PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))), + PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))), + {{1}, {2}, {0}, {-1}}); // Universe intersection Universe. - testIntersectAtPoints(PresburgerSet::getUniverse(1), - PresburgerSet::getUniverse(1), {{1}, {2}, {0}, {-1}}); + testIntersectAtPoints( + PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))), + PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))), + {{1}, {2}, {0}, {-1}}); } TEST(SetTest, Subtract) { @@ -329,12 +341,12 @@ TEST(SetTest, Complement) { // Complement of universe. testComplementAtPoints( - PresburgerSet::getUniverse(1), + PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))), {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}}); // Complement of empty set. testComplementAtPoints( - PresburgerSet::getEmpty(1), + PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))), {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}}); testComplementAtPoints( @@ -356,8 +368,10 @@ TEST(SetTest, isEqual) { // set = [2, 8] U [10, 20]. - PresburgerSet universe = PresburgerSet::getUniverse(1); - PresburgerSet emptySet = PresburgerSet::getEmpty(1); + PresburgerSet universe = + PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))); + PresburgerSet emptySet = + PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))); PresburgerSet set = parsePresburgerSetFromPolyStrings( 1, {"(x) : (x - 2 >= 0, -x + 8 >= 0)", "(x) : (x - 10 >= 0, -x + 20 >= 0)"}); @@ -431,7 +445,8 @@ // evens /\ odds = empty. expectEmpty(PresburgerSet(evens).intersect(PresburgerSet(odds))); // evens U odds = universe. - expectEqual(evens.unionSet(odds), PresburgerSet::getUniverse(1)); + expectEqual(evens.unionSet(odds), + PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1)))); expectEqual(evens.complement(), odds); expectEqual(odds.complement(), evens); // even multiples of 3 = multiples of 6. diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp @@ -14,7 +14,7 @@ using namespace presburger; TEST(PresburgerSpaceTest, insertId) { - PresburgerSpace space(2, 2, 1); + PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 2, 1); // Try inserting 2 domain ids. space.insertId(IdKind::Domain, 0, 2); @@ -26,7 +26,7 @@ } TEST(PresburgerSpaceTest, insertIdSet) { - PresburgerSpace space(0, 2, 1); + PresburgerSpace space = PresburgerSpace::getSetSpace(2, 1); // Try inserting 2 dimension ids. The space should have 4 range ids since // spaces which do not distinguish between domain, range are implemented like @@ -36,7 +36,7 @@ } TEST(PresburgerSpaceTest, removeIdRange) { - PresburgerSpace space(2, 1, 3); + PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 3); // Remove 1 domain identifier. space.removeIdRange(IdKind::Domain, 0, 1); diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h --- a/mlir/unittests/Analysis/Presburger/Utils.h +++ b/mlir/unittests/Analysis/Presburger/Utils.h @@ -41,7 +41,8 @@ /// number of dimensions as is specified by the numDims argument. inline PresburgerSet parsePresburgerSetFromPolyStrings(unsigned numDims, ArrayRef strs) { - PresburgerSet set = PresburgerSet::getEmpty(numDims); + PresburgerSet set = + PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(numDims)); for (StringRef str : strs) set.unionInPlace(parsePoly(str)); return set; @@ -70,7 +71,9 @@ unsigned numSymbols = 0) { static MLIRContext context; - PWMAFunction result(numInputs - numSymbols, numSymbols, numOutputs); + PWMAFunction result( + PresburgerSpace::getSetSpace(numInputs - numSymbols, numSymbols), + numOutputs); for (const auto &pair : data) { IntegerPolyhedron domain = parsePoly(pair.first);