diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -28,6 +28,94 @@ /// as relations with zero domain vars. enum class VarKind { Symbol, Local, Domain, Range, SetDim = Range }; +// An Identifier stores a pointer to an object, such as a Value or an Operation. +// Identifiers are intended to be attached to a variable in a PresburgerSpace +// and can be used to check if two variables correspond to the same object. +// +// Take for example the following code: +// +// for i = 0 to 100 +// for j = 0 to 100 +// S0: A[j] = 0 +// for k = 0 to 100 +// S1: A[k] = 1 +// +// If we represent the space of iteration variables surrounding S0, S1 we have: +// space(S0): {d0, d1} +// space(S1): {d0, d1} +// +// Since the variables are in different spaces, without an identifier, there +// is no way to distinguish if the variables in the two spaces are +// different. So, we attach an Identifier corresponding to the loop iteration +// variable to them. Now, +// +// space(S0) = {d0(id = i), d1(id = j)} +// space(S1) = {d0(id = i), d1(id = k)}. +// +// Using the identifier, we can check that the first iteration variable in both +// the spaces are same, while they are different for second iteration variable. +// +// Identifiers storing null pointers are treating as having no +// attachment and are considered inequal to any other identifier, including +// identifiers with no attachments. +class Identifier { +public: + Identifier() = default; + + // Create an identifier from a pointer. The type of the pointer must have + // a `llvm::PointerLikeTypeTraits` specialization. + template + explicit Identifier(T value) + : value(llvm::PointerLikeTypeTraits::getAsVoidPointer(value)) { +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + idType = TypeID::get(); +#endif + } + + /// Get the value of the identifier casted to type `T`. `T` here should match + /// the type of the identifier used to create it. + template + T getValue() const { +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + assert(TypeID::get() == idType && + "Identifier was initialized with a different type than the one used " + "to retrieve it."); +#endif + return llvm::PointerLikeTypeTraits::getFromVoidPointer(value); + } + + bool hasValue() const { return value != nullptr; } + + /// Check if the two identifiers are equal. Null identifiers are considered + /// not equal. Asserts if two identifiers are equal but their types are not. + bool isEqual(const Identifier &other) const { + if (value == nullptr || other.value == nullptr) + return false; + assert(value == other.value && idType == other.idType && + "Values of Identifiers are equal but their types do not match."); + return value == other.value; + } + + bool operator==(const Identifier &other) const { return isEqual(other); } + bool operator!=(const Identifier &other) const { return !isEqual(other); } + + void print(llvm::raw_ostream &os) const { os << "Id<" << value << ">"; } + + void dump() const { + print(llvm::errs()); + llvm::errs() << "\n"; + } + +private: + /// The value of the identifier. + void *value = nullptr; + +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + /// TypeID of the identifiers in space. This should be used in asserts only. + TypeID idType = TypeID::get(); +#endif +}; + /// PresburgerSpace is the space of all possible values of a tuple of integer /// valued variables/variables. Each variable has one of the three types: /// @@ -66,14 +154,11 @@ /// other than Locals are equal. Equality of two spaces implies that number of /// variables of each kind are equal. /// -/// PresburgerSpace optionally also supports attaching some information to each -/// variable in space, called "identifier" of that variable. `resetIds` -/// is used to enable/reset these identifiers. All identifiers must be of the -/// same type, `IdType`. `IdType` must have a `llvm::PointerLikeTypeTraits` -/// specialization available and should be supported via `mlir::TypeID`. -/// -/// These identifiers can be used to check if two variables in two different -/// spaces are actually same variable. +/// PresburgerSpace optionally also supports attaching an Identifier with each +/// non-local variable in space. This is disabled by default. `resetIds` is +/// used to enable/reset these identifiers. The user can identify each variable +/// in the space as corresponding to some Identifier. Some example use cases are +/// described in the `Identifier` documentation above. class PresburgerSpace { public: static PresburgerSpace getRelationSpace(unsigned numDomain = 0, @@ -142,6 +227,20 @@ /// varLimit). The range is relative to the kind of variable. void removeVarRange(VarKind kind, unsigned varStart, unsigned varLimit); + /// Converts variables of the specified kind in the column range [srcPos, + /// srcPos + num) to variables of the specified kind at position dstPos. The + /// ranges are relative to the kind of variable. + /// + /// srcKind and dstKind must be different. + void convertVarKind(VarKind srcKind, unsigned srcPos, unsigned num, + VarKind dstKind, unsigned dstPos); + + /// Changes the partition between dimensions and symbols. Depending on the new + /// symbol count, either a chunk of dimensional variables immediately before + /// the split become symbols, or some of the symbols immediately after the + /// split become dimensions. + void setVarSymbolSeperation(unsigned newSymbolCount); + /// Swaps the posA^th variable of kindA and posB^th variable of kindB. void swapVar(VarKind kindA, VarKind kindB, unsigned posA, unsigned posB); @@ -154,77 +253,29 @@ /// locals). bool isEqual(const PresburgerSpace &other) const; - /// Changes the partition between dimensions and symbols. Depending on the new - /// symbol count, either a chunk of dimensional variables immediately before - /// the split become symbols, or some of the symbols immediately after the - /// split become dimensions. - void setVarSymbolSeperation(unsigned newSymbolCount); - - void print(llvm::raw_ostream &os) const; - void dump() const; - - //===--------------------------------------------------------------------===// - // Identifier Interactions - //===--------------------------------------------------------------------===// - - /// Set the identifier for `i^th` variable to `id`. `T` here should match the - /// type used to enable identifiers. - template - void setId(VarKind kind, unsigned i, T id) { -#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS - assert(TypeID::get() == idType && "Type mismatch"); -#endif - atId(kind, i) = llvm::PointerLikeTypeTraits::getAsVoidPointer(id); - } - - /// Get the identifier for `i^th` variable casted to type `T`. `T` here - /// should match the type used to enable identifiers. - template - T getId(VarKind kind, unsigned i) const { -#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS - assert(TypeID::get() == idType && "Type mismatch"); -#endif - return llvm::PointerLikeTypeTraits::getFromVoidPointer(atId(kind, i)); + /// Get the identifier of the specified variable. + Identifier &getId(VarKind kind, unsigned pos) { + assert(VarKind::Local != kind && "Local variables have no identifiers"); + return identifiers[getVarKindOffset(kind) + pos]; } - - /// Check if the i^th variable of the specified kind has a non-null - /// identifier. - bool hasId(VarKind kind, unsigned i) const { - return atId(kind, i) != nullptr; + Identifier getId(VarKind kind, unsigned pos) const { + assert(VarKind::Local != kind && "Local variables have no identifiers"); + return identifiers[getVarKindOffset(kind) + pos]; } - /// Check if the spaces are compatible, as well as have the same identifiers - /// for each variable. - bool isAligned(const PresburgerSpace &other) const; - /// Check if the number of variables of the specified kind match, and have - /// same identifiers with the other space. - bool isAligned(const PresburgerSpace &other, VarKind kind) const; - - /// Find the variable of the specified kind with identifier `id`. - /// Returns PresburgerSpace::kIdNotFound if identifier is not found. - template - unsigned findId(VarKind kind, T id) const { - unsigned i = 0; - for (unsigned e = getNumVarKind(kind); i < e; ++i) - if (hasId(kind, i) && getId(kind, i) == id) - return i; - return kIdNotFound; + ArrayRef getIds(VarKind kind) const { + assert(VarKind::Local != kind && "Local variables have no identifiers"); + return {identifiers.data() + getVarKindOffset(kind), getNumVarKind(kind)}; } - static const unsigned kIdNotFound = UINT_MAX; /// Returns if identifiers are being used. bool isUsingIds() const { return usingIds; } /// Reset the stored identifiers in the space. Enables `usingIds` if it was /// `false` before. - template void resetIds() { identifiers.clear(); identifiers.resize(getNumDimAndSymbolVars()); -#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS - idType = TypeID::get(); -#endif - usingIds = true; } @@ -234,26 +285,22 @@ usingIds = false; } + /// Check if the spaces are compatible, as well as each non-local variable at + /// the same position have equal identifiers. If the space is not using + /// Identifiers, this check is same as isCompatible. + bool isAligned(const PresburgerSpace &other) const; + /// Same as above but only check the specified VarKind. + bool isAligned(const PresburgerSpace &other, VarKind kind) const; + + void print(llvm::raw_ostream &os) const; + void dump() const; + protected: - PresburgerSpace(unsigned numDomain = 0, unsigned numRange = 0, - unsigned numSymbols = 0, unsigned numLocals = 0) + PresburgerSpace(unsigned numDomain, unsigned numRange, unsigned numSymbols, + unsigned numLocals) : numDomain(numDomain), numRange(numRange), numSymbols(numSymbols), numLocals(numLocals) {} - void *&atId(VarKind kind, unsigned i) { - assert(usingIds && "Cannot access identifiers when `usingIds` is false."); - assert(kind != VarKind::Local && - "Local variables cannot have identifiers."); - return identifiers[getVarKindOffset(kind) + i]; - } - - void *atId(VarKind kind, unsigned i) const { - assert(usingIds && "Cannot access identifiers when `usingIds` is false."); - assert(kind != VarKind::Local && - "Local variables cannot have identifiers."); - return identifiers[getVarKindOffset(kind) + i]; - } - private: // Number of variables corresponding to domain variables. unsigned numDomain; @@ -272,13 +319,8 @@ /// Stores whether or not identifiers are being used in this space. bool usingIds = false; -#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS - /// TypeID of the identifiers in space. This should be used in asserts only. - TypeID idType; -#endif - /// Stores an identifier for each non-local variable as a `void` pointer. - SmallVector identifiers; + SmallVector identifiers; }; } // namespace presburger diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -14,17 +14,22 @@ using namespace presburger; PresburgerSpace PresburgerSpace::getDomainSpace() const { - // TODO: Preserve identifiers here. - return PresburgerSpace::getSetSpace(numDomain, numSymbols, numLocals); + PresburgerSpace newSpace = *this; + newSpace.removeVarRange(VarKind::Range, 0, getNumRangeVars()); + newSpace.convertVarKind(VarKind::Domain, 0, getNumDomainVars(), + VarKind::SetDim, 0); + return newSpace; } PresburgerSpace PresburgerSpace::getRangeSpace() const { - return PresburgerSpace::getSetSpace(numRange, numSymbols, numLocals); + PresburgerSpace newSpace = *this; + newSpace.removeVarRange(VarKind::Domain, 0, getNumRangeVars()); + return newSpace; } PresburgerSpace PresburgerSpace::getSpaceWithoutLocals() const { PresburgerSpace space = *this; - space.removeVarRange(VarKind::Local, 0, numLocals); + space.removeVarRange(VarKind::Local, 0, getNumLocalVars()); return space; } @@ -36,7 +41,7 @@ if (kind == VarKind::Symbol) return getNumSymbolVars(); if (kind == VarKind::Local) - return numLocals; + return getNumLocalVars(); llvm_unreachable("VarKind does not exist!"); } @@ -101,7 +106,7 @@ // Insert NULL identifiers if `usingIds` and variables inserted are // not locals. if (usingIds && kind != VarKind::Local) - identifiers.insert(identifiers.begin() + absolutePos, num, nullptr); + identifiers.insert(identifiers.begin() + absolutePos, num, Identifier()); return absolutePos; } @@ -130,26 +135,71 @@ identifiers.begin() + getVarKindOffset(kind) + varLimit); } +void PresburgerSpace::convertVarKind(VarKind srcKind, unsigned srcPos, + unsigned num, VarKind dstKind, + unsigned dstPos) { + assert(srcKind != dstKind && "cannot convert variables to the same kind"); + assert(srcPos + num <= getNumVarKind(srcKind) && + "invalid range for source variables"); + assert(dstPos <= getNumVarKind(dstKind) && + "invalid position for destination variables"); + + auto addVars = [&](VarKind kind, int num) { + switch (kind) { + case VarKind::Domain: + numDomain += num; + break; + case VarKind::Range: + numRange += num; + break; + case VarKind::Symbol: + numSymbols += num; + break; + case VarKind::Local: + numLocals += num; + break; + } + }; + + addVars(srcKind, -(signed)num); + addVars(dstKind, num); + + // Move identifiers if `usingIds` and variables moved are not locals. + unsigned srcOffset = getVarKindOffset(srcKind) + srcPos; + unsigned dstOffset = getVarKindOffset(dstKind) + dstPos; + if (isUsingIds() && srcKind != VarKind::Local && dstKind != VarKind::Local) { + identifiers.insert(identifiers.begin() + dstOffset, + identifiers.begin() + srcOffset, + identifiers.begin() + srcOffset + num); + identifiers.erase(identifiers.begin() + srcOffset, + identifiers.begin() + srcOffset + num); + } else if (isUsingIds() && srcKind != VarKind::Local) { + identifiers.erase(identifiers.begin() + srcOffset, + identifiers.begin() + srcOffset + num); + } else if (isUsingIds() && dstKind != VarKind::Local) { + identifiers.insert(identifiers.begin() + dstOffset, num, Identifier()); + } +} + void PresburgerSpace::swapVar(VarKind kindA, VarKind kindB, unsigned posA, unsigned posB) { - - if (!usingIds) + if (!isUsingIds()) return; if (kindA == VarKind::Local && kindB == VarKind::Local) return; if (kindA == VarKind::Local) { - atId(kindB, posB) = nullptr; + getId(kindB, posB) = Identifier(); return; } if (kindB == VarKind::Local) { - atId(kindA, posA) = nullptr; + getId(kindA, posA) = Identifier(); return; } - std::swap(atId(kindA, posA), atId(kindB, posB)); + std::swap(getId(kindA, posA), getId(kindB, posB)); } bool PresburgerSpace::isCompatible(const PresburgerSpace &other) const { @@ -162,25 +212,53 @@ return isCompatible(other) && getNumLocalVars() == other.getNumLocalVars(); } +/// Checks if the number of ids of the given kind in the two spaces are +/// equal and if the ids are equal. Assumes that both spaces are using +/// ids. +static bool areIdsEqual(const PresburgerSpace &spaceA, + const PresburgerSpace &spaceB, VarKind kind) { + assert(spaceA.isUsingIds() && spaceB.isUsingIds() && + "Both spaces should be using ids"); + if (spaceA.getNumVarKind(kind) != spaceB.getNumVarKind(kind)) + return false; + if (kind == VarKind::Local) + return true; + unsigned numVars = spaceA.getNumVarKind(kind); + for (unsigned i = 0; i < numVars; ++i) + if (spaceA.getId(kind, i) != spaceB.getId(kind, i)) + return false; + return true; +} + bool PresburgerSpace::isAligned(const PresburgerSpace &other) const { - assert(isUsingIds() && other.isUsingIds() && - "Both spaces should be using identifiers to check for " - "alignment."); - return isCompatible(other) && identifiers == other.identifiers; + // If only one of the spaces is using identifiers, then they are + // not aligned. + if (isUsingIds() ^ other.isUsingIds()) + return false; + // If both spaces are using identifiers, then they are aligned if + // their identifiers are equal. + if (isUsingIds()) + return areIdsEqual(*this, other, VarKind::Domain) && + areIdsEqual(*this, other, VarKind::Range) && + areIdsEqual(*this, other, VarKind::Symbol); + // If neither space is using identifiers, then they are aligned if + // they are compatible. + return isCompatible(other); } bool PresburgerSpace::isAligned(const PresburgerSpace &other, VarKind kind) const { - assert(isUsingIds() && other.isUsingIds() && - "Both spaces should be using identifiers to check for " - "alignment."); - - ArrayRef kindAttachments = - ArrayRef(identifiers).slice(getVarKindOffset(kind), getNumVarKind(kind)); - ArrayRef otherKindAttachments = - ArrayRef(other.identifiers) - .slice(other.getVarKindOffset(kind), other.getNumVarKind(kind)); - return kindAttachments == otherKindAttachments; + // If only one of the spaces is using identifiers, then they are + // not aligned. + if (isUsingIds() ^ other.isUsingIds()) + return false; + // If both spaces are using identifiers, then they are aligned if + // their identifiers are equal. + if (isUsingIds()) + return areIdsEqual(*this, other, kind); + // If neither space is using identifiers, then they are aligned if + // the number of variable kind is equal. + return getNumVarKind(kind) == other.getNumVarKind(kind); } void PresburgerSpace::setVarSymbolSeperation(unsigned newSymbolCount) { @@ -198,16 +276,30 @@ << "Symbols: " << getNumSymbolVars() << ", " << "Locals: " << getNumLocalVars() << "\n"; - if (usingIds) { -#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS - os << "TypeID of identifiers: " << idType.getAsOpaquePointer() << "\n"; -#endif + if (isUsingIds()) { + auto printIds = [&](VarKind kind) { + os << " "; + for (unsigned i = 0; i < getNumVarKind(kind); ++i) { + Identifier id = getId(kind, i); + if (id.hasValue()) + id.print(os); + else + os << "None"; + os << " "; + } + }; os << "("; - for (void *identifier : identifiers) - os << identifier << " "; - os << ")\n"; + printIds(VarKind::Domain); + os << ") -> ("; + printIds(VarKind::Range); + os << ") : ["; + printIds(VarKind::Symbol); + os << "]"; } } -void PresburgerSpace::dump() const { print(llvm::errs()); } +void PresburgerSpace::dump() const { + print(llvm::errs()); + llvm::errs() << "\n"; +} diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp --- a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp @@ -52,12 +52,13 @@ TEST(PresburgerSpaceTest, insertVarIdentifier) { PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 2, 1, 0); - space.resetIds(); + space.resetIds(); - // Attach identifiers to domain ids. int identifiers[2] = {0, 1}; - space.setId(VarKind::Domain, 0, &identifiers[0]); - space.setId(VarKind::Domain, 1, &identifiers[1]); + + // Attach identifiers to domain ids. + space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]); + space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]); // Try inserting 2 domain ids. space.insertVar(VarKind::Domain, 0, 2); @@ -68,27 +69,27 @@ EXPECT_EQ(space.getNumRangeVars(), 3u); // Check if the identifiers for the old ids are still attached properly. - EXPECT_EQ(*space.getId(VarKind::Domain, 2), identifiers[0]); - EXPECT_EQ(*space.getId(VarKind::Domain, 3), identifiers[1]); + EXPECT_EQ(space.getId(VarKind::Domain, 2), Identifier(&identifiers[0])); + EXPECT_EQ(space.getId(VarKind::Domain, 3), Identifier(&identifiers[1])); } TEST(PresburgerSpaceTest, removeVarRangeIdentifier) { PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 3, 0); - space.resetIds(); + space.resetIds(); int identifiers[6] = {0, 1, 2, 3, 4, 5}; // Attach identifiers to domain identifiers. - space.setId(VarKind::Domain, 0, &identifiers[0]); - space.setId(VarKind::Domain, 1, &identifiers[1]); + space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]); + space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]); // Attach identifiers to range identifiers. - space.setId(VarKind::Range, 0, &identifiers[2]); + space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]); // Attach identifiers to symbol identifiers. - space.setId(VarKind::Symbol, 0, &identifiers[3]); - space.setId(VarKind::Symbol, 1, &identifiers[4]); - space.setId(VarKind::Symbol, 2, &identifiers[5]); + space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]); + space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]); + space.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]); // Remove 1 domain identifier. space.removeVarRange(VarKind::Domain, 0, 1); @@ -102,9 +103,9 @@ EXPECT_EQ(space.getNumSymbolVars(), 2u); // Check if domain identifiers are attached properly. - EXPECT_EQ(*space.getId(VarKind::Domain, 0), identifiers[1]); + EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[1])); // Check if symbol identifiers are attached properly. - EXPECT_EQ(*space.getId(VarKind::Range, 0), identifiers[4]); - EXPECT_EQ(*space.getId(VarKind::Range, 1), identifiers[5]); + EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[4])); + EXPECT_EQ(space.getId(VarKind::Range, 1), Identifier(&identifiers[5])); }