diff --git a/mlir/docs/ExtensibleDialects.md b/mlir/docs/ExtensibleDialects.md new file mode 100644 --- /dev/null +++ b/mlir/docs/ExtensibleDialects.md @@ -0,0 +1,361 @@ +# Extensible dialects + +This file documents the design and API of the extensible dialects. Extensible +dialects are dialects that can be extended with new operations and types defined +at runtime. This allows for users to define dialects via with meta-programming, +or from another language, without having to recompile C++ code. + +[TOC] + +## Usage + +### Defining an extensible dialect + +Dialects defined in C++ can be extended with new operations and types at runtime +by inheriting from `mlir::ExtensibleDialect` instead of `mlir::Dialect` +(note that `ExtensibleDialect` inherits from `Dialect`). The `ExtensibleDialect` +class contains the necessary fields and methods to extend the dialect at +runtime. + +```c++ +class MyDialect : public mlir::ExtensibleDialect { + ... +} +``` + +For dialects defined in TableGen, this is done by setting the `isExtensible` +flag to `1`. + +```tablegen +def Test_Dialect : Dialect { + let isExtensible = 1; + ... +} +``` + +An extensible `Dialect` can be casted back to `ExtensibleDialect` using +`llvm::dyn_cast`, or `llvm::cast`: + +```c++ +if (auto extensibleDialect = llvm::dyn_cast(dialect)) { + ... +} +``` + +### Defining an operation at runtime + +The `DynamicOpDefinition` class represents the definition of an operation +defined at runtime. It is created using the `DynamicOpDefinition::get` +functions. An operation defined at runtime must provide a name, a dialect in +which the operation will be registered in, an operation verifier. It may also +optionally define a custom parser and a printer, fold hook, and more. + +```c++ +// The operation name, without the dialect name prefix. +StringRef name = "my_operation_name"; + +// The dialect defining the operation. +Dialect* dialect = ctx->getOrLoadDialect(); + +// Operation verifier definition. +AbstractOperation::VerifyInvariantsFn verifyFn = [](Operation* op) { + // Logic for the operation verification. + ... +} + +// Parser function definition. +AbstractOperation::ParseAssemblyFn parseFn = + [](OpAsmParser &parser, OperationState &state) { + // Parse the operation, given that the name is already parsed. + ... +}; + +// Printer function +auto printFn = [](Operation *op, OpAsmPrinter &printer) { + printer << op->getName(); + // Print the operation, given that the name is already printed. + ... +}; + +// General folder implementation, see AbstractOperation::foldHook for more +// information. +auto foldHookFn = [](Operation * op, ArrayRef operands, + SmallVectorImpl &result) { + ... +}; + +// Returns any canonicalization pattern rewrites that the operation +// supports, for use by the canonicalization pass. +auto getCanonicalizationPatterns = + [](RewritePatternSet &results, MLIRContext *context) { + ... +} + +// Definition of the operation. +std::unique_ptr opDef = + DynamicOpDefinition::get(name, dialect, std::move(verifyFn), + std::move(parseFn), std::move(printFn), std::move(foldHookFn), + std::move(getCanonicalizationPatterns)); +``` + +Once the operation is defined, it can be registered by an `ExtensibleDialect`: + +```c++ +extensibleDialect->registerDynamicOperation(std::move(opDef)); +``` + +Note that the `Dialect` given to the operation should be the one registering +the operation. + +### Using an operation defined at runtime + +It is possible to match on an operation defined at runtime using their names: + +```c++ +if (op->getName().getStringRef() == "my_dialect.my_dynamic_op") { + ... +} +``` + +An operation defined at runtime can be created by creating an `OperationState` +with its name, and passing it to a rewriter such as a `PatternRewriter`. + +```c++ +OperationState state(location, "my_dialect.my_dynamic_op", + operands, resultTypes, attributes); + +rewriter.createOperation(state); +``` + + +### Defining a type at runtime + +Contrary to types defined in C++ or in TableGen, types defined at runtime can +only have as argument a list of `Attribute`. + +Similarily to operations, a type is defined at runtime using the class +`DynamicTypeDefinition`, which is created using the `DynamicTypeDefinition::get` +functions. A type definition requires a name, the dialect that will register the +type, and a parameter verifier. It can also define optionally a custom parser +and printer for the arguments (the type name is assumed to be already +parsed/printed). + +```c++ +// The type name, without the dialect name prefix. +StringRef name = "my_type_name"; + +// The dialect defining the type. +Dialect* dialect = ctx->getOrLoadDialect(); + +// The type verifier. +// A type defined at runtime has a list of attributes as parameters. +auto verifier = [](function_ref emitError, + ArrayRef args) { + ... +}; + +// The type parameters parser. +auto parser = [](DialectAsmParser &parser, + llvm::SmallVectorImpl &parsedParams) { + ... +}; + +// The type parameters printer. +auto printer =[](DialectAsmPrinter &printer, ArrayRef params) { + ... +}; + +std::unique_ptr typeDef = + DynamicTypeDefinition::get(std::move(name), std::move(dialect), + std::move(verifier), std::move(printer), std::move(parser)); +``` + +If the printer and the parser are ommited, a default parser and printer is +generated with the format `!dialect.typename`. + +The type can then be registered by the `ExtensibleDialect`: + +```c++ +dialect->registerDynamicType(std::move(typeDef)); +``` + +### Parsing types defined at runtime in an extensible dialect + +In order to parse types defined at runtime, it is necessary to add in the +`MyDialect::parseType` method the necessary support. + +```c++ +Type MyDialect::parseType(DialectAsmParser &parser) const { + ... + // The type name. + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Type(); + + // Try to parse a dynamic type with 'typeTag' name. + Type dynType; + auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynType; + return Type(); + } +``` + +### Using a type defined at runtime + +Dynamic types are instances of `DynamicType`. It is possible to get a dynamic +type with `DynamicType::get` and `ExtensibleDialect::lookupTypeDefinition`. + +```c++ +auto typeDef = extensibleDialect->lookupTypeDefinition("my_dynamic_type"); +ArrayRef params = ...; +auto type = DynamicType::get(typeDef, params); +``` + +It is also possible to cast a `Type` known to be defined at runtime to a +`DynamicType`. + +```c++ +auto dynType = type.cast(); +auto typeDef = dynType.getTypeDef(); +auto args = dynType.getParams(); +``` + +### Defining an attribute at runtime + +Similar to types defined at runtime, attributes defined at runtime can only have +as argument a list of `Attribute`. + +Similarily to types, an attribute is defined at runtime using the class +`DynamicAttrDefinition`, which is created using the `DynamicAttrDefinition::get` +functions. An attribute definition requires a name, the dialect that will +register the attribute, and a parameter verifier. It can also define optionally +a custom parser and printer for the arguments (the attribute name is assumed to +be already parsed/printed). + +```c++ +// The attribute name, without the dialect name prefix. +StringRef name = "my_attribute_name"; + +// The dialect defining the attribute. +Dialect* dialect = ctx->getOrLoadDialect(); + +// The attribute verifier. +// An attribute defined at runtime has a list of attributes as parameters. +auto verifier = [](function_ref emitError, + ArrayRef args) { + ... +}; + +// The attribute parameters parser. +auto parser = [](DialectAsmParser &parser, + llvm::SmallVectorImpl &parsedParams) { + ... +}; + +// The attribute parameters printer. +auto printer =[](DialectAsmPrinter &printer, ArrayRef params) { + ... +}; + +std::unique_ptr attrDef = + DynamicAttrDefinition::get(std::move(name), std::move(dialect), + std::move(verifier), std::move(printer), std::move(parser)); +``` + +If the printer and the parser are ommited, a default parser and printer is +generated with the format `!dialect.attrname`. + +The attribute can then be registered by the `ExtensibleDialect`: + +```c++ +dialect->registerDynamicAttr(std::move(typeDef)); +``` + +### Parsing attributes defined at runtime in an extensible dialect + +In order to parse attributes defined at runtime, it is necessary to add in the +`MyDialect::parseAttribute` method the necessary support. + +```c++ +Attribute MyDialect::parseAttribute(DialectAsmParser &parser, + Type type) const override { + ... + // The attribute name. + StringRef attrTag; + if (failed(parser.parseKeyword(&attrTag))) + return Attribute(); + + // Try to parse a dynamic attribute with 'attrTag' name. + Attribute dynAttr; + auto parseResult = parseOptionalDynamicAttr(attrTag, parser, dynAttr); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynAttr; + return Attribute(); + } +``` + +### Using an attribute defined at runtime + +Similar to types, attributes defined at runtime are instances of `DynamicAttr`. +It is possible to get a dynamic attribute with `DynamicAttr::get` and +`ExtensibleDialect::lookupAttrDefinition`. + +```c++ +auto attrDef = extensibleDialect->lookupAttrDefinition("my_dynamic_attr"); +ArrayRef params = ...; +auto attr = DynamicAttr::get(attrDef, params); +``` + +It is also possible to cast an `Attribute` known to be defined at runtime to a +`DynamicAttr`. + +```c++ +auto dynAttr = attr.cast(); +auto attrDef = dynAttr.getAttrDef(); +auto args = dynAttr.getParams(); +``` + +## Implementation details + +### Extensible dialect + +The role of extensible dialects is to own the necessary data for defined +operations and types. They also contain the necessary accessors to easily +access them. + +In order to cast a `Dialect` back to an `ExtensibleDialect`, we implement the +`IsExtensibleDialect` interface to all `ExtensibleDialect`. The casting is done +by checking if the `Dialect` implements `IsExtensibleDialect` or not. + +### Operation representation and registration + +Operations are represented in mlir using the `AbstractOperation` class. They are +registered in dialects the same way operations defined in C++ are registered, +which is by calling `AbstractOperation::insert`. + +The only difference is that a new `TypeID` needs to be created for each +operation, since operations are not represented by a C++ class. This is done +using a `TypeIDAllocator`, which can allocate a new unique `TypeID` at runtime. + +### Type representation and registration + +Unlike operations, types need to define a C++ storage class that takes care of +type parameters. They also need to define another C++ class to access that +storage. `DynamicTypeStorage` defines the storage of types defined at runtime, +and `DynamicType` gives access to the storage, as well as defining useful +functions. A `DynamicTypeStorage` contains a list of `Attribute` type +parameters, as well as a pointer to the type definition. + +Types are registered using the `Dialect::addType` method, which expect a +`TypeID` that is generated using a `TypeIDAllocator`. The type uniquer also +register the type with the given `TypeID`. This mean that we can reuse our +single `DynamicType` with different `TypeID` to represent the different types +defined at runtime. + +Since the different types defined at runtime have different `TypeID`, it is not +possible to use `TypeID` to cast a `Type` into a `DynamicType`. Thus, similar to +`Dialect`, all `DynamicType` define a `IsDynamicTypeTrait`, so casting a `Type` +to a `DynamicType` boils down to querying the `IsDynamicTypeTrait` trait. diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -45,6 +45,17 @@ T::getTypeID()); } + /// This method is used by Dialect objects to register attributes with + /// custom TypeIDs. + /// The use of this method is in general discouraged in favor of + /// 'get(dialect)'; + static AbstractAttribute get(Dialect &dialect, + detail::InterfaceMap &&interfaceMap, + HasTraitFn &&hasTrait, TypeID typeID) { + return AbstractAttribute(dialect, std::move(interfaceMap), + std::move(hasTrait), typeID); + } + /// Return the dialect this attribute was registered to. Dialect &getDialect() const { return const_cast(dialect); } @@ -175,14 +186,23 @@ // MLIRContext. This class manages all creation and uniquing of attributes. class AttributeUniquer { public: + /// Get an uniqued instance of an attribute T. + template + static T get(MLIRContext *ctx, Args &&...args) { + return getWithTypeID(ctx, T::getTypeID(), + std::forward(args)...); + } + /// Get an uniqued instance of a parametric attribute T. + /// The use of this method is in general discouraged in favor of + /// 'get(ctx, args)'. template static typename std::enable_if_t< !std::is_same::value, T> - get(MLIRContext *ctx, Args &&...args) { + getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) { #ifndef NDEBUG if (!ctx->getAttributeUniquer().isParametricStorageInitialized( - T::getTypeID())) + typeID)) llvm::report_fatal_error( llvm::Twine("can't create Attribute '") + llvm::getTypeName() + "' because storage uniquer isn't initialized: the dialect was likely " @@ -190,30 +210,32 @@ "in the Dialect::initialize() method."); #endif return ctx->getAttributeUniquer().get( - [ctx](AttributeStorage *storage) { - initializeAttributeStorage(storage, ctx, T::getTypeID()); + [typeID, ctx](AttributeStorage *storage) { + initializeAttributeStorage(storage, ctx, typeID); // Execute any additional attribute storage initialization with the // context. static_cast(storage)->initialize(ctx); }, - T::getTypeID(), std::forward(args)...); + typeID, std::forward(args)...); } /// Get an uniqued instance of a singleton attribute T. + /// The use of this method is in general discouraged in favor of + /// 'get(ctx, args)'. template static typename std::enable_if_t< std::is_same::value, T> - get(MLIRContext *ctx) { + getWithTypeID(MLIRContext *ctx, TypeID typeID) { #ifndef NDEBUG if (!ctx->getAttributeUniquer().isSingletonStorageInitialized( - T::getTypeID())) + typeID)) llvm::report_fatal_error( llvm::Twine("can't create Attribute '") + llvm::getTypeName() + "' because storage uniquer isn't initialized: the dialect was likely " "not loaded, or the attribute wasn't added with addAttributes<...>() " "in the Dialect::initialize() method."); #endif - return ctx->getAttributeUniquer().get(T::getTypeID()); + return ctx->getAttributeUniquer().get(typeID); } template @@ -224,23 +246,33 @@ std::forward(args)...); } + /// Register an attribute instance T with the uniquer. + template + static void registerAttribute(MLIRContext *ctx) { + registerAttribute(ctx, T::getTypeID()); + } + /// Register a parametric attribute instance T with the uniquer. + /// The use of this method is in general discouraged in favor of + /// 'registerAttribute(ctx)'. template static typename std::enable_if_t< !std::is_same::value> - registerAttribute(MLIRContext *ctx) { + registerAttribute(MLIRContext *ctx, TypeID typeID) { ctx->getAttributeUniquer() - .registerParametricStorageType(T::getTypeID()); + .registerParametricStorageType(typeID); } /// Register a singleton attribute instance T with the uniquer. + /// The use of this method is in general discouraged in favor of + /// 'registerAttribute(ctx)'. template static typename std::enable_if_t< std::is_same::value> - registerAttribute(MLIRContext *ctx) { + registerAttribute(MLIRContext *ctx, TypeID typeID) { ctx->getAttributeUniquer() .registerSingletonStorageType( - T::getTypeID(), [ctx](AttributeStorage *storage) { - initializeAttributeStorage(storage, ctx, T::getTypeID()); + typeID, [ctx, typeID](AttributeStorage *storage) { + initializeAttributeStorage(storage, ctx, typeID); }); } diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -212,6 +212,11 @@ (void)std::initializer_list{0, (addAttribute(), 0)...}; } + /// Register an attribute instance with this dialect. + /// The use of this method is in general discouraged in favor of + /// 'addAttributes()'. + void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo); + /// Enable support for unregistered operations. void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; } @@ -237,7 +242,6 @@ addAttribute(T::getTypeID(), AbstractAttribute::get(*this)); detail::AttributeUniquer::registerAttribute(context); } - void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo); /// Register a type instance with this dialect. template void addType() { diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/ExtensibleDialect.h @@ -0,0 +1,500 @@ +//===- ExtensibleDialect.h - Extensible dialect -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Dialects that can register new operations/types/attributes at runtime. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_EXTENSIBLEDIALECT_H +#define MLIR_IR_EXTENSIBLEDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/TypeID.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir { +class MLIRContext; +class AsmPrinter; +class AsmParser; +class ParseResult; +class OptionalParseResult; +class ExtensibleDialect; + +namespace detail { +struct DynamicTypeStorage; +struct DynamicAttrStorage; +} // namespace detail + +//===----------------------------------------------------------------------===// +// Dynamic type +//===----------------------------------------------------------------------===// + +class DynamicType; + +/// This is the definition of a dynamic type. It stores the parser, +/// the printer, and the verifier. +/// Each dynamic type definition refer to one instance of this class. +class DynamicTypeDefinition : NonMovableTypeIDOwner { +public: + using VerifierFn = llvm::unique_function, ArrayRef) const>; + using ParserFn = llvm::unique_function &parsedAttributes) + const>; + using PrinterFn = llvm::unique_function params) const>; + + /// Create a new type definition at runtime. The op is registered only after + /// passing it to the dialect using registerDynamicType. + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier); + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier, + ParserFn &&parser, PrinterFn &&printer); + + /// Check that the type parameters are valid. + LogicalResult verify(function_ref emitError, + ArrayRef params) const { + return verifier(emitError, params); + } + + /// Get the MLIRContext in which the dynamic types are uniqued. + MLIRContext &getContext() const { return *ctx; } + + /// Get the name of the type, in the format 'typename' and + /// not 'dialectname.typename'. + StringRef getName() const { return name; } + + /// Get the dialect defining the type. + ExtensibleDialect *getDialect() const { return dialect; } + +private: + DynamicTypeDefinition(StringRef name, ExtensibleDialect *dialect, + VerifierFn &&verifier, ParserFn &&parser, + PrinterFn &&printer); + + /// This constructor should only be used when we need a pointer to + /// the DynamicTypeDefinition in the verifier, the parser, or the printer. + /// The verifier, parser, and printer need thus to be initialized after the + /// constructor. + DynamicTypeDefinition(ExtensibleDialect *dialect, StringRef name); + + /// Register the concrete type in the type Uniquer. + void registerInTypeUniquer(); + + /// The name should be prefixed with the dialect name followed by '.'. + std::string name; + + /// Dialect in which this type is defined. + ExtensibleDialect *dialect; + + /// Verifier for the type parameters. + VerifierFn verifier; + + /// Parse the type parameters. + ParserFn parser; + + /// Print the type parameters. + PrinterFn printer; + + /// Context in which the concrete types are uniqued. + MLIRContext *ctx; + + friend ExtensibleDialect; + friend DynamicType; +}; + +/// This trait is implemented by all dynamic types, and should not be +/// implemented by any other type. +/// The trait is only used to check if a type is a dynamic type or not. +/// This is required because dynamic type do not have a single TypeID. +template +class IsDynamicTypeTrait + : TypeTrait::TraitBase {}; + +/// A type defined at runtime. +/// Each DynamicType instance represent a different dynamic type. +class DynamicType + : public Type::TypeBase { +public: + // Inherit Base constructors. + using Base::Base; + + /// Get an instance of a dynamic type given a dynamic type definition and + /// type parameters. + /// This function does not call the type verifier. + static DynamicType get(DynamicTypeDefinition *typeDef, + ArrayRef params = {}); + + /// Get an instance of a dynamic type given a dynamic type definition and type + /// parameters. + /// This function also call the verifier to check if the parameters are valid. + static DynamicType getChecked(function_ref emitError, + DynamicTypeDefinition *typeDef, + ArrayRef params = {}); + + /// Get the type definition of the concrete type. + DynamicTypeDefinition *getTypeDef(); + + /// Get the type parameters. + ArrayRef getParams(); + + /// Check if a type is a specific dynamic type. + static bool isa(Type type, DynamicTypeDefinition *typeDef) { + return type.getTypeID() == typeDef->getTypeID(); + } + + /// Check if a type is a dynamic type. + static bool classof(Type type); + + /// Parse the dynamic type parameters and construct the type. + /// The parameters are either empty, and nothing is parsed, + /// or they are in the format '<>' or ''. + static ParseResult parse(AsmParser &parser, DynamicTypeDefinition *typeDef, + DynamicType &parsedType); + + /// Print the dynamic type with the format + /// 'type' or 'type<>' if there is no parameters, or 'type'. + void print(AsmPrinter &printer); +}; + +//===----------------------------------------------------------------------===// +// Dynamic attribute +//===----------------------------------------------------------------------===// + +class DynamicAttr; + +/// This is the definition of a dynamic attribute. It stores the parser, +/// the printer, and the verifier. +/// Each dynamic attribute definition refer to one instance of this class. +class DynamicAttrDefinition : NonMovableTypeIDOwner { +public: + using VerifierFn = llvm::unique_function, ArrayRef) const>; + using ParserFn = llvm::unique_function &parsedAttributes) + const>; + using PrinterFn = llvm::unique_function params) const>; + + /// Create a new attribute definition at runtime. The attribute is registered + /// only after passing it to the dialect using registerDynamicAttr. + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier); + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier, + ParserFn &&parser, PrinterFn &&printer); + + /// Check that the attribute parameters are valid. + LogicalResult verify(function_ref emitError, + ArrayRef params) const { + return verifier(emitError, params); + } + + /// Get the MLIRContext in which the dynamic attributes are uniqued. + MLIRContext &getContext() const { return *ctx; } + + /// Get the name of the attribute, in the format 'attrname' and + /// not 'dialectname.attrname'. + StringRef getName() const { return name; } + + /// Get the dialect defining the attribute. + ExtensibleDialect *getDialect() const { return dialect; } + +private: + DynamicAttrDefinition(StringRef name, ExtensibleDialect *dialect, + VerifierFn &&verifier, ParserFn &&parser, + PrinterFn &&printer); + + /// This constructor should only be used when we need a pointer to + /// the DynamicAttrDefinition in the verifier, the parser, or the printer. + /// The verifier, parser, and printer need thus to be initialized after the + /// constructor. + DynamicAttrDefinition(ExtensibleDialect *dialect, StringRef name); + + /// Register the concrete attribute in the attribute Uniquer. + void registerInAttrUniquer(); + + /// The name should be prefixed with the dialect name followed by '.'. + std::string name; + + /// Dialect in which this attribute is defined. + ExtensibleDialect *dialect; + + /// Verifier for the attribute parameters. + VerifierFn verifier; + + /// Parse the attribute parameters. + ParserFn parser; + + /// Print the attribute parameters. + PrinterFn printer; + + /// Context in which the concrete attributes are uniqued. + MLIRContext *ctx; + + friend ExtensibleDialect; + friend DynamicAttr; +}; + +/// This trait is implemented by all dynamic attributes, and should not be +/// implemented by any other attribute. +/// The trait is only used to check if an attribute is a dynamic attribute or +/// not. This is required because dynamic attribute do not have a single TypeID. +template +class IsDynamicAttrTrait + : AttributeTrait::TraitBase {}; + +/// An attribute defined at runtime. +/// Each DynamicAttr instance represent a different dynamic attribute. +class DynamicAttr : public Attribute::AttrBase { +public: + // Inherit Base constructors. + using Base::Base; + + /// Get an instance of a dynamic attribute given a dynamic attribute + /// definition and attribute parameters. + /// This function does not call the attribute verifier. + static DynamicAttr get(DynamicAttrDefinition *attrDef, + ArrayRef params = {}); + + /// Get an instance of a dynamic attribute given a dynamic attribute + /// definition and attribute parameters. + /// This function also call the verifier to check if the parameters are valid. + static DynamicAttr getChecked(function_ref emitError, + DynamicAttrDefinition *attrDef, + ArrayRef params = {}); + + /// Get the attribute definition of the concrete attribute. + DynamicAttrDefinition *getAttrDef(); + + /// Get the attribute parameters. + ArrayRef getParams(); + + /// Check if an attribute is a specific dynamic attribute. + static bool isa(Attribute attr, DynamicAttrDefinition *attrDef) { + return attr.getTypeID() == attrDef->getTypeID(); + } + + /// Check if an attribute is a dynamic attribute. + static bool classof(Attribute attr); + + /// Parse the dynamic attribute parameters and construct the attribute. + /// The parameters are either empty, and nothing is parsed, + /// or they are in the format '<>' or ''. + static ParseResult parse(AsmParser &parser, DynamicAttrDefinition *attrDef, + DynamicAttr &parsedAttr); + + /// Print the dynamic attribute with the format 'attrname' if there is no + /// parameters, or 'attrname'. + void print(AsmPrinter &printer); +}; + +//===----------------------------------------------------------------------===// +// Dynamic operation +//===----------------------------------------------------------------------===// + +/// The definition of a dynamic operation. +/// It contains the name of the operation, its owning dialect, a verifier, +/// a printer, and parser. +class DynamicOpDefinition { +public: + /// Create a new op at runtime. The op is registered only after passing it to + /// the dialect using registerDynamicOp. + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn); + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::ParseAssemblyFn &&parseFn, + OperationName::PrintAssemblyFn &&printFn); + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::ParseAssemblyFn &&parseFn, + OperationName::PrintAssemblyFn &&printFn, + OperationName::FoldHookFn &&foldHookFn, + OperationName::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn); + + void setVerifyFn(OperationName::VerifyInvariantsFn &&verify) { + verifyFn = std::move(verify); + } + + void setParseFn(OperationName::ParseAssemblyFn &&parse) { + parseFn = std::move(parse); + } + + void setPrintFn(OperationName::PrintAssemblyFn &&print) { + printFn = std::move(print); + } + + void setFoldHookFn(OperationName::FoldHookFn &&foldHook) { + foldHookFn = std::move(foldHook); + } + + void + setGetCanonicalizationPatternsFn(OperationName::GetCanonicalizationPatternsFn + &&getCanonicalizationPatterns) { + getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns); + } + +private: + DynamicOpDefinition(StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::ParseAssemblyFn &&parseFn, + OperationName::PrintAssemblyFn &&printFn, + OperationName::FoldHookFn &&foldHookFn, + OperationName::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn); + + /// Unique identifier for this operation. + TypeID typeID; + + /// Name of the operation. + /// The name is prefixed with the dialect name. + std::string name; + + /// Dialect defining this operation. + ExtensibleDialect *dialect; + + OperationName::VerifyInvariantsFn verifyFn; + OperationName::ParseAssemblyFn parseFn; + OperationName::PrintAssemblyFn printFn; + OperationName::FoldHookFn foldHookFn; + OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn; + + friend ExtensibleDialect; +}; + +//===----------------------------------------------------------------------===// +// Extensible dialect +//===----------------------------------------------------------------------===// + +/// A dialect that can be extended with new operations/types/attributes at +/// runtime. +class ExtensibleDialect : public mlir::Dialect { +public: + ExtensibleDialect(StringRef name, MLIRContext *ctx, TypeID typeID); + + /// Add a new type defined at runtime to the dialect. + void registerDynamicType(std::unique_ptr &&type); + + /// Add a new attribute defined at runtime to the dialect. + void registerDynamicAttr(std::unique_ptr &&attr); + + /// Add a new operation defined at runtime to the dialect. + void registerDynamicOp(std::unique_ptr &&type); + + /// Check if the dialect is an extensible dialect. + static bool classof(const mlir::Dialect *dialect); + + /// Returns nullptr if the definition was not found. + DynamicTypeDefinition *lookupTypeDefinition(StringRef name) const { + auto it = nameToDynTypes.find(name); + if (it == nameToDynTypes.end()) + return nullptr; + return it->second; + } + + /// Returns nullptr if the definition was not found. + DynamicTypeDefinition *lookupTypeDefinition(TypeID id) const { + auto it = dynTypes.find(id); + if (it == dynTypes.end()) + return nullptr; + return it->second.get(); + } + + /// Returns nullptr if the definition was not found. + DynamicAttrDefinition *lookupAttrDefinition(StringRef name) const { + auto it = nameToDynAttrs.find(name); + if (it == nameToDynAttrs.end()) + return nullptr; + return it->second; + } + + /// Returns nullptr if the definition was not found. + DynamicAttrDefinition *lookupAttrDefinition(TypeID id) const { + auto it = dynAttrs.find(id); + if (it == dynAttrs.end()) + return nullptr; + return it->second.get(); + } + +protected: + /// Parse the dynamic type 'typeName' in the dialect 'dialect'. + /// typename should not be prefixed with the dialect name. + /// If the dynamic type does not exist, return no value. + /// Otherwise, parse it, and return the parse result. + /// If the parsing succeed, put the resulting type in 'resultType'. + OptionalParseResult parseOptionalDynamicType(StringRef typeName, + AsmParser &parser, + Type &resultType) const; + + /// If 'type' is a dynamic type, print it. + /// Returns success if the type was printed, and failure if the type was not a + /// dynamic type. + static LogicalResult printIfDynamicType(Type type, AsmPrinter &printer); + + /// Parse the dynamic attribute 'attrName' in the dialect 'dialect'. + /// attrname should not be prefixed with the dialect name. + /// If the dynamic attribute does not exist, return no value. + /// Otherwise, parse it, and return the parse result. + /// If the parsing succeed, put the resulting attribute in 'resultAttr'. + OptionalParseResult parseOptionalDynamicAttr(StringRef attrName, + AsmParser &parser, + Attribute &resultAttr) const; + + /// If 'attr' is a dynamic attribute, print it. + /// Returns success if the attribute was printed, and failure if the + /// attribute was not a dynamic attribute. + static LogicalResult printIfDynamicAttr(Attribute attr, AsmPrinter &printer); + +private: + /// The set of all dynamic types registered. + llvm::DenseMap> dynTypes; + + /// This structure allows to get in O(1) a dynamic type given its name. + llvm::StringMap nameToDynTypes; + + /// The set of all dynamic attributes registered. + llvm::DenseMap> dynAttrs; + + /// This structure allows to get in O(1) a dynamic attribute given its name. + llvm::StringMap nameToDynAttrs; + + /// Give DynamicOpDefinition access to allocateTypeID. + friend DynamicOpDefinition; + + /// Allocates a type ID to uniquify operations. + TypeID allocateTypeID() { return typeIDAllocator.allocate(); } + + /// Owns the TypeID generated at runtime for operations. + TypeIDAllocator typeIDAllocator; +}; +} // namespace mlir + +namespace llvm { +/// Provide isa functionality for ExtensibleDialect. +/// This is to override the isa functionality for Dialect. +template <> +struct isa_impl { + static inline bool doit(const ::mlir::Dialect &dialect) { + return mlir::ExtensibleDialect::classof(&dialect); + } +}; +} // namespace llvm + +#endif // MLIR_IR_EXTENSIBLEDIALECT_H diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -324,6 +324,9 @@ // UpperCamel) and prefixed with `get` or `set` depending on if it is a getter // or setter. int emitAccessorPrefix = kEmitAccessorPrefix_Raw; + + // If this dialect can be extended at runtime with new operations or types. + bit isExtensible = 0; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -163,13 +163,22 @@ /// A utility class to get, or create, unique instances of types within an /// MLIRContext. This class manages all creation and uniquing of types. struct TypeUniquer { + /// Get an uniqued instance of a type T. + template + static T get(MLIRContext *ctx, Args &&...args) { + return getWithTypeID(ctx, T::getTypeID(), + std::forward(args)...); + } + /// Get an uniqued instance of a parametric type T. + /// The use of this method is in general discouraged in favor of + /// 'get(ctx, args)'. template static typename std::enable_if_t< !std::is_same::value, T> - get(MLIRContext *ctx, Args &&...args) { + getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) { #ifndef NDEBUG - if (!ctx->getTypeUniquer().isParametricStorageInitialized(T::getTypeID())) + if (!ctx->getTypeUniquer().isParametricStorageInitialized(typeID)) llvm::report_fatal_error( llvm::Twine("can't create type '") + llvm::getTypeName() + "' because storage uniquer isn't initialized: the dialect was likely " @@ -177,25 +186,27 @@ "in the Dialect::initialize() method."); #endif return ctx->getTypeUniquer().get( - [&](TypeStorage *storage) { - storage->initialize(AbstractType::lookup(T::getTypeID(), ctx)); + [&, typeID](TypeStorage *storage) { + storage->initialize(AbstractType::lookup(typeID, ctx)); }, - T::getTypeID(), std::forward(args)...); + typeID, std::forward(args)...); } /// Get an uniqued instance of a singleton type T. + /// The use of this method is in general discouraged in favor of + /// 'get(ctx, args)'. template static typename std::enable_if_t< std::is_same::value, T> - get(MLIRContext *ctx) { + getWithTypeID(MLIRContext *ctx, TypeID typeID) { #ifndef NDEBUG - if (!ctx->getTypeUniquer().isSingletonStorageInitialized(T::getTypeID())) + if (!ctx->getTypeUniquer().isSingletonStorageInitialized(typeID)) llvm::report_fatal_error( llvm::Twine("can't create type '") + llvm::getTypeName() + "' because storage uniquer isn't initialized: the dialect was likely " "not loaded, or the type wasn't added with addTypes<...>() " "in the Dialect::initialize() method."); #endif - return ctx->getTypeUniquer().get(T::getTypeID()); + return ctx->getTypeUniquer().get(typeID); } /// Change the mutable component of the given type instance in the provided @@ -208,22 +219,32 @@ std::forward(args)...); } + /// Register a type instance T with the uniquer. + template + static void registerType(MLIRContext *ctx) { + registerType(ctx, T::getTypeID()); + } + /// Register a parametric type instance T with the uniquer. + /// The use of this method is in general discouraged in favor of + /// 'registerType(ctx)'. template static typename std::enable_if_t< !std::is_same::value> - registerType(MLIRContext *ctx) { + registerType(MLIRContext *ctx, TypeID typeID) { ctx->getTypeUniquer().registerParametricStorageType( - T::getTypeID()); + typeID); } /// Register a singleton type instance T with the uniquer. + /// The use of this method is in general discouraged in favor of + /// 'registerType(ctx)'. template static typename std::enable_if_t< std::is_same::value> - registerType(MLIRContext *ctx) { + registerType(MLIRContext *ctx, TypeID typeID) { ctx->getTypeUniquer().registerSingletonStorageType( - T::getTypeID(), [&](TypeStorage *storage) { - storage->initialize(AbstractType::lookup(T::getTypeID(), ctx)); + typeID, [&ctx, typeID](TypeStorage *storage) { + storage->initialize(AbstractType::lookup(typeID, ctx)); }); } }; diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -82,6 +82,10 @@ /// type printing/parsing. bool useDefaultTypePrinterParser() const; + /// Returns true if this dialect can be extended at runtime with new + /// operations or types. + bool isExtensible() const; + // Returns whether two dialects are equal by checking the equality of the // underlying record. bool operator==(const Dialect &other) const; diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -12,6 +12,7 @@ Diagnostics.cpp Dialect.cpp Dominance.cpp + ExtensibleDialect.cpp FunctionImplementation.cpp FunctionSupport.cpp IntegerSet.cpp diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -0,0 +1,507 @@ +//===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Dialects that can register new operations/types/attributes at runtime. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/AttributeSupport.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Support/LogicalResult.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Dynamic types and attributes shared functions +//===----------------------------------------------------------------------===// + +/// Default parser for dynamic attribute or type parameters. +/// Parse in the format '(<>)?' or ''. +static LogicalResult +typeOrAttrParser(AsmParser &parser, SmallVectorImpl &parsedParams) { + // No parameters + if (parser.parseOptionalLess() || !parser.parseOptionalGreater()) + return success(); + + Attribute attr; + if (parser.parseAttribute(attr)) + return failure(); + parsedParams.push_back(attr); + + while (parser.parseOptionalGreater()) { + Attribute attr; + if (parser.parseComma() || parser.parseAttribute(attr)) + return failure(); + parsedParams.push_back(attr); + } + + return success(); +} + +/// Default printer for dynamic attribute or type parameters. +/// Print in the format '(<>)?' or ''. +static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef params) { + if (params.empty()) + return; + + printer << "<"; + interleaveComma(params, printer.getStream()); + printer << ">"; +} + +//===----------------------------------------------------------------------===// +// Dynamic type +//===----------------------------------------------------------------------===// + +std::unique_ptr +DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect, + VerifierFn &&verifier) { + return DynamicTypeDefinition::get(name, dialect, std::move(verifier), + typeOrAttrParser, typeOrAttrPrinter); +} + +std::unique_ptr +DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect, + VerifierFn &&verifier, ParserFn &&parser, + PrinterFn &&printer) { + return std::unique_ptr( + new DynamicTypeDefinition(name, dialect, std::move(verifier), + std::move(parser), std::move(printer))); +} + +DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef, + ExtensibleDialect *dialect, + VerifierFn &&verifier, + ParserFn &&parser, + PrinterFn &&printer) + : name(nameRef), dialect(dialect), verifier(std::move(verifier)), + parser(std::move(parser)), printer(std::move(printer)), + ctx(dialect->getContext()) { + assert(!nameRef.contains('.') && + "name should not be prefixed by the dialect name"); +} + +DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect, + StringRef nameRef) + : name(nameRef), dialect(dialect), ctx(dialect->getContext()) { + assert(!nameRef.contains('.') && + "name should not be prefixed by the dialect name"); +} + +void DynamicTypeDefinition::registerInTypeUniquer() { + detail::TypeUniquer::registerType(&getContext(), getTypeID()); +} + +namespace mlir { +namespace detail { +/// Storage of DynamicType. +/// Contains a pointer to the type definition and type parameters. +struct DynamicTypeStorage : public TypeStorage { + + using KeyTy = std::pair>; + + explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef, + ArrayRef params) + : typeDef(typeDef), params(params) {} + + bool operator==(const KeyTy &key) const { + return typeDef == key.first && params == key.second; + } + + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + static DynamicTypeStorage *construct(TypeStorageAllocator &alloc, + const KeyTy &key) { + return new (alloc.allocate()) + DynamicTypeStorage(key.first, alloc.copyInto(key.second)); + } + + /// Definition of the type. + DynamicTypeDefinition *typeDef; + + /// The type parameters. + ArrayRef params; +}; +} // namespace detail +} // namespace mlir + +DynamicType DynamicType::get(DynamicTypeDefinition *typeDef, + ArrayRef params) { + auto &ctx = typeDef->getContext(); + return detail::TypeUniquer::getWithTypeID( + &ctx, typeDef->getTypeID(), typeDef, params); +} + +DynamicType +DynamicType::getChecked(function_ref emitError, + DynamicTypeDefinition *typeDef, + ArrayRef params) { + if (failed(typeDef->verify(emitError, params))) + return {}; + return get(typeDef, params); +} + +DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; } + +ArrayRef DynamicType::getParams() { return getImpl()->params; } + +bool DynamicType::classof(Type type) { + return type.hasTrait(); +} + +ParseResult DynamicType::parse(AsmParser &parser, + DynamicTypeDefinition *typeDef, + DynamicType &parsedType) { + SmallVector params; + if (failed(typeDef->parser(parser, params))) + return failure(); + parsedType = parser.getChecked(typeDef, params); + if (!parsedType) + return failure(); + return success(); +} + +void DynamicType::print(AsmPrinter &printer) { + printer << getTypeDef()->getName(); + getTypeDef()->printer(printer, getParams()); +} + +//===----------------------------------------------------------------------===// +// Dynamic attribute +//===----------------------------------------------------------------------===// + +std::unique_ptr +DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect, + VerifierFn &&verifier) { + return DynamicAttrDefinition::get(name, dialect, std::move(verifier), + typeOrAttrParser, typeOrAttrPrinter); +} + +std::unique_ptr +DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect, + VerifierFn &&verifier, ParserFn &&parser, + PrinterFn &&printer) { + return std::unique_ptr( + new DynamicAttrDefinition(name, dialect, std::move(verifier), + std::move(parser), std::move(printer))); +} + +DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef, + ExtensibleDialect *dialect, + VerifierFn &&verifier, + ParserFn &&parser, + PrinterFn &&printer) + : name(nameRef), dialect(dialect), verifier(std::move(verifier)), + parser(std::move(parser)), printer(std::move(printer)), + ctx(dialect->getContext()) { + assert(!nameRef.contains('.') && + "name should not be prefixed by the dialect name"); +} + +DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect, + StringRef nameRef) + : name(nameRef), dialect(dialect), ctx(dialect->getContext()) { + assert(!nameRef.contains('.') && + "name should not be prefixed by the dialect name"); +} + +void DynamicAttrDefinition::registerInAttrUniquer() { + detail::AttributeUniquer::registerAttribute(&getContext(), + getTypeID()); +} + +namespace mlir { +namespace detail { +/// Storage of DynamicAttr. +/// Contains a pointer to the attribute definition and attribute parameters. +struct DynamicAttrStorage : public AttributeStorage { + + using KeyTy = std::pair>; + + explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef, + ArrayRef params) + : attrDef(attrDef), params(params) {} + + bool operator==(const KeyTy &key) const { + return attrDef == key.first && params == key.second; + } + + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc, + const KeyTy &key) { + return new (alloc.allocate()) + DynamicAttrStorage(key.first, alloc.copyInto(key.second)); + } + + /// Definition of the type. + DynamicAttrDefinition *attrDef; + + /// The type parameters. + ArrayRef params; +}; +} // namespace detail +} // namespace mlir + +DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef, + ArrayRef params) { + auto &ctx = attrDef->getContext(); + return detail::AttributeUniquer::getWithTypeID( + &ctx, attrDef->getTypeID(), attrDef, params); +} + +DynamicAttr +DynamicAttr::getChecked(function_ref emitError, + DynamicAttrDefinition *attrDef, + ArrayRef params) { + if (failed(attrDef->verify(emitError, params))) + return {}; + return get(attrDef, params); +} + +DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; } + +ArrayRef DynamicAttr::getParams() { return getImpl()->params; } + +bool DynamicAttr::classof(Attribute attr) { + return attr.hasTrait(); +} + +ParseResult DynamicAttr::parse(AsmParser &parser, + DynamicAttrDefinition *attrDef, + DynamicAttr &parsedAttr) { + SmallVector params; + if (failed(attrDef->parser(parser, params))) + return failure(); + parsedAttr = parser.getChecked(attrDef, params); + if (!parsedAttr) + return failure(); + return success(); +} + +void DynamicAttr::print(AsmPrinter &printer) { + printer << getAttrDef()->getName(); + getAttrDef()->printer(printer, getParams()); +} + +//===----------------------------------------------------------------------===// +// Dynamic operation +//===----------------------------------------------------------------------===// + +DynamicOpDefinition::DynamicOpDefinition( + StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::ParseAssemblyFn &&parseFn, + OperationName::PrintAssemblyFn &&printFn, + OperationName::FoldHookFn &&foldHookFn, + OperationName::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn) + : typeID(dialect->allocateTypeID()), + name((dialect->getNamespace() + "." + name).str()), dialect(dialect), + verifyFn(std::move(verifyFn)), parseFn(std::move(parseFn)), + printFn(std::move(printFn)), foldHookFn(std::move(foldHookFn)), + getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)) { + assert(!name.contains('.') && + "name should not be prefixed by the dialect name"); +} + +std::unique_ptr +DynamicOpDefinition::get(StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn) { + auto parseFn = [](OpAsmParser &parser, OperationState &result) { + return parser.emitError( + parser.getCurrentLocation(), + "dynamic operation do not define any parser function"); + }; + + auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) { + printer.printGenericOp(op); + }; + + return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), + std::move(parseFn), std::move(printFn)); +} + +std::unique_ptr +DynamicOpDefinition::get(StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::ParseAssemblyFn &&parseFn, + OperationName::PrintAssemblyFn &&printFn) { + auto foldHookFn = [](Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return failure(); + }; + + auto getCanonicalizationPatternsFn = [](OwningRewritePatternList &, + MLIRContext *) {}; + + return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), + std::move(parseFn), std::move(printFn), + std::move(foldHookFn), + std::move(getCanonicalizationPatternsFn)); +} + +std::unique_ptr +DynamicOpDefinition::get(StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::ParseAssemblyFn &&parseFn, + OperationName::PrintAssemblyFn &&printFn, + OperationName::FoldHookFn &&foldHookFn, + OperationName::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn) { + return std::unique_ptr(new DynamicOpDefinition( + name, dialect, std::move(verifyFn), std::move(parseFn), + std::move(printFn), std::move(foldHookFn), + std::move(getCanonicalizationPatternsFn))); +} + +//===----------------------------------------------------------------------===// +// Extensible dialect +//===----------------------------------------------------------------------===// + +namespace { +/// Interface that can only be implemented by extensible dialects. +/// The interface is used to check if a dialect is extensible or not. +class IsExtensibleDialect : public DialectInterface::Base { +public: + IsExtensibleDialect(Dialect *dialect) : Base(dialect) {} +}; +} // namespace + +ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx, + TypeID typeID) + : Dialect(Identifier::get(name, ctx), ctx, typeID) { + addInterfaces(); +} + +void ExtensibleDialect::registerDynamicType( + std::unique_ptr &&type) { + auto *typePtr = type.get(); + auto typeID = type->getTypeID(); + auto name = type->getName(); + auto *dialect = type->getDialect(); + + assert(dialect == this && + "trying to register a dynamic type in the wrong dialect"); + + // If a type with the same name is already defined, fail. + auto registered = dynTypes.try_emplace(typeID, std::move(type)).second; + (void)registered; + assert(registered && "type TypeID was not unique"); + + registered = nameToDynTypes.insert({name, typePtr}).second; + (void)registered; + assert(registered && + "Trying to create a new dynamic type with an existing name"); + + auto abstractType = + AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(), + DynamicType::getHasTraitFn(), typeID); + + /// Add the type to the dialect and the type uniquer. + addType(typeID, std::move(abstractType)); + typePtr->registerInTypeUniquer(); +} + +void ExtensibleDialect::registerDynamicAttr( + std::unique_ptr &&attr) { + auto *attrPtr = attr.get(); + auto typeID = attr->getTypeID(); + auto name = attr->getName(); + auto *dialect = attr->getDialect(); + + assert(dialect == this && + "trying to register a dynamic attribute in the wrong dialect"); + + // If an attribute with the same name is already defined, fail. + auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second; + (void)registered; + assert(registered && "attribute TypeID was not unique"); + + registered = nameToDynAttrs.insert({name, attrPtr}).second; + (void)registered; + assert(registered && + "Trying to create a new dynamic attribute with an existing name"); + + auto abstractAttr = + AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(), + DynamicAttr::getHasTraitFn(), typeID); + + /// Add the type to the dialect and the type uniquer. + addAttribute(typeID, std::move(abstractAttr)); + attrPtr->registerInAttrUniquer(); +} + +void ExtensibleDialect::registerDynamicOp( + std::unique_ptr &&op) { + assert(op->dialect == this && + "trying to register a dynamic op in the wrong dialect"); + auto hasTraitFn = [](TypeID traitId) { return false; }; + + RegisteredOperationName::insert( + op->name, *op->dialect, op->typeID, std::move(op->parseFn), + std::move(op->printFn), std::move(op->verifyFn), + std::move(op->foldHookFn), std::move(op->getCanonicalizationPatternsFn), + detail::InterfaceMap::get<>(), std::move(hasTraitFn), {}); +} + +bool ExtensibleDialect::classof(const Dialect *dialect) { + return const_cast(dialect) + ->getRegisteredInterface(); +} + +OptionalParseResult ExtensibleDialect::parseOptionalDynamicType( + StringRef typeName, AsmParser &parser, Type &resultType) const { + DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName); + if (!typeDef) + return llvm::None; + + DynamicType dynType; + if (DynamicType::parse(parser, typeDef, dynType)) + return failure(); + resultType = dynType; + return success(); +} + +LogicalResult ExtensibleDialect::printIfDynamicType(Type type, + AsmPrinter &printer) { + if (auto dynType = type.dyn_cast()) { + dynType.print(printer); + return success(); + } + return failure(); +} + +OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr( + StringRef attrName, AsmParser &parser, Attribute &resultAttr) const { + DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName); + if (!attrDef) + return llvm::None; + + DynamicAttr dynAttr; + if (DynamicAttr::parse(parser, attrDef, dynAttr)) + return failure(); + resultAttr = dynAttr; + return success(); +} + +LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute, + AsmPrinter &printer) { + if (auto dynAttr = attribute.dyn_cast()) { + dynAttr.print(printer); + return success(); + } + return failure(); +} diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -102,9 +102,14 @@ int prefix = def->getValueAsInt("emitAccessorPrefix"); if (prefix < 0 || prefix > static_cast(EmitPrefix::Both)) PrintFatalError(def->getLoc(), "Invalid accessor prefix value"); + return static_cast(prefix); } +bool Dialect::isExtensible() const { + return def->getValueAsBit("isExtensible"); +} + bool Dialect::operator==(const Dialect &other) const { return def == other.def; } diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/dynamic.mlir @@ -0,0 +1,126 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics | FileCheck %s +// Verify that extensible dialects can register dynamic operations and types. + +//===----------------------------------------------------------------------===// +// Dynamic type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @succeededDynamicTypeVerifier +func @succeededDynamicTypeVerifier() { + // CHECK: %{{.*}} = "unregistered_op"() : () -> !test.singleton_dyntype + "unregistered_op"() : () -> !test.singleton_dyntype + // CHECK-NEXT: "unregistered_op"() : () -> !test.pair_dyntype + "unregistered_op"() : () -> !test.pair_dyntype + // CHECK_NEXT: %{{.*}} = "unregistered_op"() : () -> !test.pair_dyntype, !test.singleton_dyntype> + "unregistered_op"() : () -> !test.pair_dyntype, !test.singleton_dyntype> + return +} + +// ----- + +func @failedDynamicTypeVerifier() { + // expected-error@+1 {{expected 0 type arguments, but had 1}} + "unregistered_op"() : () -> !test.singleton_dyntype + return +} + +// ----- + +func @failedDynamicTypeVerifier2() { + // expected-error@+1 {{expected 2 type arguments, but had 1}} + "unregistered_op"() : () -> !test.pair_dyntype + return +} + +// ----- + +// CHECK-LABEL: func @customTypeParserPrinter +func @customTypeParserPrinter() { + // CHECK: "unregistered_op"() : () -> !test.custom_assembly_format_dyntype + "unregistered_op"() : () -> !test.custom_assembly_format_dyntype + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Dynamic attribute +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @succeededDynamicAttributeVerifier +func @succeededDynamicAttributeVerifier() { + // CHECK: "unregistered_op"() {test_attr = #test.singleton_dynattr} : () -> () + "unregistered_op"() {test_attr = #test.singleton_dynattr} : () -> () + // CHECK-NEXT: "unregistered_op"() {test_attr = #test.pair_dynattr<3 : i32, 5 : i32>} : () -> () + "unregistered_op"() {test_attr = #test.pair_dynattr<3 : i32, 5 : i32>} : () -> () + // CHECK_NEXT: "unregistered_op"() {test_attr = #test.pair_dynattr<3 : i32, 5 : i32>} : () -> () + "unregistered_op"() {test_attr = #test.pair_dynattr<#test.pair_dynattr<3 : i32, 5 : i32>, f64>} : () -> () + return +} + +// ----- + +func @failedDynamicAttributeVerifier() { + // expected-error@+1 {{expected 0 attribute arguments, but had 1}} + "unregistered_op"() {test_attr = #test.singleton_dynattr} : () -> () + return +} + +// ----- + +func @failedDynamicAttributeVerifier2() { + // expected-error@+1 {{expected 2 attribute arguments, but had 1}} + "unregistered_op"() {test_attr = #test.pair_dynattr : () -> () + return +} + +// ----- + +// CHECK-LABEL: func @customAttributeParserPrinter +func @customAttributeParserPrinter() { + // CHECK: "unregistered_op"() {test_attr = #test.custom_assembly_format_dynattr} : () -> () + "unregistered_op"() {test_attr = #test.custom_assembly_format_dynattr} : () -> () + return +} + +//===----------------------------------------------------------------------===// +// Dynamic op +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @succeededDynamicOpVerifier +func @succeededDynamicOpVerifier(%a: f32) { + // CHECK: "test.generic_dynamic_op"() : () -> () + // CHECK-NEXT: %{{.*}} = "test.generic_dynamic_op"(%{{.*}}) : (f32) -> f64 + // CHECK-NEXT: %{{.*}}:2 = "test.one_operand_two_results"(%{{.*}}) : (f32) -> (f64, f64) + "test.generic_dynamic_op"() : () -> () + "test.generic_dynamic_op"(%a) : (f32) -> f64 + "test.one_operand_two_results"(%a) : (f32) -> (f64, f64) + return +} + +// ----- + +func @failedDynamicOpVerifier() { + // expected-error@+1 {{expected 1 operand, but had 0}} + "test.one_operand_two_results"() : () -> (f64, f64) + return +} + +// ----- + +func @failedDynamicOpVerifier2(%a: f32) { + // expected-error@+1 {{expected 2 results, but had 0}} + "test.one_operand_two_results"(%a) : (f32) -> () + return +} + +// ----- + +// CHECK-LABEL: func @customOpParserPrinter +func @customOpParserPrinter() { + // CHECK: test.custom_parser_printer_dynamic_op custom_keyword + test.custom_parser_printer_dynamic_op custom_keyword + return +} diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -15,12 +15,14 @@ #include "TestDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/Types.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" +#include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace test; @@ -216,6 +218,74 @@ #define GET_ATTRDEF_CLASSES #include "TestAttrDefs.cpp.inc" +//===----------------------------------------------------------------------===// +// Dynamic Attributes +//===----------------------------------------------------------------------===// + +/// Define a singleton dynamic attribute. +static std::unique_ptr +getSingletonDynamicAttr(TestDialect *testDialect) { + return DynamicAttrDefinition::get( + "singleton_dynattr", testDialect, + [](function_ref emitError, + ArrayRef args) { + if (!args.empty()) { + emitError() << "expected 0 attribute arguments, but had " + << args.size(); + return failure(); + } + return success(); + }); +} + +/// Define a dynamic attribute representing a pair or attributes. +static std::unique_ptr +getPairDynamicAttr(TestDialect *testDialect) { + return DynamicAttrDefinition::get( + "pair_dynattr", testDialect, + [](function_ref emitError, + ArrayRef args) { + if (args.size() != 2) { + emitError() << "expected 2 attribute arguments, but had " + << args.size(); + return failure(); + } + return success(); + }); +} + +static std::unique_ptr +getCustomAssemblyFormatDynamicAttr(TestDialect *testDialect) { + auto verifier = [](function_ref emitError, + ArrayRef args) { + if (args.size() != 2) { + emitError() << "expected 2 attribute arguments, but had " << args.size(); + return failure(); + } + return success(); + }; + + auto parser = [](AsmParser &parser, + llvm::SmallVectorImpl &parsedParams) { + Attribute leftAttr, rightAttr; + if (parser.parseLess() || parser.parseAttribute(leftAttr) || + parser.parseColon() || parser.parseAttribute(rightAttr) || + parser.parseGreater()) + return failure(); + parsedParams.push_back(leftAttr); + parsedParams.push_back(rightAttr); + return success(); + }; + + auto printer = [](AsmPrinter &printer, ArrayRef params) { + printer << "<" << params[0] << ":" << params[1] << ">"; + }; + + return DynamicAttrDefinition::get("custom_assembly_format_dynattr", + testDialect, std::move(verifier), + std::move(parser), std::move(printer)); +} + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// @@ -225,4 +295,44 @@ #define GET_ATTRDEF_LIST #include "TestAttrDefs.cpp.inc" >(); + registerDynamicAttr(getSingletonDynamicAttr(this)); + registerDynamicAttr(getPairDynamicAttr(this)); + registerDynamicAttr(getCustomAssemblyFormatDynamicAttr(this)); +} + +Attribute TestDialect::parseAttribute(DialectAsmParser &parser, + Type type) const { + StringRef attrTag; + if (failed(parser.parseKeyword(&attrTag))) + return Attribute(); + { + Attribute attr; + auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); + if (parseResult.hasValue()) + return attr; + } + + { + Attribute dynAttr; + auto parseResult = parseOptionalDynamicAttr(attrTag, parser, dynAttr); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynAttr; + return Attribute(); + } + } + + parser.emitError(parser.getNameLoc(), "unknown test attribute"); + return Attribute(); +} + +void TestDialect::printAttribute(Attribute attr, + DialectAsmPrinter &printer) const { + if (succeeded(generatedAttributePrinter(attr, printer))) + return; + + if (succeeded(printIfDynamicAttr(attr, printer))) + return; + + llvm_unreachable("unknown test attribute"); } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -23,6 +23,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/RegionKindInterface.h" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Reducer/ReductionPatternInterface.h" @@ -226,6 +227,54 @@ } // end anonymous namespace +//===----------------------------------------------------------------------===// +// Dynamic operations +//===----------------------------------------------------------------------===// + +std::unique_ptr getGenericDynamicOp(TestDialect *dialect) { + return DynamicOpDefinition::get("generic_dynamic_op", dialect, + [](Operation *op) { return success(); }); +} + +std::unique_ptr +getOneOperandTwoResultsDynamicOp(TestDialect *dialect) { + return DynamicOpDefinition::get( + "one_operand_two_results", dialect, [](Operation *op) { + if (op->getNumOperands() != 1) { + op->emitOpError() + << "expected 1 operand, but had " << op->getNumOperands(); + return failure(); + } + if (op->getNumResults() != 2) { + op->emitOpError() + << "expected 2 results, but had " << op->getNumResults(); + return failure(); + } + return success(); + }); +} + +std::unique_ptr +getCustomParserPrinterDynamicOp(TestDialect *dialect) { + auto verifier = [](Operation *op) { + if (op->getNumOperands() == 0 && op->getNumResults() == 0) + return success(); + op->emitError() << "operation should have no operands and no results"; + return failure(); + }; + + auto parser = [](OpAsmParser &parser, OperationState &state) { + return parser.parseKeyword("custom_keyword"); + }; + + auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) { + printer << op->getName() << " custom_keyword"; + }; + + return DynamicOpDefinition::get("custom_parser_printer_dynamic_op", dialect, + verifier, parser, printer); +} + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// @@ -260,6 +309,10 @@ #define GET_OP_LIST #include "TestOps.cpp.inc" >(); + registerDynamicOp(getGenericDynamicOp(this)); + registerDynamicOp(getOneOperandTwoResultsDynamicOp(this)); + registerDynamicOp(getCustomParserPrinterDynamicOp(this)); + addInterfaces(); allowUnknownOperations(); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -34,7 +34,7 @@ let hasRegionResultAttrVerify = 1; let hasOperationInterfaceFallback = 1; let hasNonDefaultDestructor = 1; - let useDefaultAttributePrinterParser = 1; + let isExtensible = 1; let dependentDialects = ["::mlir::DLTIDialect"]; let extraClassDeclaration = [{ @@ -52,6 +52,10 @@ // Storage for a custom fallback interface. void *fallbackEffectOpInterfaces; + ::mlir::Type parseTestType(::mlir::AsmParser &parser, + ::llvm::SetVector<::mlir::Type> &stack) const; + void printTestType(::mlir::Type type, ::mlir::AsmPrinter &printer, + ::llvm::SetVector<::mlir::Type> &stack) const; }]; } diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/Types.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/SetVector.h" @@ -215,6 +216,72 @@ #define GET_TYPEDEF_CLASSES #include "TestTypeDefs.cpp.inc" +//===----------------------------------------------------------------------===// +// Dynamic Types +//===----------------------------------------------------------------------===// + +/// Define a singleton dynamic type. +static std::unique_ptr +getSingletonDynamicType(TestDialect *testDialect) { + return DynamicTypeDefinition::get( + "singleton_dyntype", testDialect, + [](function_ref emitError, + ArrayRef args) { + if (!args.empty()) { + emitError() << "expected 0 type arguments, but had " << args.size(); + return failure(); + } + return success(); + }); +} + +/// Define a dynamic type representing a pair. +static std::unique_ptr +getPairDynamicType(TestDialect *testDialect) { + return DynamicTypeDefinition::get( + "pair_dyntype", testDialect, + [](function_ref emitError, + ArrayRef args) { + if (args.size() != 2) { + emitError() << "expected 2 type arguments, but had " << args.size(); + return failure(); + } + return success(); + }); +} + +static std::unique_ptr +getCustomAssemblyFormatDynamicType(TestDialect *testDialect) { + auto verifier = [](function_ref emitError, + ArrayRef args) { + if (args.size() != 2) { + emitError() << "expected 2 type arguments, but had " << args.size(); + return failure(); + } + return success(); + }; + + auto parser = [](AsmParser &parser, + llvm::SmallVectorImpl &parsedParams) { + Attribute leftAttr, rightAttr; + if (parser.parseLess() || parser.parseAttribute(leftAttr) || + parser.parseColon() || parser.parseAttribute(rightAttr) || + parser.parseGreater()) + return failure(); + parsedParams.push_back(leftAttr); + parsedParams.push_back(rightAttr); + return success(); + }; + + auto printer = [](AsmPrinter &printer, ArrayRef params) { + printer << "<" << params[0] << ":" << params[1] << ">"; + }; + + return DynamicTypeDefinition::get("custom_assembly_format_dyntype", + testDialect, std::move(verifier), + std::move(parser), std::move(printer)); +} + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// @@ -232,9 +299,14 @@ #include "TestTypeDefs.cpp.inc" >(); SimpleAType::attachInterface(*getContext()); + + registerDynamicType(getSingletonDynamicType(this)); + registerDynamicType(getPairDynamicType(this)); + registerDynamicType(getCustomAssemblyFormatDynamicType(this)); } -static Type parseTestType(AsmParser &parser, SetVector &stack) { +Type TestDialect::parseTestType(AsmParser &parser, + SetVector &stack) const { StringRef typeTag; if (failed(parser.parseKeyword(&typeTag))) return Type(); @@ -246,6 +318,16 @@ return genType; } + { + Type dynType; + auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynType; + return Type(); + } + } + if (typeTag != "test_rec") { parser.emitError(parser.getNameLoc()) << "unknown type!"; return Type(); @@ -281,11 +363,14 @@ return parseTestType(parser, stack); } -static void printTestType(Type type, AsmPrinter &printer, - SetVector &stack) { +void TestDialect::printTestType(Type type, AsmPrinter &printer, + SetVector &stack) const { if (succeeded(generatedTypePrinter(type, printer))) return; + if (succeeded(printIfDynamicType(type, printer))) + return; + auto rec = type.cast(); printer << "test_rec<" << rec.getName(); if (!stack.contains(rec)) { diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -89,9 +89,9 @@ /// {2}: initialization code that is emitted in the ctor body before calling /// initialize() static const char *const dialectDeclBeginStr = R"( -class {0} : public ::mlir::Dialect { +class {0} : public ::mlir::{3} { explicit {0}(::mlir::MLIRContext *context) - : ::mlir::Dialect(getDialectNamespace(), context, + : ::mlir::{3}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{ {2} initialize(); @@ -203,8 +203,10 @@ // Emit the start of the decl. std::string cppName = dialect.getCppClassName(); + StringRef superClassName = + dialect.isExtensible() ? "ExtensibleDialect" : "Dialect"; os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), - dependentDialectRegistrations); + dependentDialectRegistrations, superClassName); // Check for any attributes/types registered to this dialect. If there are, // add the hooks for parsing/printing.