diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -77,69 +77,74 @@ //===----------------------------------------------------------------------===// // Resource Entry -/// This class is used to build resource entries for use by the printer. Each -/// resource entry is represented using a key/value pair. The provided key must -/// be unique within the current context, which allows for a client to provide -/// resource entries without worrying about overlap with other clients. -class AsmResourceBuilder { -public: - virtual ~AsmResourceBuilder(); - - /// Build a resource entry represented by the given bool. - virtual void buildBool(StringRef key, bool data) = 0; - - /// Build a resource entry represented by the given human-readable string - /// value. - virtual void buildString(StringRef key, StringRef data) = 0; - - /// Build an resource entry represented by the given binary blob data. - virtual void buildBlob(StringRef key, ArrayRef data, - uint32_t dataAlignment) = 0; - /// Build an resource entry represented by the given binary blob data. This is - /// a useful overload if the data type is known. Note that this does not - /// support `char` element types to avoid accidentally not providing the - /// expected alignment of data in situations that treat blobs generically. - template - std::enable_if_t::value> buildBlob(StringRef key, - ArrayRef data) { - buildBlob( - key, ArrayRef((const char *)data.data(), data.size() * sizeof(T)), - alignof(T)); - } -}; - /// This class represents a processed binary blob of data. A resource blob is /// essentially a collection of data, potentially mutable, with an associated /// deleter function (used if the data needs to be destroyed). class AsmResourceBlob { public: - /// A deleter function that frees a blob given the data and allocation size. - using DeleterFn = llvm::unique_function; + /// A deleter function that frees a blob given the data, allocation size, and + /// allocation aligment. + using DeleterFn = + llvm::unique_function; + + //===--------------------------------------------------------------------===// + // Construction + //===--------------------------------------------------------------------===// AsmResourceBlob() = default; - AsmResourceBlob(ArrayRef data, DeleterFn deleter, bool dataIsMutable) - : data(data), deleter(std::move(deleter)), dataIsMutable(dataIsMutable) {} + AsmResourceBlob(ArrayRef data, size_t dataAlignment, DeleterFn deleter, + bool dataIsMutable) + : data(data), dataAlignment(dataAlignment), deleter(std::move(deleter)), + dataIsMutable(dataIsMutable) {} /// Utility constructor that initializes a blob with a non-char type T. template AsmResourceBlob(ArrayRef data, DelT &&deleteFn, bool dataIsMutable) : data((const char *)data.data(), data.size() * sizeof(T)), - deleter([deleteFn = std::forward(deleteFn)](const void *data, - size_t size) { - return deleteFn((const T *)data, size); + dataAlignment(alignof(T)), + deleter([deleteFn = std::forward(deleteFn)]( + void *data, size_t size, size_t align) { + return deleteFn((T *)data, size, align); }), dataIsMutable(dataIsMutable) {} AsmResourceBlob(AsmResourceBlob &&) = default; - AsmResourceBlob &operator=(AsmResourceBlob &&) = default; + AsmResourceBlob &operator=(AsmResourceBlob &&rhs) { + // Delete the current blob if necessary. + if (deleter) + deleter(const_cast(data.data()), data.size(), dataAlignment); + + // Take the data entries from rhs. + data = rhs.data; + dataAlignment = rhs.dataAlignment; + deleter = std::move(rhs.deleter); + dataIsMutable = rhs.dataIsMutable; + return *this; + } AsmResourceBlob(const AsmResourceBlob &) = delete; AsmResourceBlob &operator=(const AsmResourceBlob &) = delete; ~AsmResourceBlob() { if (deleter) - deleter(data.data(), data.size()); + deleter(const_cast(data.data()), data.size(), dataAlignment); } + //===--------------------------------------------------------------------===// + // Data Access + //===--------------------------------------------------------------------===// + + /// Return the alignment of the underlying data. + size_t getDataAlignment() const { return dataAlignment; } + /// Return the raw underlying data of this blob. ArrayRef getData() const { return data; } + /// Return the underlying data as an array of the given type. This is an + /// inherrently unsafe operation, and should only be used when the data is + /// known to be of the correct type. + template + ArrayRef getDataAs() const { + return llvm::makeArrayRef((const T *)data.data(), + data.size() / sizeof(T)); + } + /// Return a mutable reference to the raw underlying data of this blob. /// Asserts that the blob `isMutable`. MutableArrayRef getMutableData() { @@ -159,6 +164,9 @@ /// The raw, properly aligned, blob data. ArrayRef data; + /// The alignment of the data. + size_t dataAlignment = 0; + /// An optional deleter function used to deallocate the underlying data when /// necessary. DeleterFn deleter; @@ -167,6 +175,92 @@ bool dataIsMutable; }; +/// This class provides a simple utility wrapper for creating heap allocated +/// AsmResourceBlobs. +class HeapAsmResourceBlob { +public: + /// Create a new heap allocated blob with the given size and alignment. + /// `dataIsMutable` indicates if the allocated data can be mutated. By + /// default, we treat heap allocated blobs as mutable. + static AsmResourceBlob allocate(size_t size, size_t align, + bool dataIsMutable = true) { + return AsmResourceBlob( + ArrayRef((char *)llvm::allocate_buffer(size, align), size), align, + llvm::deallocate_buffer, dataIsMutable); + } + /// Create a new heap allocated blob and copy the provided data into it. + static AsmResourceBlob allocateAndCopy(ArrayRef data, size_t align, + bool dataIsMutable = true) { + AsmResourceBlob blob = allocate(data.size(), align, dataIsMutable); + std::memcpy(blob.getMutableData().data(), data.data(), data.size()); + return blob; + } + template + static std::enable_if_t::value, AsmResourceBlob> + allocateAndCopy(ArrayRef data, bool dataIsMutable = true) { + return allocateAndCopy( + ArrayRef((const char *)data.data(), data.size() * sizeof(T)), + alignof(T)); + } +}; +/// This class provides a simple utility wrapper for creating "unmanaged" +/// AsmResourceBlobs. The lifetime of the data provided to these blobs is +/// guaranteed to persist beyond the lifetime of this reference. +class UnmanagedAsmResourceBlob { +public: + /// Create a new unmanaged resource directly referencing the provided data. + /// `dataIsMutable` indicates if the allocated data can be mutated. By + /// default, we treat unmanaged blobs as immutable. + static AsmResourceBlob allocate(ArrayRef data, size_t align, + bool dataIsMutable = false) { + return AsmResourceBlob(data, align, /*deleter=*/{}, + /*dataIsMutable=*/false); + } + template + static std::enable_if_t::value, AsmResourceBlob> + allocate(ArrayRef data, bool dataIsMutable = false) { + return allocate( + ArrayRef((const char *)data.data(), data.size() * sizeof(T)), + alignof(T)); + } +}; + +/// This class is used to build resource entries for use by the printer. Each +/// resource entry is represented using a key/value pair. The provided key must +/// be unique within the current context, which allows for a client to provide +/// resource entries without worrying about overlap with other clients. +class AsmResourceBuilder { +public: + virtual ~AsmResourceBuilder(); + + /// Build a resource entry represented by the given bool. + virtual void buildBool(StringRef key, bool data) = 0; + + /// Build a resource entry represented by the given human-readable string + /// value. + virtual void buildString(StringRef key, StringRef data) = 0; + + /// Build an resource entry represented by the given binary blob data. + virtual void buildBlob(StringRef key, ArrayRef data, + uint32_t dataAlignment) = 0; + /// Build an resource entry represented by the given binary blob data. This is + /// a useful overload if the data type is known. Note that this does not + /// support `char` element types to avoid accidentally not providing the + /// expected alignment of data in situations that treat blobs generically. + template + std::enable_if_t::value> buildBlob(StringRef key, + ArrayRef data) { + buildBlob( + key, ArrayRef((const char *)data.data(), data.size() * sizeof(T)), + alignof(T)); + } + /// Build an resource entry represented by the given resource blob. This is + /// a useful overload if a blob already exists in-memory. + void buildBlob(StringRef key, const AsmResourceBlob &blob) { + buildBlob(key, blob.getData(), blob.getDataAlignment()); + } +}; + /// This class represents a single parsed resource entry. class AsmParsedResourceEntry { public: @@ -186,17 +280,24 @@ /// failure if the entry does not correspond to a string. virtual FailureOr parseAsString() const = 0; - /// The type of an allocator function used to allocate memory for a blob when - /// required. The function is provided a size and alignment, and should return - /// an aligned allocation buffer. + /// An allocator function used to allocate memory for a blob when required. + /// The function is provided a size and alignment, and should return an + /// aligned allocation buffer. using BlobAllocatorFn = - function_ref; + function_ref; /// Parse the resource entry represented by a binary blob. Returns failure if /// the entry does not correspond to a blob. If the blob needed to be /// allocated, the given allocator function is invoked. virtual FailureOr parseAsBlob(BlobAllocatorFn allocator) const = 0; + /// Parse the resource entry represented by a binary blob using heap + /// allocation. + FailureOr parseAsBlob() const { + return parseAsBlob([](size_t size, size_t align) { + return HeapAsmResourceBlob::allocate(size, align); + }); + } }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -157,13 +157,13 @@ /// Lookup an interface for the given ID if one is registered, otherwise /// nullptr. - const DialectInterface *getRegisteredInterface(TypeID interfaceID) { + DialectInterface *getRegisteredInterface(TypeID interfaceID) { auto it = registeredInterfaces.find(interfaceID); return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr; } template - const InterfaceT *getRegisteredInterface() { - return static_cast( + InterfaceT *getRegisteredInterface() { + return static_cast( getRegisteredInterface(InterfaceT::getInterfaceID())); } @@ -189,6 +189,12 @@ (void)std::initializer_list{ 0, (addInterface(std::make_unique(this)), 0)...}; } + template + InterfaceT &addInterface(Args &&...args) { + InterfaceT *interface = new InterfaceT(this, std::forward(args)...); + addInterface(std::unique_ptr(interface)); + return *interface; + } protected: /// The constructor takes a unique namespace for this dialect as well as the @@ -305,15 +311,11 @@ }; template struct cast_retty_impl { - using ret_type = - std::conditional_t::value, T *, - const T *>; + using ret_type = T *; }; template struct cast_retty_impl { - using ret_type = - std::conditional_t::value, T &, - const T &>; + using ret_type = T &; }; template @@ -325,7 +327,7 @@ } template static std::enable_if_t::value, - const To &> + To &> doitImpl(::mlir::Dialect &dialect) { return *dialect.getRegisteredInterface(); } diff --git a/mlir/include/mlir/IR/DialectResourceBlobManager.h b/mlir/include/mlir/IR/DialectResourceBlobManager.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/DialectResourceBlobManager.h @@ -0,0 +1,210 @@ +//===- DialectResourceBlobManager.h - Dialect Blob Management ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines utility classes for referencing and managing asm resource +// blobs. These classes are intended to more easily facilitate the sharing of +// large blobs, and their definition. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_DIALECTRESOURCEBLOBMANAGER_H +#define MLIR_IR_DIALECTRESOURCEBLOBMANAGER_H + +#include "mlir/IR/AsmState.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/RWMutex.h" +#include "llvm/Support/SMLoc.h" + +namespace mlir { +//===----------------------------------------------------------------------===// +// DialectResourceBlobManager +//===---------------------------------------------------------------------===// + +/// This class defines a manager for dialect resource blobs. Blobs are uniqued +/// by a given key, and represented using AsmResourceBlobs. +class DialectResourceBlobManager { +public: + /// The class represents an individual entry of a blob. + class BlobEntry { + public: + /// Return the key used to reference this blob. + StringRef getKey() const { return key; } + + /// Return the blob owned by this entry. + const AsmResourceBlob &getBlob() const { return blob; } + AsmResourceBlob &getBlob() { return blob; } + + /// Set the blob owned by this entry. + void setBlob(AsmResourceBlob &&newBlob) { blob = std::move(newBlob); } + + private: + BlobEntry() = default; + BlobEntry(BlobEntry &&) = default; + BlobEntry &operator=(const BlobEntry &) = delete; + BlobEntry &operator=(BlobEntry &&) = delete; + + /// Initialize this entry with the given key and blob. + void initialize(StringRef newKey, AsmResourceBlob &&newBlob) { + key = newKey; + blob = std::move(newBlob); + } + + /// The key used for this blob. + StringRef key; + + /// The blob that is referenced by this entry. + AsmResourceBlob blob; + + /// Allow access to the constructors. + friend DialectResourceBlobManager; + friend class llvm::StringMapEntryStorage; + }; + + /// Return the blob registered for the given name, or nullptr if no blob + /// is registered. + BlobEntry *lookup(StringRef name); + const BlobEntry *lookup(StringRef name) const { + return const_cast(this)->lookup(name); + } + + /// Update the blob for the entry defined by the provided name. This method + /// asserts that an entry for the given name exists in the manager. + void update(StringRef name, AsmResourceBlob &&newBlob); + + /// Insert a new entry with the provided name and optional blob data. The name + /// may be modified during insertion if another entry already exists with that + /// name. Returns the inserted entry. + BlobEntry &insert(StringRef name, AsmResourceBlob &&blob = {}); + /// Insertion method that returns a dialect specific handle to the inserted + /// entry. + template + HandleT insert(typename HandleT::Dialect *dialect, StringRef name, + AsmResourceBlob &&blob = {}) { + BlobEntry &entry = insert(name, std::move(blob)); + return HandleT(&entry, dialect); + } + +private: + /// A mutex to protect access to the blob map. + llvm::sys::SmartRWMutex blobMapLock; + + /// The internal map of tracked blobs. StringMap stores entries in distinct + /// allocations, so we can freely take references to the data without fear of + /// invalidation during additional insertion/deletion. + llvm::StringMap blobMap; +}; + +//===----------------------------------------------------------------------===// +// ResourceBlobManagerDialectInterface +//===---------------------------------------------------------------------===// + +/// This class implements a dialect interface that provides common functionality +/// for interacting with a resource blob manager. +class ResourceBlobManagerDialectInterface + : public DialectInterface::Base { +public: + ResourceBlobManagerDialectInterface(Dialect *dialect) + : Base(dialect), + blobManager(std::make_shared()) {} + + /// Return the blob manager held by this interface. + DialectResourceBlobManager &getBlobManager() { return *blobManager; } + const DialectResourceBlobManager &getBlobManager() const { + return *blobManager; + } + + /// Set the blob manager held by this interface. + void + setBlobManager(std::shared_ptr newBlobManager) { + blobManager = std::move(newBlobManager); + } + +private: + /// The blob manager owned by the dialect implementing this interface. + std::shared_ptr blobManager; +}; + +/// This class provides a base class for dialects implementing the resource blob +/// interface. It provides several additional dialect specific utilities on top +/// of the generic interface. `HandleT` is the type of the handle used to +/// reference a resource blob. +template +class ResourceBlobManagerDialectInterfaceBase + : public ResourceBlobManagerDialectInterface { +public: + using ResourceBlobManagerDialectInterface:: + ResourceBlobManagerDialectInterface; + + /// Update the blob for the entry defined by the provided name. This method + /// asserts that an entry for the given name exists in the manager. + void update(StringRef name, AsmResourceBlob &&newBlob) { + getBlobManager().update(name, std::move(newBlob)); + } + + /// Insert a new resource blob entry with the provided name and optional blob + /// data. The name may be modified during insertion if another entry already + /// exists with that name. Returns a dialect specific handle to the inserted + /// entry. + HandleT insert(StringRef name, AsmResourceBlob &&blob = {}) { + return getBlobManager().template insert( + cast(getDialect()), name, std::move(blob)); + } + + /// Build resources for each of the referenced blobs within this manager. + void buildResources(AsmResourceBuilder &provider, + ArrayRef referencedResources) { + for (const AsmDialectResourceHandle &handle : referencedResources) + if (const auto *dialectHandle = dyn_cast(&handle)) + provider.buildBlob(dialectHandle->getKey(), dialectHandle->getBlob()); + } +}; + +//===----------------------------------------------------------------------===// +// DialectResourceBlobHandle +//===----------------------------------------------------------------------===// + +/// This class defines a dialect specific handle to a resource blob. These +/// handles utilize a StringRef for the internal key, and an AsmResourceBlob as +/// the underlying data. +template +struct DialectResourceBlobHandle + : public AsmDialectResourceHandleBase, + DialectResourceBlobManager::BlobEntry, + DialectT> { + using AsmDialectResourceHandleBase, + DialectResourceBlobManager::BlobEntry, + DialectT>::AsmDialectResourceHandleBase; + using ManagerInterface = ResourceBlobManagerDialectInterfaceBase< + DialectResourceBlobHandle>; + + /// Return the human readable string key for this handle. + StringRef getKey() const { return this->getResource()->getKey(); } + + /// Return the blob referenced by this handle. + AsmResourceBlob &getBlob() { return this->getResource()->getBlob(); } + const AsmResourceBlob &getBlob() const { + return this->getResource()->getBlob(); + } + + /// Get the interface for the dialect that owns handles of this type. Asserts + /// that the dialect is registered. + static ManagerInterface &getManagerInterface(MLIRContext *ctx) { + auto *dialect = ctx->getOrLoadDialect(); + assert(dialect && "dialect not registered"); + + auto *iface = dialect->template getRegisteredInterface(); + assert(iface && "dialect doesn't provide the blob manager interface?"); + return *iface; + } +}; + +} // namespace mlir + +#endif // MLIR_IR_DIALECTRESOURCEBLOBMANAGER_H diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -12,6 +12,7 @@ BuiltinTypeInterfaces.cpp Diagnostics.cpp Dialect.cpp + DialectResourceBlobManager.cpp Dominance.cpp ExtensibleDialect.cpp FunctionImplementation.cpp diff --git a/mlir/lib/IR/DialectResourceBlobManager.cpp b/mlir/lib/IR/DialectResourceBlobManager.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/DialectResourceBlobManager.cpp @@ -0,0 +1,63 @@ +//===- DialectResourceBlobManager.cpp - Dialect Blob Management -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/DialectResourceBlobManager.h" +#include "llvm/ADT/SmallString.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// DialectResourceBlobManager +//===---------------------------------------------------------------------===// + +auto DialectResourceBlobManager::lookup(StringRef name) -> BlobEntry * { + llvm::sys::SmartScopedReader reader(blobMapLock); + + auto it = blobMap.find(name); + return it != blobMap.end() ? &it->second : nullptr; +} + +void DialectResourceBlobManager::update(StringRef name, + AsmResourceBlob &&newBlob) { + BlobEntry *entry = lookup(name); + assert(entry && "`update` expects an existing entry for the provided name"); + entry->setBlob(std::move(newBlob)); +} + +auto DialectResourceBlobManager::insert(StringRef name, AsmResourceBlob &&blob) + -> BlobEntry & { + llvm::sys::SmartScopedWriter writer(blobMapLock); + + // Functor used to attempt insertion with a given name. + auto tryInsertion = [&](StringRef name) -> BlobEntry * { + auto it = blobMap.try_emplace(name, BlobEntry()); + if (it.second) { + it.first->second.initialize(it.first->getKey(), std::move(blob)); + return &it.first->second; + } + return nullptr; + }; + + // Try inserting with the name provided by the user. + if (BlobEntry *entry = tryInsertion(name)) + return *entry; + + // If an entry already exists for the user provided name, tweak the name and + // re-attempt insertion until we find one that is unique. + llvm::SmallString<32> nameStorage(name); + nameStorage.push_back('_'); + size_t nameCounter = 1; + do { + Twine(nameCounter++).toVector(nameStorage); + + // Try inserting with the new name. + if (BlobEntry *entry = tryInsertion(name)) + return *entry; + nameStorage.resize(name.size() + 1); + } while (true); +} diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -243,7 +243,7 @@ let mnemonic = "e1di64_elements"; let parameters = (ins AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type, - ResourceHandleParameter<"TestExternalElementsDataHandle">:$handle + ResourceHandleParameter<"TestDialectResourceBlobHandle">:$handle ); let extraClassDeclaration = [{ /// Return the elements referenced by this attribute. diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h --- a/mlir/test/lib/Dialect/Test/TestAttributes.h +++ b/mlir/test/lib/Dialect/Test/TestAttributes.h @@ -24,9 +24,14 @@ #include "TestAttrInterfaces.h.inc" #include "TestOpEnums.h.inc" +#include "mlir/IR/DialectResourceBlobManager.h" namespace test { -struct TestExternalElementsDataHandle; +class TestDialect; + +/// A handle used to reference external elements instances. +using TestDialectResourceBlobHandle = + mlir::DialectResourceBlobHandle; } // namespace test #define GET_ATTRDEF_CLASSES diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -199,7 +199,7 @@ //===----------------------------------------------------------------------===// ArrayRef TestExtern1DI64ElementsAttr::getElements() const { - return getHandle().getData()->getData(); + return getHandle().getBlob().getDataAs(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -25,6 +25,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" @@ -45,68 +46,6 @@ class RewritePatternSet; } // namespace mlir -namespace test { -class TestDialect; - -//===----------------------------------------------------------------------===// -// External Elements Data -//===----------------------------------------------------------------------===// - -/// This class represents a single external elements instance. It keeps track of -/// the data, and deallocates when destructed. -class TestExternalElementsData : public mlir::AsmResourceBlob { -public: - using mlir::AsmResourceBlob::AsmResourceBlob; - TestExternalElementsData(mlir::AsmResourceBlob &&blob) - : mlir::AsmResourceBlob(std::move(blob)) {} - - /// Return the data of this external elements instance. - llvm::ArrayRef getData() const; - - /// Allocate a new external elements instance with the given number of - /// elements. - static TestExternalElementsData allocate(size_t numElements); -}; - -/// A handle used to reference external elements instances. -struct TestExternalElementsDataHandle - : public mlir::AsmDialectResourceHandleBase< - TestExternalElementsDataHandle, - llvm::StringMapEntry>, - TestDialect> { - using AsmDialectResourceHandleBase::AsmDialectResourceHandleBase; - - /// Return a key to use for this handle. - llvm::StringRef getKey() const { return getResource()->getKey(); } - - /// Return the data referenced by this handle. - TestExternalElementsData *getData() const { - return getResource()->getValue().get(); - } -}; - -/// This class acts as a manager for external elements data. It provides API -/// for creating and accessing registered elements data. -class TestExternalElementsDataManager { - using DataMap = llvm::StringMap>; - -public: - /// Return the data registered for the given name, or nullptr if no data is - /// registered. - const TestExternalElementsData *getData(llvm::StringRef name) const; - - /// Register an entry with the provided name, which may be modified if another - /// entry was already inserted with that name. Returns the inserted entry. - std::pair insert(llvm::StringRef name); - - /// Set the data for the given entry, which is expected to exist. - void setData(llvm::StringRef name, TestExternalElementsData &&data); - -private: - llvm::StringMap> dataMap; -}; -} // namespace test - //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -44,55 +44,6 @@ registry.insert(); } -//===----------------------------------------------------------------------===// -// External Elements Data -//===----------------------------------------------------------------------===// - -ArrayRef TestExternalElementsData::getData() const { - ArrayRef data = AsmResourceBlob::getData(); - return ArrayRef((const uint64_t *)data.data(), - data.size() / sizeof(uint64_t)); -} - -TestExternalElementsData -TestExternalElementsData::allocate(size_t numElements) { - return TestExternalElementsData( - llvm::ArrayRef(new uint64_t[numElements], numElements), - [](const uint64_t *data, size_t) { delete[] data; }, - /*dataIsMutable=*/true); -} - -const TestExternalElementsData * -TestExternalElementsDataManager::getData(StringRef name) const { - auto it = dataMap.find(name); - return it != dataMap.end() ? &*it->second : nullptr; -} - -std::pair -TestExternalElementsDataManager::insert(StringRef name) { - auto it = dataMap.try_emplace(name, nullptr); - if (it.second) - return it; - - llvm::SmallString<32> nameStorage(name); - nameStorage.push_back('_'); - size_t nameCounter = 1; - do { - nameStorage += std::to_string(nameCounter++); - auto it = dataMap.try_emplace(nameStorage, nullptr); - if (it.second) - return it; - nameStorage.resize(name.size() + 1); - } while (true); -} - -void TestExternalElementsDataManager::setData(StringRef name, - TestExternalElementsData &&data) { - auto it = dataMap.find(name); - assert(it != dataMap.end() && "data not registered"); - it->second = std::make_unique(std::move(data)); -} - //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// @@ -109,9 +60,18 @@ "hasSingleBlockImplicitTerminator does not match " "SingleBlockImplicitTerminatorOp"); +struct TestResourceBlobManagerInterface + : public ResourceBlobManagerDialectInterfaceBase< + TestDialectResourceBlobHandle> { + using ResourceBlobManagerDialectInterfaceBase< + TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase; +}; + // Test support for interacting with the AsmPrinter. struct TestOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; + TestOpAsmInterface(Dialect *dialect, TestResourceBlobManagerInterface &mgr) + : OpAsmDialectInterface(dialect), blobManager(mgr) {} //===------------------------------------------------------------------===// // Aliases @@ -176,33 +136,21 @@ std::string getResourceKey(const AsmDialectResourceHandle &handle) const override { - return cast(handle).getKey().str(); + return cast(handle).getKey().str(); } FailureOr declareResource(StringRef key) const final { - TestDialect *dialect = cast(getDialect()); - TestExternalElementsDataManager &mgr = dialect->getExternalDataManager(); - - // Resolve the reference by inserting a new entry into the manager. - auto it = mgr.insert(key).first; - return TestExternalElementsDataHandle(&*it, dialect); + return blobManager.insert(key); } LogicalResult parseResource(AsmParsedResourceEntry &entry) const final { - TestDialect *dialect = cast(getDialect()); - TestExternalElementsDataManager &mgr = dialect->getExternalDataManager(); - - // The resource entries are external constant data. - auto blobAllocFn = [](unsigned size, unsigned align) { - assert(align == alignof(uint64_t) && "unexpected data alignment"); - return TestExternalElementsData::allocate(size / sizeof(uint64_t)); - }; - FailureOr blob = entry.parseAsBlob(blobAllocFn); + FailureOr blob = entry.parseAsBlob(); if (failed(blob)) return failure(); - mgr.setData(entry.getKey(), std::move(*blob)); + // Update the blob for this entry. + blobManager.update(entry.getKey(), std::move(*blob)); return success(); } @@ -210,11 +158,12 @@ buildResources(Operation *op, const SetVector &referencedResources, AsmResourceBuilder &provider) const final { - for (const AsmDialectResourceHandle &handle : referencedResources) { - const auto &testHandle = cast(handle); - provider.buildBlob(testHandle.getKey(), testHandle.getData()->getData()); - } + blobManager.buildResources(provider, referencedResources.getArrayRef()); } + +private: + /// The blob manager for the dialect. + TestResourceBlobManagerInterface &blobManager; }; struct TestDialectFoldInterface : public DialectFoldInterface { @@ -412,8 +361,11 @@ registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); registerDynamicOp(getDynamicCustomParserPrinterOp(this)); - addInterfaces(); + auto &blobInterface = addInterface(); + addInterface(blobInterface); + + addInterfaces(); allowUnknownOperations(); // Instantiate our fallback op interface that we'll use on specific diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td --- a/mlir/test/lib/Dialect/Test/TestDialect.td +++ b/mlir/test/lib/Dialect/Test/TestDialect.td @@ -42,11 +42,6 @@ void printType(::mlir::Type type, ::mlir::DialectAsmPrinter &printer) const override; - /// Returns the external elements data manager for this dialect. - TestExternalElementsDataManager &getExternalDataManager() { - return externalDataManager; - } - private: // Storage for a custom fallback interface. void *fallbackEffectOpInterfaces; @@ -55,9 +50,6 @@ ::llvm::SetVector<::mlir::Type> &stack) const; void printTestType(::mlir::Type type, ::mlir::AsmPrinter &printer, ::llvm::SetVector<::mlir::Type> &stack) const; - - /// An external data manager used to test external elements data. - TestExternalElementsDataManager externalDataManager; }]; }