diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h --- a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h +++ b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h @@ -17,7 +17,7 @@ class SDBMDialect : public Dialect { public: - SDBMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) {} + SDBMDialect(MLIRContext *context); /// Since there are no other virtual methods in this derived class, override /// the destructor so that key methods get defined in the corresponding diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -135,6 +135,7 @@ template static T get(MLIRContext *ctx, unsigned kind, Args &&... args) { return ctx->getAttributeUniquer().get( + T::getTypeID(), [ctx](AttributeStorage *storage) { initializeAttributeStorage(storage, ctx, T::getTypeID()); }, diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -129,6 +129,7 @@ template static T get(MLIRContext *ctx, unsigned kind, Args &&... args) { return ctx->getTypeUniquer().get( + T::getTypeID(), [&](TypeStorage *storage) { storage->initialize(AbstractType::lookup(T::getTypeID(), ctx)); }, diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -14,6 +14,8 @@ #include "llvm/Support/Allocator.h" namespace mlir { +class TypeID; + namespace detail { struct StorageUniquerImpl; @@ -60,6 +62,10 @@ /// that is called when erasing a storage instance. This should cleanup any /// fields of the storage as necessary and not attempt to free the memory /// of the storage itself. +/// +/// All storage classes must be registered with the uniquer via +/// `registerStorageType` using an appropriate unique `TypeID` for the storage +/// class. class StorageUniquer { public: StorageUniquer(); @@ -68,6 +74,10 @@ /// Set the flag specifying if multi-threading is disabled within the uniquer. void disableMultithreading(bool disable = true); + /// Register a new storage object with this uniquer using the given unique + /// type id. + void registerStorageType(TypeID id); + /// This class acts as the base storage that all storage classes must derived /// from. class BaseStorage { @@ -125,8 +135,8 @@ /// function is used for derived types that have complex storage or uniquing /// constraints. template - Storage *get(function_ref initFn, unsigned kind, Arg &&arg, - Args &&... args) { + Storage *get(const TypeID &id, function_ref initFn, + unsigned kind, Arg &&arg, Args &&... args) { // Construct a value of the derived key type. auto derivedKey = getKey(std::forward(arg), std::forward(args)...); @@ -148,7 +158,8 @@ }; // Get an instance for the derived storage. - return static_cast(getImpl(kind, hashValue, isEqual, ctorFn)); + return static_cast( + getImpl(id, kind, hashValue, isEqual, ctorFn)); } /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter @@ -156,20 +167,21 @@ /// function is used for derived types that use no additional storage or /// uniquing outside of the kind. template - Storage *get(function_ref initFn, unsigned kind) { + Storage *get(const TypeID &id, function_ref initFn, + unsigned kind) { auto ctorFn = [&](StorageAllocator &allocator) { auto *storage = new (allocator.allocate()) Storage(); if (initFn) initFn(storage); return storage; }; - return static_cast(getImpl(kind, ctorFn)); + return static_cast(getImpl(id, kind, ctorFn)); } /// Erases a uniqued instance of 'Storage'. This function is used for derived /// types that have complex storage or uniquing constraints. template - void erase(unsigned kind, Arg &&arg, Args &&... args) { + void erase(const TypeID &id, unsigned kind, Arg &&arg, Args &&... args) { // Construct a value of the derived key type. auto derivedKey = getKey(std::forward(arg), std::forward(args)...); @@ -183,7 +195,7 @@ }; // Attempt to erase the storage instance. - eraseImpl(kind, hashValue, isEqual, [](BaseStorage *storage) { + eraseImpl(id, kind, hashValue, isEqual, [](BaseStorage *storage) { static_cast(storage)->cleanup(); }); } @@ -191,18 +203,18 @@ private: /// Implementation for getting/creating an instance of a derived type with /// complex storage. - BaseStorage *getImpl(unsigned kind, unsigned hashValue, + BaseStorage *getImpl(const TypeID &id, unsigned kind, unsigned hashValue, function_ref isEqual, function_ref ctorFn); /// Implementation for getting/creating an instance of a derived type with /// default storage. - BaseStorage *getImpl(unsigned kind, + BaseStorage *getImpl(const TypeID &id, unsigned kind, function_ref ctorFn); /// Implementation for erasing an instance of a derived type with complex /// storage. - void eraseImpl(unsigned kind, unsigned hashValue, + void eraseImpl(const TypeID &id, unsigned kind, unsigned hashValue, function_ref isEqual, function_ref cleanupFn); diff --git a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp --- a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp @@ -7,7 +7,17 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SDBM/SDBMDialect.h" +#include "SDBMExprDetail.h" using namespace mlir; +SDBMDialect::SDBMDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + uniquer.registerStorageType(TypeID::get()); + uniquer.registerStorageType(TypeID::get()); + uniquer.registerStorageType(TypeID::get()); + uniquer.registerStorageType(TypeID::get()); + uniquer.registerStorageType(TypeID::get()); +} + SDBMDialect::~SDBMDialect() = default; diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp --- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp @@ -246,6 +246,7 @@ StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(SDBMExprKind::Add), lhs, rhs); } @@ -533,6 +534,7 @@ StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(SDBMExprKind::Diff), lhs, rhs); } @@ -573,6 +575,7 @@ StorageUniquer &uniquer = var.getDialect()->getUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(SDBMExprKind::Stripe), var, stripeFactor); } @@ -608,7 +611,8 @@ StorageUniquer &uniquer = dialect->getUniquer(); return uniquer.get( - assignDialect, static_cast(SDBMExprKind::DimId), position); + TypeID::get(), assignDialect, + static_cast(SDBMExprKind::DimId), position); } //===----------------------------------------------------------------------===// @@ -624,7 +628,8 @@ StorageUniquer &uniquer = dialect->getUniquer(); return uniquer.get( - assignDialect, static_cast(SDBMExprKind::SymbolId), position); + TypeID::get(), assignDialect, + static_cast(SDBMExprKind::SymbolId), position); } //===----------------------------------------------------------------------===// @@ -640,7 +645,8 @@ StorageUniquer &uniquer = dialect->getUniquer(); return uniquer.get( - assignCtx, static_cast(SDBMExprKind::Constant), value); + TypeID::get(), assignCtx, + static_cast(SDBMExprKind::Constant), value); } int64_t SDBMConstantExpr::getValue() const { @@ -656,6 +662,7 @@ StorageUniquer &uniquer = var.getDialect()->getUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(SDBMExprKind::Neg), var); } diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Support/MathExtras.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/STLExtras.h" using namespace mlir; @@ -245,7 +246,8 @@ StorageUniquer &uniquer = context->getAffineUniquer(); return uniquer.get( - assignCtx, static_cast(kind), position); + TypeID::get(), assignCtx, + static_cast(kind), position); } AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { @@ -280,7 +282,8 @@ StorageUniquer &uniquer = context->getAffineUniquer(); return uniquer.get( - assignCtx, static_cast(AffineExprKind::Constant), constant); + TypeID::get(), assignCtx, + static_cast(AffineExprKind::Constant), constant); } /// Simplify add expression. Return nullptr if it can't be simplified. @@ -388,6 +391,7 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::Add), *this, other); } @@ -448,6 +452,7 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::Mul), *this, other); } @@ -514,6 +519,7 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::FloorDiv), *this, other); } @@ -557,6 +563,7 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::CeilDiv), *this, other); } @@ -604,6 +611,7 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::Mod), *this, other); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -394,6 +394,13 @@ /// The empty dictionary attribute. impl->emptyDictionaryAttr = AttributeUniquer::get( this, StandardAttributes::Dictionary, ArrayRef()); + + // Register the affine storage objects with the uniquer. + impl->affineUniquer.registerStorageType( + TypeID::get()); + impl->affineUniquer.registerStorageType(TypeID::get()); + impl->affineUniquer.registerStorageType( + TypeID::get()); } MLIRContext::~MLIRContext() {} @@ -557,6 +564,7 @@ llvm::errs() << "error: dialect type already registered.\n"; abort(); } + impl.typeUniquer.registerStorageType(typeID); } void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { @@ -568,6 +576,7 @@ llvm::errs() << "error: dialect attribute already registered.\n"; abort(); } + impl.attributeUniquer.registerStorageType(typeID); } /// Get the dialect that registered the attribute with the provided typeid. diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -9,15 +9,29 @@ #include "mlir/Support/StorageUniquer.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" #include "llvm/Support/RWMutex.h" using namespace mlir; using namespace mlir::detail; -namespace mlir { -namespace detail { -/// This is the implementation of the StorageUniquer class. -struct StorageUniquerImpl { +namespace { +/// This class stores instances of a simple storage object, i.e. the storage +/// objects that are only uniqued on kind. +struct SimpleUniquer { + using BaseStorage = StorageUniquer::BaseStorage; + using StorageAllocator = StorageUniquer::StorageAllocator; + + /// Instances of this storage object. + llvm::SmallDenseMap instances; + + /// Allocator to use when constructing derived instances. + StorageAllocator allocator; + + /// A mutex to keep uniquing thread-safe. + llvm::sys::SmartRWMutex mutex; +}; +struct ComplexUniquer { using BaseStorage = StorageUniquer::BaseStorage; using StorageAllocator = StorageUniquer::StorageAllocator; @@ -40,88 +54,157 @@ BaseStorage *storage; }; + /// Storage info for derived TypeStorage objects. + struct StorageKeyInfo : DenseMapInfo { + static HashedStorage getEmptyKey() { + return HashedStorage{0, DenseMapInfo::getEmptyKey()}; + } + static HashedStorage getTombstoneKey() { + return HashedStorage{0, DenseMapInfo::getTombstoneKey()}; + } + + static unsigned getHashValue(const HashedStorage &key) { + return key.hashValue; + } + static unsigned getHashValue(LookupKey key) { return key.hashValue; } + + static bool isEqual(const HashedStorage &lhs, const HashedStorage &rhs) { + return lhs.storage == rhs.storage; + } + static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) { + if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey())) + return false; + // If the lookup kind matches the kind of the storage, then invoke the + // equality function on the lookup key. + return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage); + } + }; + + /// Unique types with specific hashing or storage constraints. + using StorageTypeSet = DenseSet; + StorageTypeSet instances; + + /// Allocator to use when constructing derived instances. + StorageAllocator allocator; + + /// A mutex to keep type uniquing thread-safe. + llvm::sys::SmartRWMutex mutex; +}; + +/// This class represents the uniquer for a specific storage object type. It +/// provides uniquers for creating instances with simple and complex storage. +class ObjectStorageUniquer { +public: + ComplexUniquer &getComplexStorage() { return complexStorage; } + SimpleUniquer &getSimpleStorage() { return simpleStorage; } + +private: + ComplexUniquer complexStorage; + SimpleUniquer simpleStorage; +}; +} // end anonymous namespace + +namespace mlir { +namespace detail { +/// This is the implementation of the StorageUniquer class. +struct StorageUniquerImpl { + using BaseStorage = StorageUniquer::BaseStorage; + using StorageAllocator = StorageUniquer::StorageAllocator; + /// Get or create an instance of a complex derived type. BaseStorage * - getOrCreate(unsigned kind, unsigned hashValue, + getOrCreate(TypeID id, unsigned kind, unsigned hashValue, function_ref isEqual, function_ref ctorFn) { - LookupKey lookupKey{kind, hashValue, isEqual}; + assert(objectStorage.count(id) && "creating unregistered storage instance"); + + ComplexUniquer::LookupKey lookupKey{kind, hashValue, isEqual}; + auto &storageUniquer = objectStorage[id]->getComplexStorage(); if (!threadingIsEnabled) - return getOrCreateUnsafe(kind, hashValue, lookupKey, ctorFn); + return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn); // Check for an existing instance in read-only mode. { - llvm::sys::SmartScopedReader typeLock(mutex); - auto it = storageTypes.find_as(lookupKey); - if (it != storageTypes.end()) + llvm::sys::SmartScopedReader typeLock(storageUniquer.mutex); + auto it = storageUniquer.instances.find_as(lookupKey); + if (it != storageUniquer.instances.end()) return it->storage; } // Acquire a writer-lock so that we can safely create the new type instance. - llvm::sys::SmartScopedWriter typeLock(mutex); - return getOrCreateUnsafe(kind, hashValue, lookupKey, ctorFn); + llvm::sys::SmartScopedWriter typeLock(storageUniquer.mutex); + return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn); } /// Get or create an instance of a complex derived type in an unsafe fashion. BaseStorage * - getOrCreateUnsafe(unsigned kind, unsigned hashValue, LookupKey &lookupKey, + getOrCreateUnsafe(ComplexUniquer &storageUniquer, unsigned kind, + ComplexUniquer::LookupKey &lookupKey, function_ref ctorFn) { - auto existing = storageTypes.insert_as({}, lookupKey); + auto existing = storageUniquer.instances.insert_as({}, lookupKey); if (!existing.second) return existing.first->storage; // Otherwise, construct and initialize the derived storage for this type // instance. - BaseStorage *storage = initializeStorage(kind, ctorFn); - *existing.first = HashedStorage{hashValue, storage}; + BaseStorage *storage = + initializeStorage(kind, storageUniquer.allocator, ctorFn); + *existing.first = + ComplexUniquer::HashedStorage{lookupKey.hashValue, storage}; return storage; } /// Get or create an instance of a simple derived type. BaseStorage * - getOrCreate(unsigned kind, + getOrCreate(TypeID id, unsigned kind, function_ref ctorFn) { + assert(objectStorage.count(id) && "creating unregistered storage instance"); + + auto &storageUniquer = objectStorage[id]->getSimpleStorage(); if (!threadingIsEnabled) - return getOrCreateUnsafe(kind, ctorFn); + return getOrCreateUnsafe(storageUniquer, kind, ctorFn); // Check for an existing instance in read-only mode. { - llvm::sys::SmartScopedReader typeLock(mutex); - auto it = simpleTypes.find(kind); - if (it != simpleTypes.end()) + llvm::sys::SmartScopedReader typeLock(storageUniquer.mutex); + auto it = storageUniquer.instances.find(kind); + if (it != storageUniquer.instances.end()) return it->second; } // Acquire a writer-lock so that we can safely create the new type instance. - llvm::sys::SmartScopedWriter typeLock(mutex); - return getOrCreateUnsafe(kind, ctorFn); + llvm::sys::SmartScopedWriter typeLock(storageUniquer.mutex); + return getOrCreateUnsafe(storageUniquer, kind, ctorFn); } /// Get or create an instance of a simple derived type in an unsafe fashion. BaseStorage * - getOrCreateUnsafe(unsigned kind, + getOrCreateUnsafe(SimpleUniquer &storageUniquer, unsigned kind, function_ref ctorFn) { - auto &result = simpleTypes[kind]; + auto &result = storageUniquer.instances[kind]; if (result) return result; // Otherwise, create and return a new storage instance. - return result = initializeStorage(kind, ctorFn); + return result = initializeStorage(kind, storageUniquer.allocator, ctorFn); } /// Erase an instance of a complex derived type. - void erase(unsigned kind, unsigned hashValue, + void erase(TypeID id, unsigned kind, unsigned hashValue, function_ref isEqual, function_ref cleanupFn) { - LookupKey lookupKey{kind, hashValue, isEqual}; + assert(objectStorage.count(id) && "creating unregistered storage instance"); + + auto &storageUniquer = objectStorage[id]->getComplexStorage(); + ComplexUniquer::LookupKey lookupKey{kind, hashValue, isEqual}; // Acquire a writer-lock so that we can safely erase the type instance. - llvm::sys::SmartScopedWriter typeLock(mutex); - auto existing = storageTypes.find_as(lookupKey); - if (existing == storageTypes.end()) + llvm::sys::SmartScopedWriter typeLock(storageUniquer.mutex); + auto existing = storageUniquer.instances.find_as(lookupKey); + if (existing == storageUniquer.instances.end()) return; // Cleanup the storage and remove it from the map. cleanupFn(existing->storage); - storageTypes.erase(existing); + storageUniquer.instances.erase(existing); } //===--------------------------------------------------------------------===// @@ -130,51 +213,15 @@ /// Utility to create and initialize a storage instance. BaseStorage * - initializeStorage(unsigned kind, + initializeStorage(unsigned kind, StorageAllocator &allocator, function_ref ctorFn) { BaseStorage *storage = ctorFn(allocator); storage->kind = kind; return storage; } - /// Storage info for derived TypeStorage objects. - struct StorageKeyInfo : DenseMapInfo { - static HashedStorage getEmptyKey() { - return HashedStorage{0, DenseMapInfo::getEmptyKey()}; - } - static HashedStorage getTombstoneKey() { - return HashedStorage{0, DenseMapInfo::getTombstoneKey()}; - } - - static unsigned getHashValue(const HashedStorage &key) { - return key.hashValue; - } - static unsigned getHashValue(LookupKey key) { return key.hashValue; } - - static bool isEqual(const HashedStorage &lhs, const HashedStorage &rhs) { - return lhs.storage == rhs.storage; - } - static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) { - if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey())) - return false; - // If the lookup kind matches the kind of the storage, then invoke the - // equality function on the lookup key. - return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage); - } - }; - - /// Unique types with specific hashing or storage constraints. - using StorageTypeSet = DenseSet; - StorageTypeSet storageTypes; - - /// Unique types with just the kind. - DenseMap simpleTypes; - - /// Allocator to use when constructing derived type instances. - StorageUniquer::StorageAllocator allocator; - - /// A mutex to keep type uniquing thread-safe. - llvm::sys::SmartRWMutex mutex; + /// Map of type ids to the storage uniquer to use for registered objects. + DenseMap> objectStorage; /// Flag specifying if multi-threading is enabled within the uniquer. bool threadingIsEnabled = true; @@ -190,27 +237,34 @@ impl->threadingIsEnabled = !disable; } +/// Register a new storage object with this uniquer using the given unique type +/// id. +void StorageUniquer::registerStorageType(TypeID id) { + impl->objectStorage.try_emplace(id, std::make_unique()); +} + /// Implementation for getting/creating an instance of a derived type with /// complex storage. auto StorageUniquer::getImpl( - unsigned kind, unsigned hashValue, + const TypeID &id, unsigned kind, unsigned hashValue, function_ref isEqual, function_ref ctorFn) -> BaseStorage * { - return impl->getOrCreate(kind, hashValue, isEqual, ctorFn); + return impl->getOrCreate(id, kind, hashValue, isEqual, ctorFn); } /// Implementation for getting/creating an instance of a derived type with /// default storage. auto StorageUniquer::getImpl( - unsigned kind, function_ref ctorFn) - -> BaseStorage * { - return impl->getOrCreate(kind, ctorFn); + const TypeID &id, unsigned kind, + function_ref ctorFn) -> BaseStorage * { + return impl->getOrCreate(id, kind, ctorFn); } /// Implementation for erasing an instance of a derived type with complex /// storage. -void StorageUniquer::eraseImpl(unsigned kind, unsigned hashValue, +void StorageUniquer::eraseImpl(const TypeID &id, unsigned kind, + unsigned hashValue, function_ref isEqual, function_ref cleanupFn) { - impl->erase(kind, hashValue, isEqual, cleanupFn); + impl->erase(id, kind, hashValue, isEqual, cleanupFn); }