diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_IR_DIALECT_H #define MLIR_IR_DIALECT_H +#include "mlir/IR/DialectRegistry.h" #include "mlir/IR/OperationSupport.h" #include "mlir/Support/TypeID.h" @@ -26,11 +27,9 @@ class OpBuilder; class Type; -using DialectAllocatorFunction = std::function; -using DialectAllocatorFunctionRef = function_ref; -using DialectInterfaceAllocatorFunction = - std::function(Dialect *)>; -using ObjectInterfaceAllocatorFunction = std::function; +//===----------------------------------------------------------------------===// +// Dialect +//===----------------------------------------------------------------------===// /// Dialects are groups of MLIR operations, types and attributes, as well as /// behavior associated with the entire group. For example, hooks into other @@ -180,6 +179,16 @@ getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName)); } + /// Register a dialect interface with this dialect instance. + void addInterface(std::unique_ptr interface); + + /// Register a set of dialect interfaces with this dialect instance. + template + void addInterfaces() { + (void)std::initializer_list{ + 0, (addInterface(std::make_unique(this)), 0)...}; + } + protected: /// The constructor takes a unique namespace for this dialect as well as the /// context to bind to. @@ -218,15 +227,6 @@ /// Enable support for unregistered types. void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; } - /// Register a dialect interface with this dialect instance. - void addInterface(std::unique_ptr interface); - - /// Register a set of dialect interfaces with this dialect instance. - template void addInterfaces() { - (void)std::initializer_list{ - 0, (addInterface(std::make_unique(this)), 0)...}; - } - private: Dialect(const Dialect &) = delete; void operator=(Dialect &) = delete; @@ -274,168 +274,6 @@ friend class MLIRContext; }; -/// The DialectRegistry maps a dialect namespace to a constructor for the -/// matching dialect. -/// This allows for decoupling the list of dialects "available" from the -/// dialects loaded in the Context. The parser in particular will lazily load -/// dialects in the Context as operations are encountered. -class DialectRegistry { - /// Lists of interfaces that need to be registered when the dialect is loaded. - struct DelayedInterfaces { - /// Dialect interfaces. - SmallVector, 2> - dialectInterfaces; - /// Attribute/Operation/Type interfaces. - SmallVector, 2> - objectInterfaces; - }; - - using MapTy = - std::map>; - using InterfaceMapTy = DenseMap; - -public: - explicit DialectRegistry(); - - template void insert() { - insert(TypeID::get(), - ConcreteDialect::getDialectNamespace(), - static_cast(([](MLIRContext *ctx) { - // Just allocate the dialect, the context - // takes ownership of it. - return ctx->getOrLoadDialect(); - }))); - } - - template - void insert() { - insert(); - insert(); - } - - /// Add a new dialect constructor to the registry. The constructor must be - /// calling MLIRContext::getOrLoadDialect in order for the context to take - /// ownership of the dialect and for delayed interface registration to happen. - void insert(TypeID typeID, StringRef name, - const DialectAllocatorFunction &ctor); - - /// Return an allocation function for constructing the dialect identified by - /// its namespace, or nullptr if the namespace is not in this registry. - DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const; - - // Register all dialects available in the current registry with the registry - // in the provided context. - void appendTo(DialectRegistry &destination) const { - for (const auto &nameAndRegistrationIt : registry) - destination.insert(nameAndRegistrationIt.second.first, - nameAndRegistrationIt.first, - nameAndRegistrationIt.second.second); - // Merge interfaces. - for (auto it : interfaces) { - TypeID dialect = it.first; - auto destInterfaces = destination.interfaces.find(dialect); - if (destInterfaces == destination.interfaces.end()) { - destination.interfaces[dialect] = it.second; - continue; - } - // The destination already has delayed interface registrations for this - // dialect. Merge registrations into the destination registry. - destInterfaces->second.dialectInterfaces.append( - it.second.dialectInterfaces.begin(), - it.second.dialectInterfaces.end()); - destInterfaces->second.objectInterfaces.append( - it.second.objectInterfaces.begin(), it.second.objectInterfaces.end()); - } - } - - /// Return the names of dialects known to this registry. - auto getDialectNames() const { - return llvm::map_range( - registry, - [](const MapTy::value_type &item) -> StringRef { return item.first; }); - } - - /// Add an interface constructed with the given allocation function to the - /// dialect provided as template parameter. The dialect must be present in - /// the registry. - template - void addDialectInterface(TypeID interfaceTypeID, - DialectInterfaceAllocatorFunction allocator) { - addDialectInterface(DialectTy::getDialectNamespace(), interfaceTypeID, - allocator); - } - - /// Add an interface to the dialect, both provided as template parameter. The - /// dialect must be present in the registry. - template - void addDialectInterface() { - addDialectInterface( - InterfaceTy::getInterfaceID(), [](Dialect *dialect) { - return std::make_unique(dialect); - }); - } - - /// Add an external op interface model for an op that belongs to a dialect, - /// both provided as template parameters. The dialect must be present in the - /// registry. - template void addOpInterface() { - StringRef opName = OpTy::getOperationName(); - StringRef dialectName = opName.split('.').first; - addObjectInterface(dialectName, TypeID::get(), - ModelTy::Interface::getInterfaceID(), - [](MLIRContext *context) { - OpTy::template attachInterface(*context); - }); - } - - /// Add an external attribute interface model for an attribute type `AttrTy` - /// that is going to belong to `DialectTy`. The dialect must be present in the - /// registry. - template - void addAttrInterface() { - addStorageUserInterface(DialectTy::getDialectNamespace()); - } - - /// Add an external type interface model for an type class `TypeTy` that is - /// going to belong to `DialectTy`. The dialect must be present in the - /// registry. - template - void addTypeInterface() { - addStorageUserInterface(DialectTy::getDialectNamespace()); - } - - /// Register any interfaces required for the given dialect (based on its - /// TypeID). Users are not expected to call this directly. - void registerDelayedInterfaces(Dialect *dialect) const; - -private: - /// Add an interface constructed with the given allocation function to the - /// dialect identified by its namespace. - void addDialectInterface(StringRef dialectName, TypeID interfaceTypeID, - const DialectInterfaceAllocatorFunction &allocator); - - /// Add an attribute/operation/type interface constructible with the given - /// allocation function to the dialect identified by its namespace. - void addObjectInterface(StringRef dialectName, TypeID objectID, - TypeID interfaceTypeID, - const ObjectInterfaceAllocatorFunction &allocator); - - /// Add an external model for an attribute/type interface to the dialect - /// identified by its namespace. - template - void addStorageUserInterface(StringRef dialectName) { - addObjectInterface(dialectName, TypeID::get(), - ModelTy::Interface::getInterfaceID(), - [](MLIRContext *context) { - ObjectTy::template attachInterface(*context); - }); - } - - MapTy registry; - InterfaceMapTy interfaces; -}; - } // namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/DialectRegistry.h @@ -0,0 +1,222 @@ +//===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines functionality for registring and extending dialects. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_DIALECTREGISTRY_H +#define MLIR_IR_DIALECTREGISTRY_H + +#include "mlir/IR/MLIRContext.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace mlir { +class Dialect; + +using DialectAllocatorFunction = std::function; +using DialectAllocatorFunctionRef = function_ref; + +//===----------------------------------------------------------------------===// +// DialectExtension +//===----------------------------------------------------------------------===// + +/// This class represents an opaque dialect extension. It contains a set of +/// required dialects and an application function. The required dialects control +/// when the extension is applied, i.e. the extension is applied when all +/// required dialects are loaded. The application function can be used to attach +/// additional functionality to attributes, dialects, operations, types, etc., +/// and may also load additional necessary dialects. +class DialectExtensionBase { +public: + virtual ~DialectExtensionBase(); + + /// Return the dialects that our required by this extension to be loaded + /// before applying. + ArrayRef getRequiredDialects() const { return dialectNames; } + + /// Apply this extension to the given context and the required dialects. + virtual void apply(MLIRContext *context, + MutableArrayRef dialects) const = 0; + + /// Return a copy of this extension. + virtual std::unique_ptr clone() const = 0; + +protected: + /// Initialize the extension with a set of required dialects. Note that there + /// should always be at least one affected dialect. + DialectExtensionBase(ArrayRef dialectNames) + : dialectNames(dialectNames.begin(), dialectNames.end()) { + assert(!dialectNames.empty() && "expected at least one affected dialect"); + } + +private: + /// The names of the dialects affected by this extension. + SmallVector dialectNames; +}; + +/// This class represents a dialect extension anchored on the given set of +/// dialects. When all of the specified dialects have been loaded, the +/// application function of this extension will be executed. +template +class DialectExtension : public DialectExtensionBase { +public: + /// Applies this extension to the given context and set of required dialects. + virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0; + + /// Return a copy of this extension. + std::unique_ptr clone() const final { + return std::make_unique(static_cast(*this)); + } + +protected: + DialectExtension() + : DialectExtensionBase( + ArrayRef({DialectsT::getDialectNamespace()...})) {} + + /// Override the base apply method to allow providing the exact dialect types. + void apply(MLIRContext *context, + MutableArrayRef dialects) const final { + unsigned dialectIdx = 0; + auto derivedDialects = std::tuple{ + static_cast(dialects[dialectIdx++])...}; + llvm::apply_tuple( + [&](DialectsT *...dialect) { apply(context, dialect...); }, + derivedDialects); + } +}; + +//===----------------------------------------------------------------------===// +// DialectRegistry +//===----------------------------------------------------------------------===// + +/// The DialectRegistry maps a dialect namespace to a constructor for the +/// matching dialect. This allows for decoupling the list of dialects +/// "available" from the dialects loaded in the Context. The parser in +/// particular will lazily load dialects in the Context as operations are +/// encountered. +class DialectRegistry { + using MapTy = + std::map>; + +public: + explicit DialectRegistry(); + + template + void insert() { + insert(TypeID::get(), + ConcreteDialect::getDialectNamespace(), + static_cast(([](MLIRContext *ctx) { + // Just allocate the dialect, the context + // takes ownership of it. + return ctx->getOrLoadDialect(); + }))); + } + + template + void insert() { + insert(); + insert(); + } + + /// Add a new dialect constructor to the registry. The constructor must be + /// calling MLIRContext::getOrLoadDialect in order for the context to take + /// ownership of the dialect and for delayed interface registration to happen. + void insert(TypeID typeID, StringRef name, + const DialectAllocatorFunction &ctor); + + /// Return an allocation function for constructing the dialect identified by + /// its namespace, or nullptr if the namespace is not in this registry. + DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const; + + // Register all dialects available in the current registry with the registry + // in the provided context. + void appendTo(DialectRegistry &destination) const { + for (const auto &nameAndRegistrationIt : registry) + destination.insert(nameAndRegistrationIt.second.first, + nameAndRegistrationIt.first, + nameAndRegistrationIt.second.second); + // Merge the extensions. + for (const auto &extension : extensions) + destination.extensions.push_back(extension->clone()); + } + + /// Return the names of dialects known to this registry. + auto getDialectNames() const { + return llvm::map_range( + registry, + [](const MapTy::value_type &item) -> StringRef { return item.first; }); + } + + /// Apply any held extensions that require the given dialect. Users are not + /// expected to call this directly. + void applyExtensions(Dialect *dialect) const; + + /// Apply any applicable extensions to the given context. Users are not + /// expected to call this directly. + void applyExtensions(MLIRContext *ctx) const; + + /// Add the given extension to the registry. + void addExtension(std::unique_ptr extension) { + extensions.push_back(std::move(extension)); + } + + /// Add the given extensions to the registry. + template + void addExtensions() { + (void)std::initializer_list{ + addExtension(std::make_unique())...}; + } + + /// Add an extension function that requires the given dialects. + /// Note: This bare functor overload is provided in addition to the + /// std::function variant to enable dialect type deduction, e.g.: + /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... }) + /// + /// is equivalent to: + /// registry.addExtension( + /// [](MLIRContext *ctx, MyDialect *dialect){ ... } + /// ) + template + void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) { + addExtension( + std::function(extensionFn)); + } + template + void + addExtension(std::function extensionFn) { + using ExtensionFnT = std::function; + + struct Extension : public DialectExtension { + Extension(const Extension &) = default; + Extension(ExtensionFnT extensionFn) + : extensionFn(std::move(extensionFn)) {} + ~Extension() override = default; + + void apply(MLIRContext *context, DialectsT *...dialects) const final { + extensionFn(context, dialects...); + } + ExtensionFnT extensionFn; + }; + addExtension(std::make_unique(std::move(extensionFn))); + } + +private: + MapTy registry; + std::vector> extensions; +}; + +} // namespace mlir + +#endif // MLIR_IR_DIALECTREGISTRY_H diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp @@ -154,7 +154,9 @@ void mlir::arith::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); + registry.addExtension(+[](MLIRContext *ctx, ArithmeticDialect *dialect) { + ConstantOp::attachInterface(*ctx); + IndexCastOp::attachInterface(*ctx); + SelectOp::attachInterface(*ctx); + }); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -695,7 +695,9 @@ void bufferization::registerAllocationOpInterfaceExternalModels( DialectRegistry ®istry) { - registry.addOpInterface(); + registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { + memref::AllocOp::attachInterface(*ctx); + }); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -962,9 +962,11 @@ void mlir::linalg::comprehensive_bufferize::std_ext:: registerModuleBufferizationExternalModels(DialectRegistry ®istry) { - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); + registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { + func::CallOp::attachInterface(*ctx); + func::ReturnOp::attachInterface(*ctx); + func::FuncOp::attachInterface(*ctx); + }); } /// Set the attribute that triggers inplace bufferization on a FuncOp argument diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -246,22 +246,13 @@ /// Helper structure that iterates over all LinalgOps in `OpTys` and registers /// the `BufferizableOpInterface` with each of them. -template -struct LinalgOpInterfaceHelper; - -template -struct LinalgOpInterfaceHelper { - static void registerOpInterface(DialectRegistry ®istry) { - registry.addOpInterface>(); - LinalgOpInterfaceHelper::registerOpInterface(registry); +template +struct LinalgOpInterfaceHelper { + static void registerOpInterface(MLIRContext *ctx) { + (void)std::initializer_list{ + 0, (Ops::template attachInterface>(*ctx), 0)...}; } }; - -template <> -struct LinalgOpInterfaceHelper<> { - static void registerOpInterface(DialectRegistry ®istry) {} -}; - } // namespace /// Return true if all `neededValues` are in scope at the given @@ -501,13 +492,15 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { - registry.addOpInterface(); + registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { + linalg::InitTensorOp::attachInterface(*ctx); - // Register all Linalg structured ops. `LinalgOp` is an interface and it is - // not possible to attach an external interface to an existing interface. - // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. - LinalgOpInterfaceHelper< + // Register all Linalg structured ops. `LinalgOp` is an interface and it is + // not possible to attach an external interface to an existing interface. + // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. + LinalgOpInterfaceHelper< #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" - >::registerOpInterface(registry); + >::registerOpInterface(ctx); + }); } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -503,8 +503,10 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); + registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { + ExecuteRegionOp::attachInterface(*ctx); + ForOp::attachInterface(*ctx); + IfOp::attachInterface(*ctx); + YieldOp::attachInterface(*ctx); + }); } diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -168,6 +168,8 @@ void mlir::shape::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { - registry.addOpInterface(); - registry.addOpInterface(); + registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) { + shape::AssumingOp::attachInterface(*ctx); + shape::AssumingYieldOp::attachInterface(*ctx); + }); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -205,11 +205,11 @@ void mlir::tensor::registerInferTypeOpInterfaceExternalModels( DialectRegistry ®istry) { - registry - .addOpInterface>(); - registry - .addOpInterface>(); - registry.addOpInterface(); + registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { + ExpandShapeOp::attachInterface< + ReifyExpandOrCollapseShapeOp>(*ctx); + CollapseShapeOp::attachInterface< + ReifyExpandOrCollapseShapeOp>(*ctx); + PadOp::attachInterface(*ctx); + }); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -283,5 +283,7 @@ void mlir::tensor::registerTilingOpInterfaceExternalModels( DialectRegistry ®istry) { - registry.addOpInterface(); + registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { + tensor::PadOp::attachInterface(*ctx); + }); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -700,15 +700,17 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); + registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { + CastOp::attachInterface(*ctx); + CollapseShapeOp::attachInterface(*ctx); + DimOp::attachInterface(*ctx); + ExpandShapeOp::attachInterface(*ctx); + ExtractSliceOp::attachInterface(*ctx); + ExtractOp::attachInterface(*ctx); + FromElementsOp::attachInterface(*ctx); + GenerateOp::attachInterface(*ctx); + InsertOp::attachInterface(*ctx); + InsertSliceOp::attachInterface(*ctx); + RankOp::attachInterface(*ctx); + }); } diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -121,6 +121,8 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { - registry.addOpInterface(); - registry.addOpInterface(); + registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) { + TransferReadOp::attachInterface(*ctx); + TransferWriteOp::attachInterface(*ctx); + }); } diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -24,97 +24,6 @@ using namespace mlir; using namespace detail; -//===----------------------------------------------------------------------===// -// DialectRegistry -//===----------------------------------------------------------------------===// - -DialectRegistry::DialectRegistry() { insert(); } - -void DialectRegistry::addDialectInterface( - StringRef dialectName, TypeID interfaceTypeID, - const DialectInterfaceAllocatorFunction &allocator) { - assert(allocator && "unexpected null interface allocation function"); - auto it = registry.find(dialectName.str()); - assert(it != registry.end() && - "adding an interface for an unregistered dialect"); - - // Bail out if the interface with the given ID is already in the registry for - // the given dialect. We expect a small number (dozens) of interfaces so a - // linear search is fine here. - auto &ifaces = interfaces[it->second.first]; - for (const auto &kvp : ifaces.dialectInterfaces) { - if (kvp.first == interfaceTypeID) { - LLVM_DEBUG(llvm::dbgs() - << "[" DEBUG_TYPE - "] repeated interface registration for dialect " - << dialectName); - return; - } - } - - ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator); -} - -void DialectRegistry::addObjectInterface( - StringRef dialectName, TypeID objectID, TypeID interfaceTypeID, - const ObjectInterfaceAllocatorFunction &allocator) { - assert(allocator && "unexpected null interface allocation function"); - - auto it = registry.find(dialectName.str()); - assert(it != registry.end() && - "adding an interface for an op from an unregistered dialect"); - - auto dialectID = it->second.first; - auto &ifaces = interfaces[dialectID]; - - for (const auto &info : ifaces.objectInterfaces) { - if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) { - LLVM_DEBUG(llvm::dbgs() - << "[" DEBUG_TYPE - "] repeated interface object interface registration"); - return; - } - } - - ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator); -} - -DialectAllocatorFunctionRef -DialectRegistry::getDialectAllocator(StringRef name) const { - auto it = registry.find(name.str()); - if (it == registry.end()) - return nullptr; - return it->second.second; -} - -void DialectRegistry::insert(TypeID typeID, StringRef name, - const DialectAllocatorFunction &ctor) { - auto inserted = registry.insert( - std::make_pair(std::string(name), std::make_pair(typeID, ctor))); - if (!inserted.second && inserted.first->second.first != typeID) { - llvm::report_fatal_error( - "Trying to register different dialects for the same namespace: " + - name); - } -} - -void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const { - auto it = interfaces.find(dialect->getTypeID()); - if (it == interfaces.end()) - return; - - // Add an interface if it is not already present. - for (const auto &kvp : it->getSecond().dialectInterfaces) { - if (dialect->getRegisteredInterface(kvp.first)) - continue; - dialect->addInterface(kvp.second(dialect)); - } - - // Add attribute, operation and type interfaces. - for (const auto &info : it->getSecond().objectInterfaces) - std::get<2>(info)(dialect->getContext()); -} - //===----------------------------------------------------------------------===// // Dialect //===----------------------------------------------------------------------===// @@ -189,7 +98,13 @@ auto it = registeredInterfaces.try_emplace(interface->getID(), std::move(interface)); (void)it; - assert(it.second && "interface kind has already been registered"); + LLVM_DEBUG({ + if (!it.second) { + llvm::dbgs() << "[" DEBUG_TYPE + "] repeated interface registration for dialect " + << getNamespace(); + } + }); } //===----------------------------------------------------------------------===// @@ -216,3 +131,100 @@ DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { return getInterfaceFor(op->getDialect()); } + +//===----------------------------------------------------------------------===// +// DialectExtension +//===----------------------------------------------------------------------===// + +DialectExtensionBase::~DialectExtensionBase() = default; + +//===----------------------------------------------------------------------===// +// DialectRegistry +//===----------------------------------------------------------------------===// + +DialectRegistry::DialectRegistry() { insert(); } + +DialectAllocatorFunctionRef +DialectRegistry::getDialectAllocator(StringRef name) const { + auto it = registry.find(name.str()); + if (it == registry.end()) + return nullptr; + return it->second.second; +} + +void DialectRegistry::insert(TypeID typeID, StringRef name, + const DialectAllocatorFunction &ctor) { + auto inserted = registry.insert( + std::make_pair(std::string(name), std::make_pair(typeID, ctor))); + if (!inserted.second && inserted.first->second.first != typeID) { + llvm::report_fatal_error( + "Trying to register different dialects for the same namespace: " + + name); + } +} + +void DialectRegistry::applyExtensions(Dialect *dialect) const { + MLIRContext *ctx = dialect->getContext(); + StringRef dialectName = dialect->getNamespace(); + + // Functor used to try to apply the given extension. + auto applyExtension = [&](const DialectExtensionBase &extension) { + ArrayRef dialectNames = extension.getRequiredDialects(); + + // Handle the simple case of a single dialect name. In this case, the + // required dialect should be the current dialect. + if (dialectNames.size() == 1) { + if (dialectNames.front() == dialectName) + extension.apply(ctx, dialect); + return; + } + + // Otherwise, check to see if this extension requires this dialect. + const StringRef *nameIt = llvm::find(dialectNames, dialectName); + if (nameIt == dialectNames.end()) + return; + + // If it does, ensure that all of the other required dialects have been + // loaded. + SmallVector requiredDialects; + requiredDialects.reserve(dialectNames.size()); + for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e; + ++it) { + // The current dialect is known to be loaded. + if (it == nameIt) { + requiredDialects.push_back(dialect); + continue; + } + // Otherwise, check if it is loaded. + Dialect *loadedDialect = ctx->getLoadedDialect(*it); + if (!loadedDialect) + return; + requiredDialects.push_back(loadedDialect); + } + extension.apply(ctx, requiredDialects); + }; + + for (const auto &extension : extensions) + applyExtension(*extension); +} + +void DialectRegistry::applyExtensions(MLIRContext *ctx) const { + // Functor used to try to apply the given extension. + auto applyExtension = [&](const DialectExtensionBase &extension) { + ArrayRef dialectNames = extension.getRequiredDialects(); + + // Check to see if all of the dialects for this extension are loaded. + SmallVector requiredDialects; + requiredDialects.reserve(dialectNames.size()); + for (StringRef dialectName : dialectNames) { + Dialect *loadedDialect = ctx->getLoadedDialect(dialectName); + if (!loadedDialect) + return; + requiredDialects.push_back(loadedDialect); + } + extension.apply(ctx, requiredDialects); + }; + + for (const auto &extension : extensions) + applyExtension(*extension); +} diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -357,9 +357,8 @@ void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { registry.appendTo(impl->dialectsRegistry); - // For the already loaded dialects, register the interfaces immediately. - for (const auto &kvp : impl->loadedDialects) - registry.registerDelayedInterfaces(kvp.second.get()); + // For the already loaded dialects, apply any possible extensions immediately. + registry.applyExtensions(this); } const DialectRegistry &MLIRContext::getDialectRegistry() { @@ -437,8 +436,8 @@ impl.dialectReferencingStrAttrs.erase(stringAttrsIt); } - // Actually register the interfaces with delayed registration. - impl.dialectsRegistry.registerDelayedInterfaces(dialect.get()); + // Apply any extensions to this newly loaded dialect. + impl.dialectsRegistry.applyExtensions(dialect.get()); return dialect.get(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp @@ -44,8 +44,9 @@ void mlir::registerAMXDialectTranslation(DialectRegistry ®istry) { registry.insert(); - registry.addDialectInterface(); + registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) { + dialect->addInterfaces(); + }); } void mlir::registerAMXDialectTranslation(MLIRContext &context) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.cpp @@ -45,8 +45,10 @@ void mlir::registerArmNeonDialectTranslation(DialectRegistry ®istry) { registry.insert(); - registry.addDialectInterface(); + registry.addExtension( + +[](MLIRContext *ctx, arm_neon::ArmNeonDialect *dialect) { + dialect->addInterfaces(); + }); } void mlir::registerArmNeonDialectTranslation(MLIRContext &context) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp @@ -44,8 +44,9 @@ void mlir::registerArmSVEDialectTranslation(DialectRegistry ®istry) { registry.insert(); - registry.addDialectInterface(); + registry.addExtension(+[](MLIRContext *ctx, arm_sve::ArmSVEDialect *dialect) { + dialect->addInterfaces(); + }); } void mlir::registerArmSVEDialectTranslation(MLIRContext &context) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -503,8 +503,9 @@ void mlir::registerLLVMDialectTranslation(DialectRegistry ®istry) { registry.insert(); - registry.addDialectInterface(); + registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { + dialect->addInterfaces(); + }); } void mlir::registerLLVMDialectTranslation(MLIRContext &context) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -141,8 +141,9 @@ void mlir::registerNVVMDialectTranslation(DialectRegistry ®istry) { registry.insert(); - registry.addDialectInterface(); + registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) { + dialect->addInterfaces(); + }); } void mlir::registerNVVMDialectTranslation(MLIRContext &context) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp @@ -533,8 +533,9 @@ void mlir::registerOpenACCDialectTranslation(DialectRegistry ®istry) { registry.insert(); - registry.addDialectInterface(); + registry.addExtension(+[](MLIRContext *ctx, acc::OpenACCDialect *dialect) { + dialect->addInterfaces(); + }); } void mlir::registerOpenACCDialectTranslation(MLIRContext &context) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1270,8 +1270,9 @@ void mlir::registerOpenMPDialectTranslation(DialectRegistry ®istry) { registry.insert(); - registry.addDialectInterface(); + registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) { + dialect->addInterfaces(); + }); } void mlir::registerOpenMPDialectTranslation(MLIRContext &context) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp @@ -107,8 +107,9 @@ void mlir::registerROCDLDialectTranslation(DialectRegistry ®istry) { registry.insert(); - registry.addDialectInterface(); + registry.addExtension(+[](MLIRContext *ctx, ROCDL::ROCDLDialect *dialect) { + dialect->addInterfaces(); + }); } void mlir::registerROCDLDialectTranslation(MLIRContext &context) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp @@ -45,8 +45,10 @@ void mlir::registerX86VectorDialectTranslation(DialectRegistry ®istry) { registry.insert(); - registry.addDialectInterface(); + registry.addExtension( + +[](MLIRContext *ctx, x86vector::X86VectorDialect *dialect) { + dialect->addInterfaces(); + }); } void mlir::registerX86VectorDialectTranslation(MLIRContext &context) { diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp --- a/mlir/unittests/IR/DialectTest.cpp +++ b/mlir/unittests/IR/DialectTest.cpp @@ -63,7 +63,9 @@ registry.insert(); // Delayed registration of an interface for TestDialect. - registry.addDialectInterface(); + registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { + dialect->addInterfaces(); + }); MLIRContext context(registry); @@ -85,8 +87,10 @@ // loaded dialect and check that the interface is now registered. DialectRegistry secondRegistry; secondRegistry.insert(); - secondRegistry - .addDialectInterface(); + secondRegistry.addExtension( + +[](MLIRContext *ctx, SecondTestDialect *dialect) { + dialect->addInterfaces(); + }); context.appendDialectRegistry(secondRegistry); secondTestDialectInterface = dyn_cast(secondTestDialect); @@ -97,7 +101,9 @@ // Set up the delayed registration. DialectRegistry registry; registry.insert(); - registry.addDialectInterface(); + registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { + dialect->addInterfaces(); + }); MLIRContext context(registry); // Load the TestDialect and check that the interface got registered for it. @@ -110,33 +116,12 @@ // on repeated interface registration. DialectRegistry secondRegistry; secondRegistry.insert(); - secondRegistry.addDialectInterface(); + secondRegistry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) { + dialect->addInterfaces(); + }); context.appendDialectRegistry(secondRegistry); testDialectInterface = dyn_cast(testDialect); EXPECT_TRUE(testDialectInterface != nullptr); } -// A dialect that registers two interfaces with the same InterfaceID, triggering -// an assertion failure. -struct RepeatedRegistrationDialect : public Dialect { - static StringRef getDialectNamespace() { return "repeatedreg"; } - RepeatedRegistrationDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context, - TypeID::get()) { - addInterfaces(); - addInterfaces(); - } -}; - -TEST(Dialect, RepeatedInterfaceRegistrationDeath) { - MLIRContext context; - (void)context; - - // This triggers an assertion in debug mode. -#ifndef NDEBUG - ASSERT_DEATH(context.loadDialect(), - "interface kind has already been registered"); -#endif -} - } // namespace diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -102,7 +102,9 @@ // Put the interface in the registry. DialectRegistry registry; registry.insert(); - registry.addTypeInterface(); + registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { + test::TestType::attachInterface(*ctx); + }); // Check that when a context is constructed with the given registry, the type // interface gets registered. @@ -119,7 +121,9 @@ // Put the interface in the registry. DialectRegistry registry; registry.insert(); - registry.addTypeInterface(); + registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { + test::TestType::attachInterface(*ctx); + }); // Check that when the registry gets appended to the context, the interface // becomes available for objects in loaded dialects. @@ -133,7 +137,9 @@ TEST(InterfaceAttachment, RepeatedRegistration) { DialectRegistry registry; - registry.addTypeInterface(); + registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { + IntegerType::attachInterface(*ctx); + }); MLIRContext context(registry); // Should't fail on repeated registration through the dialect registry. @@ -144,7 +150,9 @@ // Builtin dialect needs to registration or loading, but delayed interface // registration must still work. DialectRegistry registry; - registry.addTypeInterface(); + registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { + IntegerType::attachInterface(*ctx); + }); MLIRContext context(registry); IntegerType i16 = IntegerType::get(&context, 16); @@ -238,8 +246,9 @@ // that the delayed registration work for attributes. DialectRegistry registry; registry.insert(); - registry.addAttrInterface(); + registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { + test::SimpleAAttr::attachInterface(*ctx); + }); MLIRContext context(registry); context.loadDialect(); @@ -343,12 +352,16 @@ TEST(InterfaceAttachment, OperationDelayedContextConstruct) { DialectRegistry registry; registry.insert(); - registry.addOpInterface(); - registry.addOpInterface>(); - registry.addOpInterface>(); - - // Construct the context directly from a registry. The interfaces are expected - // to be readily available on operations. + registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { + ModuleOp::attachInterface(*ctx); + }); + registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { + test::OpJ::attachInterface>(*ctx); + test::OpH::attachInterface>(*ctx); + }); + + // Construct the context directly from a registry. The interfaces are + // expected to be readily available on operations. MLIRContext context(registry); context.loadDialect(); @@ -370,9 +383,13 @@ TEST(InterfaceAttachment, OperationDelayedContextAppend) { DialectRegistry registry; registry.insert(); - registry.addOpInterface(); - registry.addOpInterface>(); - registry.addOpInterface>(); + registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { + ModuleOp::attachInterface(*ctx); + }); + registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) { + test::OpJ::attachInterface>(*ctx); + test::OpH::attachInterface>(*ctx); + }); // Construct the context, create ops, and only then append the registry. The // interfaces are expected to be available after appending the registry.