diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -92,13 +92,16 @@ template static T get(MLIRContext *ctx, unsigned kind, Args &&... args) { return ctx->getAttributeUniquer().get( - getInitFn(ctx, T::getTypeID()), kind, std::forward(args)...); + [ctx](AttributeStorage *storage) { + initializeAttributeStorage(storage, ctx, T::getTypeID()); + }, + kind, std::forward(args)...); } private: - /// Returns a functor used to initialize new attribute storage instances. - static std::function getInitFn(MLIRContext *ctx, - TypeID attrID); + /// Initialize the given attribute storage instance. + static void initializeAttributeStorage(AttributeStorage *storage, + MLIRContext *ctx, TypeID attrID); }; } // namespace detail diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -123,7 +123,7 @@ /// function is used for derived types that have complex storage or uniquing /// constraints. template - Storage *get(std::function initFn, unsigned kind, Arg &&arg, + Storage *get(function_ref initFn, unsigned kind, Arg &&arg, Args &&... args) { // Construct a value of the derived key type. auto derivedKey = @@ -133,19 +133,17 @@ unsigned hashValue = getHash(kind, derivedKey); // Generate an equality function for the derived storage. - std::function isEqual = - [&derivedKey](const BaseStorage *existing) { - return static_cast(*existing) == derivedKey; - }; + auto isEqual = [&derivedKey](const BaseStorage *existing) { + return static_cast(*existing) == derivedKey; + }; // Generate a constructor function for the derived storage. - std::function ctorFn = - [&](StorageAllocator &allocator) { - auto *storage = Storage::construct(allocator, derivedKey); - if (initFn) - initFn(storage); - return storage; - }; + auto ctorFn = [&](StorageAllocator &allocator) { + auto *storage = Storage::construct(allocator, derivedKey); + if (initFn) + initFn(storage); + return storage; + }; // Get an instance for the derived storage. return static_cast(getImpl(kind, hashValue, isEqual, ctorFn)); @@ -156,7 +154,7 @@ /// function is used for derived types that use no additional storage or /// uniquing outside of the kind. template - Storage *get(std::function initFn, unsigned kind) { + Storage *get(function_ref initFn, unsigned kind) { auto ctorFn = [&](StorageAllocator &allocator) { auto *storage = new (allocator.allocate()) Storage(); if (initFn) @@ -178,10 +176,9 @@ unsigned hashValue = getHash(kind, derivedKey); // Generate an equality function for the derived storage. - std::function isEqual = - [&derivedKey](const BaseStorage *existing) { - return static_cast(*existing) == derivedKey; - }; + auto isEqual = [&derivedKey](const BaseStorage *existing) { + return static_cast(*existing) == derivedKey; + }; // Attempt to erase the storage instance. eraseImpl(kind, hashValue, isEqual, [](BaseStorage *storage) { @@ -194,18 +191,18 @@ /// complex storage. BaseStorage *getImpl(unsigned kind, unsigned hashValue, function_ref isEqual, - std::function ctorFn); + function_ref ctorFn); /// Implementation for getting/creating an instance of a derived type with /// default storage. BaseStorage *getImpl(unsigned kind, - std::function ctorFn); + function_ref ctorFn); /// Implementation for erasing an instance of a derived type with complex /// storage. void eraseImpl(unsigned kind, unsigned hashValue, function_ref isEqual, - std::function cleanupFn); + function_ref cleanupFn); /// The internal implementation class. std::unique_ptr impl; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -624,16 +624,15 @@ return getImpl().attributeUniquer; } -/// Returns a functor used to initialize new attribute storage instances. -std::function -AttributeUniquer::getInitFn(MLIRContext *ctx, TypeID attrID) { - return [ctx, attrID](AttributeStorage *storage) { - storage->initializeDialect(lookupDialectForSymbol(ctx, attrID)); - - // If the attribute did not provide a type, then default to NoneType. - if (!storage->getType()) - storage->setType(NoneType::get(ctx)); - }; +/// Initialize the given attribute storage instance. +void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, + MLIRContext *ctx, + TypeID attrID) { + storage->initializeDialect(lookupDialectForSymbol(ctx, attrID)); + + // If the attribute did not provide a type, then default to NoneType. + if (!storage->getType()) + storage->setType(NoneType::get(ctx)); } BoolAttr BoolAttr::get(bool value, MLIRContext *context) { diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -176,14 +176,14 @@ auto StorageUniquer::getImpl( unsigned kind, unsigned hashValue, function_ref isEqual, - std::function ctorFn) -> BaseStorage * { + function_ref ctorFn) -> BaseStorage * { return impl->getOrCreate(kind, hashValue, isEqual, ctorFn); } /// Implementation for getting/creating an instance of a derived type with /// default storage. auto StorageUniquer::getImpl( - unsigned kind, std::function ctorFn) + unsigned kind, function_ref ctorFn) -> BaseStorage * { return impl->getOrCreate(kind, ctorFn); } @@ -192,6 +192,6 @@ /// storage. void StorageUniquer::eraseImpl(unsigned kind, unsigned hashValue, function_ref isEqual, - std::function cleanupFn) { + function_ref cleanupFn) { impl->erase(kind, hashValue, isEqual, cleanupFn); }