diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -15,6 +15,7 @@ #define MLIR_ANALYSIS_PRESBURGER_PRESBURGERSPACE_H #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" namespace mlir { @@ -31,23 +32,47 @@ /// /// Local: Local identifiers correspond to existentially quantified variables. /// -/// PresburgerSpace only supports identifiers of kind Dimension and Symbol. +/// Dimension identifiers are further divided into Domain and Range identifiers +/// to support building relations. +/// +/// Spaces with distinction between domain and range identifiers should use +/// IdKind::Domain and IdKind::Range to refer to domain and range identifiers. +/// IdKind:: Dimension should be used to refer to both domain and range +/// together. +/// +/// Spaces with no distinction between domain and range identifiers should use +/// IdKind::Dimension to refer to dimension identifiers. +/// +/// PresburgerSpace does not support identifiers of Local type. class PresburgerSpace { friend PresburgerLocalSpace; public: - /// Kind of identifier (column). - enum IdKind { Dimension, Symbol, Local }; + /// Kind of identifier. + enum IdKind { Dimension, Symbol, Local, Domain, Range }; + /// Space constructor for Relation space type. + PresburgerSpace(unsigned numDomain, unsigned numRange, unsigned numSymbols) + : PresburgerSpace(numDomain, numRange, numSymbols, /*numLocals=*/0, + Relation) {} + + /// Space constructor for Set space type. PresburgerSpace(unsigned numDims, unsigned numSymbols) - : numDims(numDims), numSymbols(numSymbols), numLocals(0) {} + : PresburgerSpace(/*numDomain=*/0, numDims, numSymbols, /*numLocals=*/0, + Set) {} virtual ~PresburgerSpace() = default; - unsigned getNumIds() const { return numDims + numSymbols + numLocals; } - unsigned getNumDimIds() const { return numDims; } + unsigned getNumDomainIds() const { return numDomain; } + unsigned getNumRangeIds() const { return numRange; } + unsigned getNumDimIds() const { return numDomain + numRange; } unsigned getNumSymbolIds() const { return numSymbols; } - unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; } + unsigned getNumDimAndSymbolIds() const { + return numDomain + numRange + numSymbols; + } + unsigned getNumIds() const { + return numDomain + numRange + numSymbols + numLocals; + } /// Get the number of ids of the specified kind. unsigned getNumIdKind(IdKind kind) const; @@ -78,12 +103,23 @@ /// split become dimensions. void setDimSymbolSeparation(unsigned newSymbolCount); + void print(llvm::raw_ostream &os) const; + void dump() const; + private: - PresburgerSpace(unsigned numDims, unsigned numSymbols, unsigned numLocals) - : numDims(numDims), numSymbols(numSymbols), numLocals(numLocals) {} + /// Kind of space. + enum SpaceType { Set, Relation }; - /// Number of identifiers corresponding to real dimensions. - unsigned numDims; + PresburgerSpace(unsigned numDomain, unsigned numRange, unsigned numSymbols, + unsigned numLocals, SpaceType spaceType) + : numDomain(numDomain), numRange(numRange), numSymbols(numSymbols), + numLocals(numLocals), spaceType(spaceType) {} + + // Number of identifiers corresponding to domain identifiers. + unsigned numDomain; + + // Number of identifiers corresponding to range identifiers. + unsigned numRange; /// Number of identifiers corresponding to symbols (unknown but constant for /// analysis). @@ -91,14 +127,22 @@ /// Total number of identifiers. unsigned numLocals; + + SpaceType spaceType; }; /// Extension of PresburgerSpace supporting Local identifiers. class PresburgerLocalSpace : public PresburgerSpace { public: + /// Local Space constructor for Relation space type. + PresburgerLocalSpace(unsigned numDomain, unsigned numRange, + unsigned numSymbols, unsigned numLocals) + : PresburgerSpace(numDomain, numRange, numSymbols, numLocals, Relation) {} + + /// Local Space constructor for Set space type. PresburgerLocalSpace(unsigned numDims, unsigned numSymbols, unsigned numLocals) - : PresburgerSpace(numDims, numSymbols, numLocals) {} + : PresburgerSpace(/*numDomain=*/0, numDims, numSymbols, numLocals, Set) {} unsigned getNumLocalIds() const { return numLocals; } @@ -110,6 +154,9 @@ /// Removes identifiers in the column range [idStart, idLimit). void removeIdRange(unsigned idStart, unsigned idLimit) override; + + void print(llvm::raw_ostream &os) const; + void dump() const; }; } // namespace mlir diff --git a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp --- a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp @@ -107,11 +107,10 @@ unsigned IntegerPolyhedron::insertId(IdKind kind, unsigned pos, unsigned num) { assert(pos <= getNumIdKind(kind)); - unsigned absolutePos = getIdKindOffset(kind) + pos; - inequalities.insertColumns(absolutePos, num); - equalities.insertColumns(absolutePos, num); - - return PresburgerLocalSpace::insertId(kind, pos, num); + unsigned insertPos = PresburgerLocalSpace::insertId(kind, pos, num); + inequalities.insertColumns(insertPos, num); + equalities.insertColumns(insertPos, num); + return insertPos; } unsigned IntegerPolyhedron::appendDimId(unsigned num) { diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -15,21 +15,29 @@ unsigned PresburgerSpace::getNumIdKind(IdKind kind) const { if (kind == IdKind::Dimension) return getNumDimIds(); + if (kind == IdKind::Domain) + return getNumDomainIds(); + if (kind == IdKind::Range) + return getNumRangeIds(); if (kind == IdKind::Symbol) return getNumSymbolIds(); if (kind == IdKind::Local) return numLocals; - llvm_unreachable("IdKind does not exit!"); + llvm_unreachable("IdKind does not exist!"); } unsigned PresburgerSpace::getIdKindOffset(IdKind kind) const { if (kind == IdKind::Dimension) return 0; + if (kind == IdKind::Domain) + return 0; + if (kind == IdKind::Range) + return getNumDomainIds(); if (kind == IdKind::Symbol) return getNumDimIds(); if (kind == IdKind::Local) return getNumDimAndSymbolIds(); - llvm_unreachable("IdKind does not exit!"); + llvm_unreachable("IdKind does not exist!"); } unsigned PresburgerSpace::getIdKindEnd(IdKind kind) const { @@ -56,13 +64,23 @@ unsigned absolutePos = getIdKindOffset(kind) + pos; - if (kind == IdKind::Dimension) - numDims += num; - else if (kind == IdKind::Symbol) + if (kind == IdKind::Dimension) { + assert(spaceType == Set && "IdKind::Dimension can only be used to add " + "identifiers in a Space with type Set."); + numRange += num; + } else if (kind == IdKind::Range) { + numRange += num; + } else if (kind == IdKind::Domain) { + assert(spaceType == Relation && + "IdKind::Domain can only be used to add identifiers in a Space with " + "type Relation."); + numDomain += num; + } else if (kind == IdKind::Symbol) { numSymbols += num; - else + } else { llvm_unreachable( "PresburgerSpace only supports Dimensions and Symbol identifiers!"); + } return absolutePos; } @@ -76,13 +94,16 @@ // We are going to be removing one or more identifiers from the range. assert(idStart < getNumIds() && "invalid idStart position"); - // Update members numDims, numSymbols and numIds. - unsigned numDimsEliminated = - getIdKindOverlap(IdKind::Dimension, idStart, idLimit); + // Update members numDomain, numRange, numSymbols and numIds. + unsigned numDomainEliminated = + getIdKindOverlap(IdKind::Domain, idStart, idLimit); + unsigned numRangeEliminated = + getIdKindOverlap(IdKind::Range, idStart, idLimit); unsigned numSymbolsEliminated = getIdKindOverlap(IdKind::Symbol, idStart, idLimit); - numDims -= numDimsEliminated; + numDomain -= numDomainEliminated; + numRange -= numRangeEliminated; numSymbols -= numSymbolsEliminated; } @@ -108,8 +129,7 @@ getIdKindOverlap(IdKind::Local, idStart, idLimit); // Update space parameters. - PresburgerSpace::removeIdRange( - idStart, std::min(idLimit, PresburgerSpace::getNumIds())); + PresburgerSpace::removeIdRange(idStart, idLimit); // Update local ids. numLocals -= numLocalsEliminated; @@ -118,6 +138,22 @@ void PresburgerSpace::setDimSymbolSeparation(unsigned newSymbolCount) { assert(newSymbolCount <= getNumDimAndSymbolIds() && "invalid separation position"); - numDims = numDims + numSymbols - newSymbolCount; + numRange = numRange + numSymbols - newSymbolCount; numSymbols = newSymbolCount; } + +void PresburgerSpace::print(llvm::raw_ostream &os) const { + os << "Domain: " << getNumDomainIds() << ", " + << "Range: " << getNumRangeIds() << ", " + << "Symbols: " << getNumSymbolIds() << "\n"; +} + +void PresburgerSpace::dump() const { print(llvm::errs()); } + +void PresburgerLocalSpace::print(llvm::raw_ostream &os) const { + os << "Domain: " << getNumDomainIds() << ", " + << "Range: " << getNumRangeIds() << ", " + << "Symbols: " << getNumSymbolIds() << ", " << getNumLocalIds() << "\n"; +} + +void PresburgerLocalSpace::dump() const { print(llvm::errs()); } diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -3,6 +3,7 @@ LinearTransformTest.cpp MatrixTest.cpp PresburgerSetTest.cpp + PresburgerSpaceTest.cpp PWMAFunctionTest.cpp SimplexTest.cpp ../../Dialect/Affine/Analysis/AffineStructuresParser.cpp diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp @@ -0,0 +1,50 @@ +//===- PresburgerSpaceTest.cpp - Tests for PresburgerSpace ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/PresburgerSpace.h" +#include +#include + +using namespace mlir; +using IdKind = PresburgerSpace::IdKind; + +TEST(PresburgerSpaceTest, insertId) { + PresburgerSpace space(2, 2, 1); + + // Try inserting 2 domain ids. + space.insertId(PresburgerSpace::IdKind::Domain, 0, 2); + EXPECT_EQ(space.getNumDomainIds(), 4u); + + // Try inserting 1 range ids. + space.insertId(PresburgerSpace::IdKind::Range, 0, 1); + EXPECT_EQ(space.getNumRangeIds(), 3u); +} + +TEST(PresburgerSpaceTest, insertIdSet) { + PresburgerSpace space(2, 1); + + // Try inserting 2 dimension ids. The space should have 4 range ids since + // spaces which do not distinguish between domain, range are implemented like + // this. + space.insertId(PresburgerSpace::IdKind::Dimension, 0, 2); + EXPECT_EQ(space.getNumRangeIds(), 4u); +} + +TEST(PresburgerSpaceTest, removeIdRange) { + PresburgerSpace space(2, 1, 3); + + // Remove 1 domain identifier. + space.removeIdRange(0, 1); + EXPECT_EQ(space.getNumDomainIds(), 1u); + + // Remove 1 symbol and 1 range identifier. + space.removeIdRange(1, 3); + EXPECT_EQ(space.getNumDomainIds(), 1u); + EXPECT_EQ(space.getNumRangeIds(), 0u); + EXPECT_EQ(space.getNumSymbolIds(), 2u); +}