diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -15,6 +15,7 @@ #define MLIR_ANALYSIS_PRESBURGER_PRESBURGERSPACE_H #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" namespace mlir { @@ -31,23 +32,45 @@ /// /// Local: Local identifiers correspond to existentially quantified variables. /// -/// PresburgerSpace only supports identifiers of kind Dimension and Symbol. +/// Dimension identifiers are further divided into Domain and Range identifiers +/// to support building relations. +/// +/// Spaces with distinction between domain and range identifiers should use +/// IdKind::Domain and IdKind::Range to refer to domain and range identifiers. +/// +/// Spaces with no distinction between domain and range identifiers should use +/// IdKind::SetDim to refer to dimension identifiers. +/// +/// PresburgerSpace does not support identifiers of kind Local. See +/// PresburgerLocalSpace for an extension that supports Local ids. class PresburgerSpace { friend PresburgerLocalSpace; public: - /// Kind of identifier (column). - enum IdKind { Dimension, Symbol, Local }; + /// Kind of identifier. Implementation wise SetDims are treated as Range + /// ids, and spaces with no distinction between dimension ids are treated + /// as relations with zero domain ids. + enum IdKind { Symbol, Local, Domain, Range, SetDim = Range }; - PresburgerSpace(unsigned numDims, unsigned numSymbols) - : numDims(numDims), numSymbols(numSymbols), numLocals(0) {} + static PresburgerSpace getRelationSpace(unsigned numDomain, unsigned numRange, + unsigned numSymbols); + + static PresburgerSpace getSetSpace(unsigned numDims, unsigned numSymbols); virtual ~PresburgerSpace() = default; - unsigned getNumIds() const { return numDims + numSymbols + numLocals; } - unsigned getNumDimIds() const { return numDims; } + unsigned getNumDomainIds() const { return numDomain; } + unsigned getNumRangeIds() const { return numRange; } unsigned getNumSymbolIds() const { return numSymbols; } - unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; } + unsigned getNumSetDimIds() const { return numRange; } + + unsigned getNumDimIds() const { return numDomain + numRange; } + unsigned getNumDimAndSymbolIds() const { + return numDomain + numRange + numSymbols; + } + unsigned getNumIds() const { + return numDomain + numRange + numSymbols + numLocals; + } /// Get the number of ids of the specified kind. unsigned getNumIdKind(IdKind kind) const; @@ -78,12 +101,36 @@ /// split become dimensions. void setDimSymbolSeparation(unsigned newSymbolCount); + void print(llvm::raw_ostream &os) const; + void dump() const; + +protected: + /// Space constructor for Relation space type. + PresburgerSpace(unsigned numDomain, unsigned numRange, unsigned numSymbols) + : PresburgerSpace(Relation, numDomain, numRange, numSymbols, + /*numLocals=*/0) {} + + /// Space constructor for Set space type. + PresburgerSpace(unsigned numDims, unsigned numSymbols) + : PresburgerSpace(Set, /*numDomain=*/0, numDims, numSymbols, + /*numLocals=*/0) {} + private: - PresburgerSpace(unsigned numDims, unsigned numSymbols, unsigned numLocals) - : numDims(numDims), numSymbols(numSymbols), numLocals(numLocals) {} + /// Kind of space. + enum SpaceKind { Set, Relation }; + + PresburgerSpace(SpaceKind spaceKind, unsigned numDomain, unsigned numRange, + unsigned numSymbols, unsigned numLocals) + : spaceKind(spaceKind), numDomain(numDomain), numRange(numRange), + numSymbols(numSymbols), numLocals(numLocals) {} - /// Number of identifiers corresponding to real dimensions. - unsigned numDims; + SpaceKind spaceKind; + + // Number of identifiers corresponding to domain identifiers. + unsigned numDomain; + + // Number of identifiers corresponding to range identifiers. + unsigned numRange; /// Number of identifiers corresponding to symbols (unknown but constant for /// analysis). @@ -96,9 +143,13 @@ /// Extension of PresburgerSpace supporting Local identifiers. class PresburgerLocalSpace : public PresburgerSpace { public: - PresburgerLocalSpace(unsigned numDims, unsigned numSymbols, - unsigned numLocals) - : PresburgerSpace(numDims, numSymbols, numLocals) {} + static PresburgerLocalSpace getRelationSpace(unsigned numDomain, + unsigned numRange, + unsigned numSymbols, + unsigned numLocals); + + static PresburgerLocalSpace getSetSpace(unsigned numDims, unsigned numSymbols, + unsigned numLocals); unsigned getNumLocalIds() const { return numLocals; } @@ -110,6 +161,20 @@ /// Removes identifiers in the column range [idStart, idLimit). void removeIdRange(unsigned idStart, unsigned idLimit) override; + + void print(llvm::raw_ostream &os) const; + void dump() const; + +protected: + /// Local Space constructor for Relation space type. + PresburgerLocalSpace(unsigned numDomain, unsigned numRange, + unsigned numSymbols, unsigned numLocals) + : PresburgerSpace(Relation, numDomain, numRange, numSymbols, numLocals) {} + + /// Local Space constructor for Set space type. + PresburgerLocalSpace(unsigned numDims, unsigned numSymbols, + unsigned numLocals) + : PresburgerSpace(Set, /*numDomain=*/0, numDims, numSymbols, numLocals) {} }; } // namespace mlir diff --git a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp --- a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp @@ -93,7 +93,7 @@ } unsigned IntegerPolyhedron::insertDimId(unsigned pos, unsigned num) { - return insertId(IdKind::Dimension, pos, num); + return insertId(IdKind::SetDim, pos, num); } unsigned IntegerPolyhedron::insertSymbolId(unsigned pos, unsigned num) { @@ -107,16 +107,15 @@ unsigned IntegerPolyhedron::insertId(IdKind kind, unsigned pos, unsigned num) { assert(pos <= getNumIdKind(kind)); - unsigned absolutePos = getIdKindOffset(kind) + pos; - inequalities.insertColumns(absolutePos, num); - equalities.insertColumns(absolutePos, num); - - return PresburgerLocalSpace::insertId(kind, pos, num); + unsigned insertPos = PresburgerLocalSpace::insertId(kind, pos, num); + inequalities.insertColumns(insertPos, num); + equalities.insertColumns(insertPos, num); + return insertPos; } unsigned IntegerPolyhedron::appendDimId(unsigned num) { unsigned pos = getNumDimIds(); - insertId(IdKind::Dimension, pos, num); + insertId(IdKind::SetDim, pos, num); return pos; } diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -12,24 +12,56 @@ using namespace mlir; +PresburgerSpace PresburgerSpace::getRelationSpace(unsigned numDomain, + unsigned numRange, + unsigned numSymbols) { + return PresburgerSpace(numDomain, numRange, numSymbols); +} + +PresburgerSpace PresburgerSpace::getSetSpace(unsigned numDims, + unsigned numSymbols) { + return PresburgerSpace(numDims, numSymbols); +} + +PresburgerLocalSpace +PresburgerLocalSpace::getRelationSpace(unsigned numDomain, unsigned numRange, + unsigned numSymbols, + unsigned numLocals) { + return PresburgerLocalSpace(numDomain, numRange, numSymbols, numLocals); +} + +PresburgerLocalSpace PresburgerLocalSpace::getSetSpace(unsigned numDims, + unsigned numSymbols, + unsigned numLocals) { + return PresburgerLocalSpace(numDims, numSymbols, numLocals); +} + unsigned PresburgerSpace::getNumIdKind(IdKind kind) const { - if (kind == IdKind::Dimension) - return getNumDimIds(); + if (kind == IdKind::Domain) { + assert(spaceKind == Relation && "IdKind::Domain is not supported in Set."); + return getNumDomainIds(); + } + if (kind == IdKind::Range) + return getNumRangeIds(); if (kind == IdKind::Symbol) return getNumSymbolIds(); if (kind == IdKind::Local) return numLocals; - llvm_unreachable("IdKind does not exit!"); + llvm_unreachable("IdKind does not exist!"); } unsigned PresburgerSpace::getIdKindOffset(IdKind kind) const { - if (kind == IdKind::Dimension) + if (kind == IdKind::Domain) { + assert(spaceKind == Relation && "IdKind::Domain is not supported in Set."); return 0; + } + if (kind == IdKind::Range) + return getNumDomainIds(); if (kind == IdKind::Symbol) return getNumDimIds(); if (kind == IdKind::Local) return getNumDimAndSymbolIds(); - llvm_unreachable("IdKind does not exit!"); + llvm_unreachable("IdKind does not exist!"); } unsigned PresburgerSpace::getIdKindEnd(IdKind kind) const { @@ -56,13 +88,16 @@ unsigned absolutePos = getIdKindOffset(kind) + pos; - if (kind == IdKind::Dimension) - numDims += num; - else if (kind == IdKind::Symbol) + if (kind == IdKind::Domain) { + assert(spaceKind == Relation && "IdKind::Domain is not supported in Set."); + numDomain += num; + } else if (kind == IdKind::Range) { + numRange += num; + } else if (kind == IdKind::Symbol) { numSymbols += num; - else - llvm_unreachable( - "PresburgerSpace only supports Dimensions and Symbol identifiers!"); + } else { + llvm_unreachable("PresburgerSpace does not support local identifiers!"); + } return absolutePos; } @@ -76,13 +111,17 @@ // We are going to be removing one or more identifiers from the range. assert(idStart < getNumIds() && "invalid idStart position"); - // Update members numDims, numSymbols and numIds. - unsigned numDimsEliminated = - getIdKindOverlap(IdKind::Dimension, idStart, idLimit); + // Update members numDomain, numRange, numSymbols and numIds. + unsigned numDomainEliminated = 0; + if (spaceKind == Relation) + numDomainEliminated = getIdKindOverlap(IdKind::Domain, idStart, idLimit); + unsigned numRangeEliminated = + getIdKindOverlap(IdKind::Range, idStart, idLimit); unsigned numSymbolsEliminated = getIdKindOverlap(IdKind::Symbol, idStart, idLimit); - numDims -= numDimsEliminated; + numDomain -= numDomainEliminated; + numRange -= numRangeEliminated; numSymbols -= numSymbolsEliminated; } @@ -108,8 +147,7 @@ getIdKindOverlap(IdKind::Local, idStart, idLimit); // Update space parameters. - PresburgerSpace::removeIdRange( - idStart, std::min(idLimit, PresburgerSpace::getNumIds())); + PresburgerSpace::removeIdRange(idStart, idLimit); // Update local ids. numLocals -= numLocalsEliminated; @@ -118,6 +156,31 @@ void PresburgerSpace::setDimSymbolSeparation(unsigned newSymbolCount) { assert(newSymbolCount <= getNumDimAndSymbolIds() && "invalid separation position"); - numDims = numDims + numSymbols - newSymbolCount; + numRange = numRange + numSymbols - newSymbolCount; numSymbols = newSymbolCount; } + +void PresburgerSpace::print(llvm::raw_ostream &os) const { + if (spaceKind == Relation) { + os << "Domain: " << getNumDomainIds() << ", " + << "Range: " << getNumRangeIds() << ", "; + } else { + os << "Dimension: " << getNumDomainIds() << ", "; + } + os << "Symbols: " << getNumSymbolIds() << "\n"; +} + +void PresburgerSpace::dump() const { print(llvm::errs()); } + +void PresburgerLocalSpace::print(llvm::raw_ostream &os) const { + if (spaceKind == Relation) { + os << "Domain: " << getNumDomainIds() << ", " + << "Range: " << getNumRangeIds() << ", "; + } else { + os << "Dimension: " << getNumDomainIds() << ", "; + } + os << "Symbols: " << getNumSymbolIds() << ", " + << "Locals" << getNumLocalIds() << "\n"; +} + +void PresburgerLocalSpace::dump() const { print(llvm::errs()); } diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -268,7 +268,7 @@ unsigned FlatAffineValueConstraints::appendDimId(ValueRange vals) { unsigned pos = getNumDimIds(); - insertId(IdKind::Dimension, pos, vals); + insertId(IdKind::SetDim, pos, vals); return pos; } @@ -280,7 +280,7 @@ unsigned FlatAffineValueConstraints::insertDimId(unsigned pos, ValueRange vals) { - return insertId(IdKind::Dimension, pos, vals); + return insertId(IdKind::SetDim, pos, vals); } unsigned FlatAffineValueConstraints::insertSymbolId(unsigned pos, @@ -365,7 +365,7 @@ static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique( const FlatAffineValueConstraints &cst, FlatAffineConstraints::IdKind kind) { - if (kind == FlatAffineConstraints::IdKind::Dimension) + if (kind == FlatAffineConstraints::IdKind::SetDim) return areIdsUnique(cst, 0, cst.getNumDimIds()); if (kind == FlatAffineConstraints::IdKind::Symbol) return areIdsUnique(cst, cst.getNumDimIds(), cst.getNumDimAndSymbolIds()); @@ -1214,8 +1214,8 @@ dims.reserve(getNumDimIds()); syms.reserve(getNumSymbolIds()); - for (unsigned i = getIdKindOffset(IdKind::Dimension), - e = getIdKindEnd(IdKind::Dimension); + for (unsigned i = getIdKindOffset(IdKind::SetDim), + e = getIdKindEnd(IdKind::SetDim); i < e; ++i) dims.push_back(values[i] ? *values[i] : Value()); for (unsigned i = getIdKindOffset(IdKind::Symbol), diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -3,6 +3,7 @@ LinearTransformTest.cpp MatrixTest.cpp PresburgerSetTest.cpp + PresburgerSpaceTest.cpp PWMAFunctionTest.cpp SimplexTest.cpp ../../Dialect/Affine/Analysis/AffineStructuresParser.cpp diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -158,7 +158,7 @@ EXPECT_THAT(set.getInequality(0), testing::ElementsAre(10, 11, 12, 20, 30, 40)); - set.removeIdRange(IntegerPolyhedron::IdKind::Dimension, 0, 2); + set.removeIdRange(IntegerPolyhedron::IdKind::SetDim, 0, 2); EXPECT_THAT(set.getInequality(0), testing::ElementsAre(12, 20, 30, 40)); set.removeIdRange(IntegerPolyhedron::IdKind::Local, 1, 1); diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp @@ -0,0 +1,50 @@ +//===- PresburgerSpaceTest.cpp - Tests for PresburgerSpace ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/PresburgerSpace.h" +#include +#include + +using namespace mlir; +using IdKind = PresburgerSpace::IdKind; + +TEST(PresburgerSpaceTest, insertId) { + PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 2, 1); + + // Try inserting 2 domain ids. + space.insertId(PresburgerSpace::IdKind::Domain, 0, 2); + EXPECT_EQ(space.getNumDomainIds(), 4u); + + // Try inserting 1 range ids. + space.insertId(PresburgerSpace::IdKind::Range, 0, 1); + EXPECT_EQ(space.getNumRangeIds(), 3u); +} + +TEST(PresburgerSpaceTest, insertIdSet) { + PresburgerSpace space = PresburgerSpace::getSetSpace(2, 1); + + // Try inserting 2 dimension ids. The space should have 4 range ids since + // spaces which do not distinguish between domain, range are implemented like + // this. + space.insertId(PresburgerSpace::IdKind::SetDim, 0, 2); + EXPECT_EQ(space.getNumRangeIds(), 4u); +} + +TEST(PresburgerSpaceTest, removeIdRange) { + PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 3); + + // Remove 1 domain identifier. + space.removeIdRange(0, 1); + EXPECT_EQ(space.getNumDomainIds(), 1u); + + // Remove 1 symbol and 1 range identifier. + space.removeIdRange(1, 3); + EXPECT_EQ(space.getNumDomainIds(), 1u); + EXPECT_EQ(space.getNumRangeIds(), 0u); + EXPECT_EQ(space.getNumSymbolIds(), 2u); +}