diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h --- a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h +++ b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h @@ -17,7 +17,7 @@ class SDBMDialect : public Dialect { public: - SDBMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) {} + SDBMDialect(MLIRContext *context); /// Since there are no other virtual methods in this derived class, override /// the destructor so that key methods get defined in the corresponding diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -133,17 +133,19 @@ template static T get(MLIRContext *ctx, unsigned kind, Args &&... args) { return ctx->getAttributeUniquer().get( + T::getTypeID(), [ctx](AttributeStorage *storage) { initializeAttributeStorage(storage, ctx, T::getTypeID()); }, kind, std::forward(args)...); } - template - static LogicalResult mutate(MLIRContext *ctx, ImplType *impl, + template + static LogicalResult mutate(MLIRContext *ctx, typename T::ImplType *impl, Args &&...args) { assert(impl && "cannot mutate null attribute"); - return ctx->getAttributeUniquer().mutate(impl, std::forward(args)...); + return ctx->getAttributeUniquer().mutate(T::getTypeID(), impl, + std::forward(args)...); } private: diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -109,8 +109,8 @@ /// The arguments are forwarded to 'ConcreteT::mutate'. template LogicalResult mutate(Args &&...args) { - return UniquerT::mutate(this->getContext(), getImpl(), - std::forward(args)...); + return UniquerT::template mutate(this->getContext(), getImpl(), + std::forward(args)...); } /// Default implementation that just returns success. diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -127,6 +127,7 @@ template static T get(MLIRContext *ctx, unsigned kind, Args &&... args) { return ctx->getTypeUniquer().get( + T::getTypeID(), [&](TypeStorage *storage) { storage->initialize(AbstractType::lookup(T::getTypeID(), ctx)); }, @@ -135,11 +136,12 @@ /// Change the mutable component of the given type instance in the provided /// context. - template - static LogicalResult mutate(MLIRContext *ctx, ImplType *impl, + template + static LogicalResult mutate(MLIRContext *ctx, typename T::ImplType *impl, Args &&...args) { assert(impl && "cannot mutate null type"); - return ctx->getTypeUniquer().mutate(impl, std::forward(args)...); + return ctx->getTypeUniquer().mutate(T::getTypeID(), impl, + std::forward(args)...); } }; } // namespace detail diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -15,6 +15,8 @@ #include "llvm/Support/Allocator.h" namespace mlir { +class TypeID; + namespace detail { struct StorageUniquerImpl; @@ -75,6 +77,10 @@ /// value of the function is used to indicate whether the mutation was /// successful, e.g., to limit the number of mutations or enable deferred /// one-time assignment of the mutable component. +/// +/// All storage classes must be registered with the uniquer via +/// `registerStorageType` using an appropriate unique `TypeID` for the storage +/// class. class StorageUniquer { public: StorageUniquer(); @@ -83,6 +89,10 @@ /// Set the flag specifying if multi-threading is disabled within the uniquer. void disableMultithreading(bool disable = true); + /// Register a new storage object with this uniquer using the given unique + /// type id. + void registerStorageType(TypeID id); + /// This class acts as the base storage that all storage classes must derived /// from. class BaseStorage { @@ -140,8 +150,8 @@ /// function is used for derived types that have complex storage or uniquing /// constraints. template - Storage *get(function_ref initFn, unsigned kind, Arg &&arg, - Args &&... args) { + Storage *get(const TypeID &id, function_ref initFn, + unsigned kind, Arg &&arg, Args &&...args) { // Construct a value of the derived key type. auto derivedKey = getKey(std::forward(arg), std::forward(args)...); @@ -163,7 +173,8 @@ }; // Get an instance for the derived storage. - return static_cast(getImpl(kind, hashValue, isEqual, ctorFn)); + return static_cast( + getImpl(id, kind, hashValue, isEqual, ctorFn)); } /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter @@ -171,31 +182,32 @@ /// function is used for derived types that use no additional storage or /// uniquing outside of the kind. template - Storage *get(function_ref initFn, unsigned kind) { + Storage *get(const TypeID &id, function_ref initFn, + unsigned kind) { auto ctorFn = [&](StorageAllocator &allocator) { auto *storage = new (allocator.allocate()) Storage(); if (initFn) initFn(storage); return storage; }; - return static_cast(getImpl(kind, ctorFn)); + return static_cast(getImpl(id, kind, ctorFn)); } /// Changes the mutable component of 'storage' by forwarding the trailing /// arguments to the 'mutate' function of the derived class. template - LogicalResult mutate(Storage *storage, Args &&...args) { + LogicalResult mutate(const TypeID &id, Storage *storage, Args &&...args) { auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult { return static_cast(*storage).mutate( allocator, std::forward(args)...); }; - return mutateImpl(mutationFn); + return mutateImpl(id, mutationFn); } /// Erases a uniqued instance of 'Storage'. This function is used for derived /// types that have complex storage or uniquing constraints. template - void erase(unsigned kind, Arg &&arg, Args &&... args) { + void erase(const TypeID &id, unsigned kind, Arg &&arg, Args &&...args) { // Construct a value of the derived key type. auto derivedKey = getKey(std::forward(arg), std::forward(args)...); @@ -209,7 +221,7 @@ }; // Attempt to erase the storage instance. - eraseImpl(kind, hashValue, isEqual, [](BaseStorage *storage) { + eraseImpl(id, kind, hashValue, isEqual, [](BaseStorage *storage) { static_cast(storage)->cleanup(); }); } @@ -217,24 +229,25 @@ private: /// Implementation for getting/creating an instance of a derived type with /// complex storage. - BaseStorage *getImpl(unsigned kind, unsigned hashValue, + BaseStorage *getImpl(const TypeID &id, unsigned kind, unsigned hashValue, function_ref isEqual, function_ref ctorFn); /// Implementation for getting/creating an instance of a derived type with /// default storage. - BaseStorage *getImpl(unsigned kind, + BaseStorage *getImpl(const TypeID &id, unsigned kind, function_ref ctorFn); /// Implementation for erasing an instance of a derived type with complex /// storage. - void eraseImpl(unsigned kind, unsigned hashValue, + void eraseImpl(const TypeID &id, unsigned kind, unsigned hashValue, function_ref isEqual, function_ref cleanupFn); /// Implementation for mutating an instance of a derived storage. LogicalResult - mutateImpl(function_ref mutationFn); + mutateImpl(const TypeID &id, + function_ref mutationFn); /// The internal implementation class. std::unique_ptr impl; @@ -249,7 +262,7 @@ static typename std::enable_if< llvm::is_detected::value, typename ImplTy::KeyTy>::type - getKey(Args &&... args) { + getKey(Args &&...args) { return ImplTy::getKey(args...); } /// If there is no 'ImplTy::getKey' method, then we try to directly construct @@ -258,7 +271,7 @@ static typename std::enable_if< !llvm::is_detected::value, typename ImplTy::KeyTy>::type - getKey(Args &&... args) { + getKey(Args &&...args) { return typename ImplTy::KeyTy(args...); } diff --git a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp --- a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp @@ -7,7 +7,17 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SDBM/SDBMDialect.h" +#include "SDBMExprDetail.h" using namespace mlir; +SDBMDialect::SDBMDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + uniquer.registerStorageType(TypeID::get()); + uniquer.registerStorageType(TypeID::get()); + uniquer.registerStorageType(TypeID::get()); + uniquer.registerStorageType(TypeID::get()); + uniquer.registerStorageType(TypeID::get()); +} + SDBMDialect::~SDBMDialect() = default; diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp --- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp @@ -246,6 +246,7 @@ StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(SDBMExprKind::Add), lhs, rhs); } @@ -533,6 +534,7 @@ StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(SDBMExprKind::Diff), lhs, rhs); } @@ -573,6 +575,7 @@ StorageUniquer &uniquer = var.getDialect()->getUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(SDBMExprKind::Stripe), var, stripeFactor); } @@ -608,7 +611,8 @@ StorageUniquer &uniquer = dialect->getUniquer(); return uniquer.get( - assignDialect, static_cast(SDBMExprKind::DimId), position); + TypeID::get(), assignDialect, + static_cast(SDBMExprKind::DimId), position); } //===----------------------------------------------------------------------===// @@ -624,7 +628,8 @@ StorageUniquer &uniquer = dialect->getUniquer(); return uniquer.get( - assignDialect, static_cast(SDBMExprKind::SymbolId), position); + TypeID::get(), assignDialect, + static_cast(SDBMExprKind::SymbolId), position); } //===----------------------------------------------------------------------===// @@ -640,7 +645,8 @@ StorageUniquer &uniquer = dialect->getUniquer(); return uniquer.get( - assignCtx, static_cast(SDBMExprKind::Constant), value); + TypeID::get(), assignCtx, + static_cast(SDBMExprKind::Constant), value); } int64_t SDBMConstantExpr::getValue() const { @@ -656,6 +662,7 @@ StorageUniquer &uniquer = var.getDialect()->getUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(SDBMExprKind::Neg), var); } diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Support/MathExtras.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/STLExtras.h" using namespace mlir; @@ -417,7 +418,8 @@ StorageUniquer &uniquer = context->getAffineUniquer(); return uniquer.get( - assignCtx, static_cast(kind), position); + TypeID::get(), assignCtx, + static_cast(kind), position); } AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { @@ -452,7 +454,8 @@ StorageUniquer &uniquer = context->getAffineUniquer(); return uniquer.get( - assignCtx, static_cast(AffineExprKind::Constant), constant); + TypeID::get(), assignCtx, + static_cast(AffineExprKind::Constant), constant); } /// Simplify add expression. Return nullptr if it can't be simplified. @@ -560,6 +563,7 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::Add), *this, other); } @@ -620,6 +624,7 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::Mul), *this, other); } @@ -686,6 +691,7 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::FloorDiv), *this, other); } @@ -729,6 +735,7 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::CeilDiv), *this, other); } @@ -776,6 +783,7 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( + TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::Mod), *this, other); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -400,6 +400,13 @@ /// The empty dictionary attribute. impl->emptyDictionaryAttr = AttributeUniquer::get( this, StandardAttributes::Dictionary, ArrayRef()); + + // Register the affine storage objects with the uniquer. + impl->affineUniquer.registerStorageType( + TypeID::get()); + impl->affineUniquer.registerStorageType( + TypeID::get()); + impl->affineUniquer.registerStorageType(TypeID::get()); } MLIRContext::~MLIRContext() {} @@ -561,6 +568,7 @@ AbstractType(std::move(typeInfo)); if (!impl.registeredTypes.insert({typeID, newInfo}).second) llvm::report_fatal_error("Dialect Type already registered."); + impl.typeUniquer.registerStorageType(typeID); } void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { @@ -570,6 +578,7 @@ AbstractAttribute(std::move(attrInfo)); if (!impl.registeredAttributes.insert({typeID, newInfo}).second) llvm::report_fatal_error("Dialect Attribute already registered."); + impl.attributeUniquer.registerStorageType(typeID); } /// Get the dialect that registered the attribute with the provided typeid. diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -9,15 +9,14 @@ #include "mlir/Support/StorageUniquer.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" #include "llvm/Support/RWMutex.h" using namespace mlir; using namespace mlir::detail; -namespace mlir { -namespace detail { -/// This is the implementation of the StorageUniquer class. -struct StorageUniquerImpl { +namespace { +struct InstSpecificUniquer { using BaseStorage = StorageUniquer::BaseStorage; using StorageAllocator = StorageUniquer::StorageAllocator; @@ -40,98 +39,158 @@ BaseStorage *storage; }; + /// Storage info for derived TypeStorage objects. + struct StorageKeyInfo : DenseMapInfo { + static HashedStorage getEmptyKey() { + return HashedStorage{0, DenseMapInfo::getEmptyKey()}; + } + static HashedStorage getTombstoneKey() { + return HashedStorage{0, DenseMapInfo::getTombstoneKey()}; + } + + static unsigned getHashValue(const HashedStorage &key) { + return key.hashValue; + } + static unsigned getHashValue(LookupKey key) { return key.hashValue; } + + static bool isEqual(const HashedStorage &lhs, const HashedStorage &rhs) { + return lhs.storage == rhs.storage; + } + static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) { + if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey())) + return false; + // If the lookup kind matches the kind of the storage, then invoke the + // equality function on the lookup key. + return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage); + } + }; + + /// Unique types with specific hashing or storage constraints. + using StorageTypeSet = DenseSet; + StorageTypeSet complexInstances; + + /// Instances of this storage object. + llvm::SmallDenseMap simpleInstances; + + /// Allocator to use when constructing derived instances. + StorageAllocator allocator; + + /// A mutex to keep type uniquing thread-safe. + llvm::sys::SmartRWMutex mutex; +}; +} // end anonymous namespace + +namespace mlir { +namespace detail { +/// This is the implementation of the StorageUniquer class. +struct StorageUniquerImpl { + using BaseStorage = StorageUniquer::BaseStorage; + using StorageAllocator = StorageUniquer::StorageAllocator; + /// Get or create an instance of a complex derived type. BaseStorage * - getOrCreate(unsigned kind, unsigned hashValue, + getOrCreate(TypeID id, unsigned kind, unsigned hashValue, function_ref isEqual, function_ref ctorFn) { - LookupKey lookupKey{kind, hashValue, isEqual}; + assert(instUniquers.count(id) && "creating unregistered storage instance"); + InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual}; + InstSpecificUniquer &storageUniquer = *instUniquers[id]; if (!threadingIsEnabled) - return getOrCreateUnsafe(kind, hashValue, lookupKey, ctorFn); + return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn); // Check for an existing instance in read-only mode. { - llvm::sys::SmartScopedReader typeLock(mutex); - auto it = storageTypes.find_as(lookupKey); - if (it != storageTypes.end()) + llvm::sys::SmartScopedReader typeLock(storageUniquer.mutex); + auto it = storageUniquer.complexInstances.find_as(lookupKey); + if (it != storageUniquer.complexInstances.end()) return it->storage; } // Acquire a writer-lock so that we can safely create the new type instance. - llvm::sys::SmartScopedWriter typeLock(mutex); - return getOrCreateUnsafe(kind, hashValue, lookupKey, ctorFn); + llvm::sys::SmartScopedWriter typeLock(storageUniquer.mutex); + return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn); } /// Get or create an instance of a complex derived type in an unsafe fashion. BaseStorage * - getOrCreateUnsafe(unsigned kind, unsigned hashValue, LookupKey &lookupKey, + getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind, + InstSpecificUniquer::LookupKey &lookupKey, function_ref ctorFn) { - auto existing = storageTypes.insert_as({}, lookupKey); + auto existing = storageUniquer.complexInstances.insert_as({}, lookupKey); if (!existing.second) return existing.first->storage; // Otherwise, construct and initialize the derived storage for this type // instance. - BaseStorage *storage = initializeStorage(kind, ctorFn); - *existing.first = HashedStorage{hashValue, storage}; + BaseStorage *storage = + initializeStorage(kind, storageUniquer.allocator, ctorFn); + *existing.first = + InstSpecificUniquer::HashedStorage{lookupKey.hashValue, storage}; return storage; } /// Get or create an instance of a simple derived type. BaseStorage * - getOrCreate(unsigned kind, + getOrCreate(TypeID id, unsigned kind, function_ref ctorFn) { + assert(instUniquers.count(id) && "creating unregistered storage instance"); + InstSpecificUniquer &storageUniquer = *instUniquers[id]; if (!threadingIsEnabled) - return getOrCreateUnsafe(kind, ctorFn); + return getOrCreateUnsafe(storageUniquer, kind, ctorFn); // Check for an existing instance in read-only mode. { - llvm::sys::SmartScopedReader typeLock(mutex); - auto it = simpleTypes.find(kind); - if (it != simpleTypes.end()) + llvm::sys::SmartScopedReader typeLock(storageUniquer.mutex); + auto it = storageUniquer.simpleInstances.find(kind); + if (it != storageUniquer.simpleInstances.end()) return it->second; } // Acquire a writer-lock so that we can safely create the new type instance. - llvm::sys::SmartScopedWriter typeLock(mutex); - return getOrCreateUnsafe(kind, ctorFn); + llvm::sys::SmartScopedWriter typeLock(storageUniquer.mutex); + return getOrCreateUnsafe(storageUniquer, kind, ctorFn); } /// Get or create an instance of a simple derived type in an unsafe fashion. BaseStorage * - getOrCreateUnsafe(unsigned kind, + getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind, function_ref ctorFn) { - auto &result = simpleTypes[kind]; + auto &result = storageUniquer.simpleInstances[kind]; if (result) return result; // Otherwise, create and return a new storage instance. - return result = initializeStorage(kind, ctorFn); + return result = initializeStorage(kind, storageUniquer.allocator, ctorFn); } /// Erase an instance of a complex derived type. - void erase(unsigned kind, unsigned hashValue, + void erase(TypeID id, unsigned kind, unsigned hashValue, function_ref isEqual, function_ref cleanupFn) { - LookupKey lookupKey{kind, hashValue, isEqual}; + assert(instUniquers.count(id) && "erasing unregistered storage instance"); + InstSpecificUniquer &storageUniquer = *instUniquers[id]; + InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual}; // Acquire a writer-lock so that we can safely erase the type instance. - llvm::sys::SmartScopedWriter typeLock(mutex); - auto existing = storageTypes.find_as(lookupKey); - if (existing == storageTypes.end()) + llvm::sys::SmartScopedWriter lock(storageUniquer.mutex); + auto existing = storageUniquer.complexInstances.find_as(lookupKey); + if (existing == storageUniquer.complexInstances.end()) return; // Cleanup the storage and remove it from the map. cleanupFn(existing->storage); - storageTypes.erase(existing); + storageUniquer.complexInstances.erase(existing); } /// Mutates an instance of a derived storage in a thread-safe way. LogicalResult - mutate(function_ref mutationFn) { + mutate(TypeID id, + function_ref mutationFn) { + assert(instUniquers.count(id) && "mutating unregistered storage instance"); + InstSpecificUniquer &storageUniquer = *instUniquers[id]; if (!threadingIsEnabled) - return mutationFn(allocator); + return mutationFn(storageUniquer.allocator); - llvm::sys::SmartScopedWriter lock(mutex); - return mutationFn(allocator); + llvm::sys::SmartScopedWriter lock(storageUniquer.mutex); + return mutationFn(storageUniquer.allocator); } //===--------------------------------------------------------------------===// @@ -140,51 +199,15 @@ /// Utility to create and initialize a storage instance. BaseStorage * - initializeStorage(unsigned kind, + initializeStorage(unsigned kind, StorageAllocator &allocator, function_ref ctorFn) { BaseStorage *storage = ctorFn(allocator); storage->kind = kind; return storage; } - /// Storage info for derived TypeStorage objects. - struct StorageKeyInfo : DenseMapInfo { - static HashedStorage getEmptyKey() { - return HashedStorage{0, DenseMapInfo::getEmptyKey()}; - } - static HashedStorage getTombstoneKey() { - return HashedStorage{0, DenseMapInfo::getTombstoneKey()}; - } - - static unsigned getHashValue(const HashedStorage &key) { - return key.hashValue; - } - static unsigned getHashValue(LookupKey key) { return key.hashValue; } - - static bool isEqual(const HashedStorage &lhs, const HashedStorage &rhs) { - return lhs.storage == rhs.storage; - } - static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) { - if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey())) - return false; - // If the lookup kind matches the kind of the storage, then invoke the - // equality function on the lookup key. - return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage); - } - }; - - /// Unique types with specific hashing or storage constraints. - using StorageTypeSet = DenseSet; - StorageTypeSet storageTypes; - - /// Unique types with just the kind. - DenseMap simpleTypes; - - /// Allocator to use when constructing derived type instances. - StorageUniquer::StorageAllocator allocator; - - /// A mutex to keep type uniquing thread-safe. - llvm::sys::SmartRWMutex mutex; + /// Map of type ids to the storage uniquer to use for registered objects. + DenseMap> instUniquers; /// Flag specifying if multi-threading is enabled within the uniquer. bool threadingIsEnabled = true; @@ -200,33 +223,41 @@ impl->threadingIsEnabled = !disable; } +/// Register a new storage object with this uniquer using the given unique type +/// id. +void StorageUniquer::registerStorageType(TypeID id) { + impl->instUniquers.try_emplace(id, std::make_unique()); +} + /// Implementation for getting/creating an instance of a derived type with /// complex storage. auto StorageUniquer::getImpl( - unsigned kind, unsigned hashValue, + const TypeID &id, unsigned kind, unsigned hashValue, function_ref isEqual, function_ref ctorFn) -> BaseStorage * { - return impl->getOrCreate(kind, hashValue, isEqual, ctorFn); + return impl->getOrCreate(id, kind, hashValue, isEqual, ctorFn); } /// Implementation for getting/creating an instance of a derived type with /// default storage. auto StorageUniquer::getImpl( - unsigned kind, function_ref ctorFn) - -> BaseStorage * { - return impl->getOrCreate(kind, ctorFn); + const TypeID &id, unsigned kind, + function_ref ctorFn) -> BaseStorage * { + return impl->getOrCreate(id, kind, ctorFn); } /// Implementation for erasing an instance of a derived type with complex /// storage. -void StorageUniquer::eraseImpl(unsigned kind, unsigned hashValue, +void StorageUniquer::eraseImpl(const TypeID &id, unsigned kind, + unsigned hashValue, function_ref isEqual, function_ref cleanupFn) { - impl->erase(kind, hashValue, isEqual, cleanupFn); + impl->erase(id, kind, hashValue, isEqual, cleanupFn); } /// Implementation for mutating an instance of a derived storage. LogicalResult StorageUniquer::mutateImpl( + const TypeID &id, function_ref mutationFn) { - return impl->mutate(mutationFn); + return impl->mutate(id, mutationFn); }