diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -14,7 +14,10 @@ #ifndef MLIR_ANALYSIS_PRESBURGER_PRESBURGERSPACE_H #define MLIR_ANALYSIS_PRESBURGER_PRESBURGERSPACE_H +#include "mlir/Support/TypeID.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/PointerLikeTypeTraits.h" #include "llvm/Support/raw_ostream.h" namespace mlir { @@ -62,6 +65,12 @@ /// Compatibility of two spaces implies that number of identifiers of each kind /// other than Locals are equal. Equality of two spaces implies that number of /// identifiers of each kind are equal. +/// +/// PresburgerSpace optionally also supports attaching values to each variable +/// in space. `resetValues` enables attaching values to space. All +/// values must be of same type, which is the type passed as template through +/// `resetValues`. `ValueType` must have a +/// `llvm::PointerLikeTypeTraits` specialization available. class PresburgerSpace { public: static PresburgerSpace getRelationSpace(unsigned numDomain = 0, @@ -119,6 +128,9 @@ /// idLimit). The range is relative to the kind of identifier. void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit); + /// Swaps the posA^th variable of kindA and posB^th variable of kindB. + void swapId(IdKind kindA, IdKind kindB, unsigned posA, unsigned posB); + /// Returns true if both the spaces are compatible i.e. if both spaces have /// the same number of identifiers of each kind (excluding locals). bool isCompatible(const PresburgerSpace &other) const; @@ -137,12 +149,91 @@ void print(llvm::raw_ostream &os) const; void dump() const; + //===--------------------------------------------------------------------===// + // Value Interactions + //===--------------------------------------------------------------------===// + + /// Set the value attached to the `i^th` variable to `value`. `T` here should + /// match the type used to enable values. + template + void setValue(IdKind kind, unsigned i, T value) { + assert(TypeID::get() == valueType && "Type mismatch"); + atValue(kind, i) = llvm::PointerLikeTypeTraits::getAsVoidPointer(value); + } + + /// Get the value attached to the `i^th` variable casted to type `T`. `T` here + /// should match the type used to enable values. + template + T getValue(IdKind kind, unsigned i) const { + assert(TypeID::get() == valueType && "Type mismatch"); + return llvm::PointerLikeTypeTraits::getFromVoidPointer(atValue(kind, i)); + } + + /// Check if the i^th variable of the specified kind has a non-null value. + bool hasValue(IdKind kind, unsigned i) const { + return atValue(kind, i) != nullptr; + } + + /// Check if the spaces are compatible, as well as have the same values + /// attached to each variable. + bool isAligned(const PresburgerSpace &other) const; + /// Check if the number of variables of the specified kind match, and have + /// same values with the other space. + bool isAligned(const PresburgerSpace &other, IdKind kind) const; + + /// Find the variable of the specified kind with value `val`. Returns number + /// of variables of the specified kind if not found. + template + unsigned findId(IdKind kind, T val) const { + unsigned i = 0; + for (unsigned e = getNumIdKind(kind); i < e; ++i) + if (hasValue(kind, i) && getValue(kind, i) == val) + break; + return i; + } + + /// Returns if values are being used. + bool isUsingValues() const { return usingValues; } + + /// Reset the stored values in the space. Enables `usingValues` if it was + /// `false` before. + template + void resetValues() { + values.clear(); + values.resize(getNumDimAndSymbolIds()); +#ifndef NDEBUG + valueType = TypeID::get(); +#endif + + usingValues = true; + } + + /// Disable values being stored in space. + void disableValues() { + values.clear(); + usingValues = false; + } + protected: PresburgerSpace(unsigned numDomain = 0, unsigned numRange = 0, unsigned numSymbols = 0, unsigned numLocals = 0) : numDomain(numDomain), numRange(numRange), numSymbols(numSymbols), numLocals(numLocals) {} + void *&atValue(IdKind kind, unsigned i) { + assert(usingValues && "Cannot access values when `usingValues` is false."); + assert(kind != IdKind::Local && + "Values cannot be attached to local identifiers."); + return values[getIdKindOffset(kind) + i]; + } + + void *atValue(IdKind kind, unsigned i) const { + assert(usingValues && "Cannot access values when `usingValues` is false."); + assert(kind != IdKind::Local && + "Values cannot be attached to local identifiers."); + return values[getIdKindOffset(kind) + i]; + } + private: // Number of identifiers corresponding to domain identifiers. unsigned numDomain; @@ -157,6 +248,17 @@ /// Number of identifers corresponding to locals (identifiers corresponding /// to existentially quantified variables). unsigned numLocals; + + /// Stores whether or not values are attached to this space. + bool usingValues = false; + +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + /// TypeID of the values in space. This should be used in asserts only. + TypeID valueType; +#endif + + /// Stores a value for each non-local identifier as a `void` pointer. + llvm::SmallVector values; }; } // namespace presburger diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -83,6 +83,10 @@ else numLocals += num; + // Insert NULL values if `usingValues` and variables inserted are not locals. + if (usingValues && kind != IdKind::Local) + values.insert(values.begin() + absolutePos, num, nullptr); + return absolutePos; } @@ -102,6 +106,33 @@ numSymbols -= numIdsEliminated; else numLocals -= numIdsEliminated; + + // Remove values if `usingValues` and variables removed are not locals. + if (usingValues && kind != IdKind::Local) + values.erase(values.begin() + getIdKindOffset(kind) + idStart, + values.begin() + getIdKindOffset(kind) + idLimit); +} + +void PresburgerSpace::swapId(IdKind kindA, IdKind kindB, unsigned posA, + unsigned posB) { + + if (!usingValues) + return; + + if (kindA == IdKind::Local && kindB == IdKind::Local) + return; + + if (kindA == IdKind::Local) { + atValue(kindB, posB) = nullptr; + return; + } + + if (kindB == IdKind::Local) { + atValue(kindA, posA) = nullptr; + return; + } + + std::swap(atValue(kindA, posA), atValue(kindB, posB)); } bool PresburgerSpace::isCompatible(const PresburgerSpace &other) const { @@ -114,11 +145,34 @@ return isCompatible(other) && getNumLocalIds() == other.getNumLocalIds(); } +bool PresburgerSpace::isAligned(const PresburgerSpace &other) const { + assert(isUsingValues() && other.isUsingValues() && + "Both spaces should be using values to check for " + "alignment."); + return isCompatible(other) && values == other.values; +} + +bool PresburgerSpace::isAligned(const PresburgerSpace &other, + IdKind kind) const { + assert(isUsingValues() && other.isUsingValues() && + "Both spaces should be using values to check for " + "alignment."); + + ArrayRef kindValues = + makeArrayRef(values).slice(getIdKindOffset(kind), getNumIdKind(kind)); + ArrayRef otherKindValues = + makeArrayRef(other.values) + .slice(other.getIdKindOffset(kind), other.getNumIdKind(kind)); + return kindValues == otherKindValues; +} + void PresburgerSpace::setDimSymbolSeparation(unsigned newSymbolCount) { assert(newSymbolCount <= getNumDimAndSymbolIds() && "invalid separation position"); numRange = numRange + numSymbols - newSymbolCount; numSymbols = newSymbolCount; + // We do not need to change `values` since the ordering of `values` remains + // same. } void PresburgerSpace::print(llvm::raw_ostream &os) const { @@ -126,6 +180,17 @@ << "Range: " << getNumRangeIds() << ", " << "Symbols: " << getNumSymbolIds() << ", " << "Locals: " << getNumLocalIds() << "\n"; + + if (usingValues) { +#ifndef NDEBUG + os << "TypeID of values: " << valueType.getAsOpaquePointer() << "\n"; +#endif + + os << "("; + for (void *value : values) + os << value << " "; + os << ")\n"; + } } void PresburgerSpace::dump() const { print(llvm::errs()); } diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp @@ -49,3 +49,62 @@ EXPECT_EQ(space.getNumRangeIds(), 0u); EXPECT_EQ(space.getNumSymbolIds(), 2u); } + +TEST(PresburgerSpaceTest, insertIdValue) { + PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 2, 1, 0); + space.resetValues(); + + // Attach value to domain ids. + int values[2] = {0, 1}; + space.setValue(IdKind::Domain, 0, &values[0]); + space.setValue(IdKind::Domain, 1, &values[1]); + + // Try inserting 2 domain ids. + space.insertId(IdKind::Domain, 0, 2); + EXPECT_EQ(space.getNumDomainIds(), 4u); + + // Try inserting 1 range ids. + space.insertId(IdKind::Range, 0, 1); + EXPECT_EQ(space.getNumRangeIds(), 3u); + + // Check if the values for the old ids are still attached properly. + EXPECT_EQ(*space.getValue(IdKind::Domain, 2), values[0]); + EXPECT_EQ(*space.getValue(IdKind::Domain, 3), values[1]); +} + +TEST(PresburgerSpaceTest, removeIdRangeValue) { + PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 3, 0); + space.resetValues(); + + int values[6] = {0, 1, 2, 3, 4, 5}; + + // Attach values to domain identifiers. + space.setValue(IdKind::Domain, 0, &values[0]); + space.setValue(IdKind::Domain, 1, &values[1]); + + // Attach values to range identifiers. + space.setValue(IdKind::Range, 0, &values[2]); + + // Attach values to symbol identifiers. + space.setValue(IdKind::Symbol, 0, &values[3]); + space.setValue(IdKind::Symbol, 1, &values[4]); + space.setValue(IdKind::Symbol, 2, &values[5]); + + // Remove 1 domain identifier. + space.removeIdRange(IdKind::Domain, 0, 1); + EXPECT_EQ(space.getNumDomainIds(), 1u); + + // Remove 1 symbol and 1 range identifier. + space.removeIdRange(IdKind::Symbol, 0, 1); + space.removeIdRange(IdKind::Range, 0, 1); + EXPECT_EQ(space.getNumDomainIds(), 1u); + EXPECT_EQ(space.getNumRangeIds(), 0u); + EXPECT_EQ(space.getNumSymbolIds(), 2u); + + // Check if domain values are attached properly. + EXPECT_EQ(*space.getValue(IdKind::Domain, 0), values[1]); + + // Check if symbol values are attached properly. + EXPECT_EQ(*space.getValue(IdKind::Range, 0), values[4]); + EXPECT_EQ(*space.getValue(IdKind::Range, 1), values[5]); +}