diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -14,7 +14,10 @@ #ifndef MLIR_ANALYSIS_PRESBURGER_PRESBURGERSPACE_H #define MLIR_ANALYSIS_PRESBURGER_PRESBURGERSPACE_H +#include "mlir/Support/TypeID.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/PointerLikeTypeTraits.h" #include "llvm/Support/raw_ostream.h" namespace mlir { @@ -62,6 +65,16 @@ /// Compatibility of two spaces implies that number of identifiers of each kind /// other than Locals are equal. Equality of two spaces implies that number of /// identifiers of each kind are equal. +/// +/// PresburgerSpace optionally also supports attaching attachements to each +/// variable in space. `resetAttachements` enables attaching +/// attachements to space. All attachements must be of the same type, +/// `AttachementType`. `AttachementType` must have a +/// `llvm::PointerLikeTypeTraits` specialization available and should be +/// supported via mlir::TypeID. +/// +/// These attachements can be used to check if two variables in two different +/// spaces correspond to the same variable. class PresburgerSpace { public: static PresburgerSpace getRelationSpace(unsigned numDomain = 0, @@ -113,12 +126,18 @@ /// Positions are relative to the kind of identifier. Return the absolute /// column position (i.e., not relative to the kind of identifier) of the /// first added identifier. + /// + /// If attachements are being used, the newly added variables have no + /// attachements. unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1); /// Removes identifiers of the specified kind in the column range [idStart, /// idLimit). The range is relative to the kind of identifier. void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit); + /// Swaps the posA^th variable of kindA and posB^th variable of kindB. + void swapId(IdKind kindA, IdKind kindB, unsigned posA, unsigned posB); + /// Returns true if both the spaces are compatible i.e. if both spaces have /// the same number of identifiers of each kind (excluding locals). bool isCompatible(const PresburgerSpace &other) const; @@ -137,12 +156,101 @@ void print(llvm::raw_ostream &os) const; void dump() const; + //===--------------------------------------------------------------------===// + // Attachement Interactions + //===--------------------------------------------------------------------===// + + /// Set the attachement for `i^th` variable to `attachement`. `T` here should + /// match the type used to enable attachements. + template + void setAttachement(IdKind kind, unsigned i, T attachement) { +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + assert(TypeID::get() == attachementType && "Type mismatch"); +#endif + atAttachement(kind, i) = + llvm::PointerLikeTypeTraits::getAsVoidPointer(attachement); + } + + /// Get the attachement for `i^th` variable casted to type `T`. `T` here + /// should match the type used to enable attachements. + template + T getAttachement(IdKind kind, unsigned i) const { +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + assert(TypeID::get() == attachementType && "Type mismatch"); +#endif + return llvm::PointerLikeTypeTraits::getFromVoidPointer( + atAttachement(kind, i)); + } + + /// Check if the i^th variable of the specified kind has a non-null + /// attachement. + bool hasAttachement(IdKind kind, unsigned i) const { + return atAttachement(kind, i) != nullptr; + } + + /// Check if the spaces are compatible, as well as have the same attachements + /// for each variable. + bool isAligned(const PresburgerSpace &other) const; + /// Check if the number of variables of the specified kind match, and have + /// same attachements with the other space. + bool isAligned(const PresburgerSpace &other, IdKind kind) const; + + /// Find the variable of the specified kind with attachement `val`. + /// PresburgerSpace::kIdNotFound if attachement is not found. + template + unsigned findId(IdKind kind, T val) const { + unsigned i = 0; + for (unsigned e = getNumIdKind(kind); i < e; ++i) + if (hasAttachement(kind, i) && getAttachement(kind, i) == val) + return i; + return kIdNotFound; + } + static const unsigned kIdNotFound = UINT_MAX; + + /// Returns if attachements are being used. + bool isUsingAttachements() const { return usingAttachements; } + + /// Reset the stored attachements in the space. Enables `usingAttachements` if + /// it was `false` before. + template + void resetAttachements() { + attachements.clear(); + attachements.resize(getNumDimAndSymbolIds()); +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + attachementType = TypeID::get(); +#endif + + usingAttachements = true; + } + + /// Disable attachements being stored in space. + void disableAttachements() { + attachements.clear(); + usingAttachements = false; + } + protected: PresburgerSpace(unsigned numDomain = 0, unsigned numRange = 0, unsigned numSymbols = 0, unsigned numLocals = 0) : numDomain(numDomain), numRange(numRange), numSymbols(numSymbols), numLocals(numLocals) {} + void *&atAttachement(IdKind kind, unsigned i) { + assert(usingAttachements && + "Cannot access attachements when `usingAttachements` is false."); + assert(kind != IdKind::Local && + "Local variables cannot have attachements."); + return attachements[getIdKindOffset(kind) + i]; + } + + void *atAttachement(IdKind kind, unsigned i) const { + assert(usingAttachements && + "Cannot access attachements when `usingAttachements` is false."); + assert(kind != IdKind::Local && + "Local variables cannot have attachements."); + return attachements[getIdKindOffset(kind) + i]; + } + private: // Number of identifiers corresponding to domain identifiers. unsigned numDomain; @@ -157,6 +265,17 @@ /// Number of identifers corresponding to locals (identifiers corresponding /// to existentially quantified variables). unsigned numLocals; + + /// Stores whether or not attachements are being used in this space. + bool usingAttachements = false; + +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + /// TypeID of the attachements in space. This should be used in asserts only. + TypeID attachementType; +#endif + + /// Stores a attachement for each non-local identifier as a `void` pointer. + SmallVector attachements; }; } // namespace presburger diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -83,6 +83,11 @@ else numLocals += num; + // Insert NULL attachements if `usingAttachements` and variables inserted are + // not locals. + if (usingAttachements && kind != IdKind::Local) + attachements.insert(attachements.begin() + absolutePos, num, nullptr); + return absolutePos; } @@ -102,6 +107,34 @@ numSymbols -= numIdsEliminated; else numLocals -= numIdsEliminated; + + // Remove attachements if `usingAttachements` and variables removed are not + // locals. + if (usingAttachements && kind != IdKind::Local) + attachements.erase(attachements.begin() + getIdKindOffset(kind) + idStart, + attachements.begin() + getIdKindOffset(kind) + idLimit); +} + +void PresburgerSpace::swapId(IdKind kindA, IdKind kindB, unsigned posA, + unsigned posB) { + + if (!usingAttachements) + return; + + if (kindA == IdKind::Local && kindB == IdKind::Local) + return; + + if (kindA == IdKind::Local) { + atAttachement(kindB, posB) = nullptr; + return; + } + + if (kindB == IdKind::Local) { + atAttachement(kindA, posA) = nullptr; + return; + } + + std::swap(atAttachement(kindA, posA), atAttachement(kindB, posB)); } bool PresburgerSpace::isCompatible(const PresburgerSpace &other) const { @@ -114,11 +147,35 @@ return isCompatible(other) && getNumLocalIds() == other.getNumLocalIds(); } +bool PresburgerSpace::isAligned(const PresburgerSpace &other) const { + assert(isUsingAttachements() && other.isUsingAttachements() && + "Both spaces should be using attachements to check for " + "alignment."); + return isCompatible(other) && attachements == other.attachements; +} + +bool PresburgerSpace::isAligned(const PresburgerSpace &other, + IdKind kind) const { + assert(isUsingAttachements() && other.isUsingAttachements() && + "Both spaces should be using attachements to check for " + "alignment."); + + ArrayRef kindAttachements = + makeArrayRef(attachements) + .slice(getIdKindOffset(kind), getNumIdKind(kind)); + ArrayRef otherKindAttachements = + makeArrayRef(other.attachements) + .slice(other.getIdKindOffset(kind), other.getNumIdKind(kind)); + return kindAttachements == otherKindAttachements; +} + void PresburgerSpace::setDimSymbolSeparation(unsigned newSymbolCount) { assert(newSymbolCount <= getNumDimAndSymbolIds() && "invalid separation position"); numRange = numRange + numSymbols - newSymbolCount; numSymbols = newSymbolCount; + // We do not need to change `attachements` since the ordering of + // `attachements` remains same. } void PresburgerSpace::print(llvm::raw_ostream &os) const { @@ -126,6 +183,18 @@ << "Range: " << getNumRangeIds() << ", " << "Symbols: " << getNumSymbolIds() << ", " << "Locals: " << getNumLocalIds() << "\n"; + + if (usingAttachements) { +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + os << "TypeID of attachements: " << attachementType.getAsOpaquePointer() + << "\n"; +#endif + + os << "("; + for (void *attachement : attachements) + os << attachement << " "; + os << ")\n"; + } } void PresburgerSpace::dump() const { print(llvm::errs()); } diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp @@ -49,3 +49,62 @@ EXPECT_EQ(space.getNumRangeIds(), 0u); EXPECT_EQ(space.getNumSymbolIds(), 2u); } + +TEST(PresburgerSpaceTest, insertIdValue) { + PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 2, 1, 0); + space.resetValues(); + + // Attach value to domain ids. + int values[2] = {0, 1}; + space.setValue(IdKind::Domain, 0, &values[0]); + space.setValue(IdKind::Domain, 1, &values[1]); + + // Try inserting 2 domain ids. + space.insertId(IdKind::Domain, 0, 2); + EXPECT_EQ(space.getNumDomainIds(), 4u); + + // Try inserting 1 range ids. + space.insertId(IdKind::Range, 0, 1); + EXPECT_EQ(space.getNumRangeIds(), 3u); + + // Check if the values for the old ids are still attached properly. + EXPECT_EQ(*space.getValue(IdKind::Domain, 2), values[0]); + EXPECT_EQ(*space.getValue(IdKind::Domain, 3), values[1]); +} + +TEST(PresburgerSpaceTest, removeIdRangeValue) { + PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 3, 0); + space.resetValues(); + + int values[6] = {0, 1, 2, 3, 4, 5}; + + // Attach values to domain identifiers. + space.setValue(IdKind::Domain, 0, &values[0]); + space.setValue(IdKind::Domain, 1, &values[1]); + + // Attach values to range identifiers. + space.setValue(IdKind::Range, 0, &values[2]); + + // Attach values to symbol identifiers. + space.setValue(IdKind::Symbol, 0, &values[3]); + space.setValue(IdKind::Symbol, 1, &values[4]); + space.setValue(IdKind::Symbol, 2, &values[5]); + + // Remove 1 domain identifier. + space.removeIdRange(IdKind::Domain, 0, 1); + EXPECT_EQ(space.getNumDomainIds(), 1u); + + // Remove 1 symbol and 1 range identifier. + space.removeIdRange(IdKind::Symbol, 0, 1); + space.removeIdRange(IdKind::Range, 0, 1); + EXPECT_EQ(space.getNumDomainIds(), 1u); + EXPECT_EQ(space.getNumRangeIds(), 0u); + EXPECT_EQ(space.getNumSymbolIds(), 2u); + + // Check if domain values are attached properly. + EXPECT_EQ(*space.getValue(IdKind::Domain, 0), values[1]); + + // Check if symbol values are attached properly. + EXPECT_EQ(*space.getValue(IdKind::Range, 0), values[4]); + EXPECT_EQ(*space.getValue(IdKind::Range, 1), values[5]); +}