diff --git a/mlir/docs/ExtensibleDialects.md b/mlir/docs/ExtensibleDialects.md new file mode 100644 --- /dev/null +++ b/mlir/docs/ExtensibleDialects.md @@ -0,0 +1,369 @@ +# Extensible dialects + +This file documents the design and API of the extensible dialects. Extensible +dialects are dialects that can be extended with new operations and types defined +at runtime. This allows for users to define dialects via with meta-programming, +or from another language, without having to recompile C++ code. + +[TOC] + +## Usage + +### Defining an extensible dialect + +Dialects defined in C++ can be extended with new operations, types, etc., at +runtime by inheriting from `mlir::ExtensibleDialect` instead of `mlir::Dialect` +(note that `ExtensibleDialect` inherits from `Dialect`). The `ExtensibleDialect` +class contains the necessary fields and methods to extend the dialect at +runtime. + +```c++ +class MyDialect : public mlir::ExtensibleDialect { + ... +} +``` + +For dialects defined in TableGen, this is done by setting the `isExtensible` +flag to `1`. + +```tablegen +def Test_Dialect : Dialect { + let isExtensible = 1; + ... +} +``` + +An extensible `Dialect` can be casted back to `ExtensibleDialect` using +`llvm::dyn_cast`, or `llvm::cast`: + +```c++ +if (auto extensibleDialect = llvm::dyn_cast(dialect)) { + ... +} +``` + +### Defining an operation at runtime + +The `DynamicOpDefinition` class represents the definition of an operation +defined at runtime. It is created using the `DynamicOpDefinition::get` +functions. An operation defined at runtime must provide a name, a dialect in +which the operation will be registered in, an operation verifier. It may also +optionally define a custom parser and a printer, fold hook, and more. + +```c++ +// The operation name, without the dialect name prefix. +StringRef name = "my_operation_name"; + +// The dialect defining the operation. +Dialect* dialect = ctx->getOrLoadDialect(); + +// Operation verifier definition. +AbstractOperation::VerifyInvariantsFn verifyFn = [](Operation* op) { + // Logic for the operation verification. + ... +} + +// Parser function definition. +AbstractOperation::ParseAssemblyFn parseFn = + [](OpAsmParser &parser, OperationState &state) { + // Parse the operation, given that the name is already parsed. + ... +}; + +// Printer function +auto printFn = [](Operation *op, OpAsmPrinter &printer) { + printer << op->getName(); + // Print the operation, given that the name is already printed. + ... +}; + +// General folder implementation, see AbstractOperation::foldHook for more +// information. +auto foldHookFn = [](Operation * op, ArrayRef operands, + SmallVectorImpl &result) { + ... +}; + +// Returns any canonicalization pattern rewrites that the operation +// supports, for use by the canonicalization pass. +auto getCanonicalizationPatterns = + [](RewritePatternSet &results, MLIRContext *context) { + ... +} + +// Definition of the operation. +std::unique_ptr opDef = + DynamicOpDefinition::get(name, dialect, std::move(verifyFn), + std::move(parseFn), std::move(printFn), std::move(foldHookFn), + std::move(getCanonicalizationPatterns)); +``` + +Once the operation is defined, it can be registered by an `ExtensibleDialect`: + +```c++ +extensibleDialect->registerDynamicOperation(std::move(opDef)); +``` + +Note that the `Dialect` given to the operation should be the one registering +the operation. + +### Using an operation defined at runtime + +It is possible to match on an operation defined at runtime using their names: + +```c++ +if (op->getName().getStringRef() == "my_dialect.my_dynamic_op") { + ... +} +``` + +An operation defined at runtime can be created by instantiating an +`OperationState` with the operation name, and using it with a rewriter +(for instance a `PatternRewriter`) to create the operation. + +```c++ +OperationState state(location, "my_dialect.my_dynamic_op", + operands, resultTypes, attributes); + +rewriter.createOperation(state); +``` + + +### Defining a type at runtime + +Contrary to types defined in C++ or in TableGen, types defined at runtime can +only have as argument a list of `Attribute`. + +Similarily to operations, a type is defined at runtime using the class +`DynamicTypeDefinition`, which is created using the `DynamicTypeDefinition::get` +functions. A type definition requires a name, the dialect that will register the +type, and a parameter verifier. It can also define optionally a custom parser +and printer for the arguments (the type name is assumed to be already +parsed/printed). + +```c++ +// The type name, without the dialect name prefix. +StringRef name = "my_type_name"; + +// The dialect defining the type. +Dialect* dialect = ctx->getOrLoadDialect(); + +// The type verifier. +// A type defined at runtime has a list of attributes as parameters. +auto verifier = [](function_ref emitError, + ArrayRef args) { + ... +}; + +// The type parameters parser. +auto parser = [](DialectAsmParser &parser, + llvm::SmallVectorImpl &parsedParams) { + ... +}; + +// The type parameters printer. +auto printer =[](DialectAsmPrinter &printer, ArrayRef params) { + ... +}; + +std::unique_ptr typeDef = + DynamicTypeDefinition::get(std::move(name), std::move(dialect), + std::move(verifier), std::move(printer), + std::move(parser)); +``` + +If the printer and the parser are ommited, a default parser and printer is +generated with the format `!dialect.typename`. + +The type can then be registered by the `ExtensibleDialect`: + +```c++ +dialect->registerDynamicType(std::move(typeDef)); +``` + +### Parsing types defined at runtime in an extensible dialect + +`parseType` methods generated by TableGen can parse types defined at runtime, +though overriden `parseType` methods need to add the necessary support for them. + +```c++ +Type MyDialect::parseType(DialectAsmParser &parser) const { + ... + + // The type name. + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Type(); + + // Try to parse a dynamic type with 'typeTag' name. + Type dynType; + auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynType; + return Type(); + } + + ... +} +``` + +### Using a type defined at runtime + +Dynamic types are instances of `DynamicType`. It is possible to get a dynamic +type with `DynamicType::get` and `ExtensibleDialect::lookupTypeDefinition`. + +```c++ +auto typeDef = extensibleDialect->lookupTypeDefinition("my_dynamic_type"); +ArrayRef params = ...; +auto type = DynamicType::get(typeDef, params); +``` + +It is also possible to cast a `Type` known to be defined at runtime to a +`DynamicType`. + +```c++ +auto dynType = type.cast(); +auto typeDef = dynType.getTypeDef(); +auto args = dynType.getParams(); +``` + +### Defining an attribute at runtime + +Similar to types defined at runtime, attributes defined at runtime can only have +as argument a list of `Attribute`. + +Similarily to types, an attribute is defined at runtime using the class +`DynamicAttrDefinition`, which is created using the `DynamicAttrDefinition::get` +functions. An attribute definition requires a name, the dialect that will +register the attribute, and a parameter verifier. It can also define optionally +a custom parser and printer for the arguments (the attribute name is assumed to +be already parsed/printed). + +```c++ +// The attribute name, without the dialect name prefix. +StringRef name = "my_attribute_name"; + +// The dialect defining the attribute. +Dialect* dialect = ctx->getOrLoadDialect(); + +// The attribute verifier. +// An attribute defined at runtime has a list of attributes as parameters. +auto verifier = [](function_ref emitError, + ArrayRef args) { + ... +}; + +// The attribute parameters parser. +auto parser = [](DialectAsmParser &parser, + llvm::SmallVectorImpl &parsedParams) { + ... +}; + +// The attribute parameters printer. +auto printer =[](DialectAsmPrinter &printer, ArrayRef params) { + ... +}; + +std::unique_ptr attrDef = + DynamicAttrDefinition::get(std::move(name), std::move(dialect), + std::move(verifier), std::move(printer), + std::move(parser)); +``` + +If the printer and the parser are ommited, a default parser and printer is +generated with the format `!dialect.attrname`. + +The attribute can then be registered by the `ExtensibleDialect`: + +```c++ +dialect->registerDynamicAttr(std::move(typeDef)); +``` + +### Parsing attributes defined at runtime in an extensible dialect + +`parseAttribute` methods generated by TableGen can parse attributes defined at +runtime, though overriden `parseAttribute` methods need to add the necessary +support for them. + +```c++ +Attribute MyDialect::parseAttribute(DialectAsmParser &parser, + Type type) const override { + ... + // The attribute name. + StringRef attrTag; + if (failed(parser.parseKeyword(&attrTag))) + return Attribute(); + + // Try to parse a dynamic attribute with 'attrTag' name. + Attribute dynAttr; + auto parseResult = parseOptionalDynamicAttr(attrTag, parser, dynAttr); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynAttr; + return Attribute(); + } +``` + +### Using an attribute defined at runtime + +Similar to types, attributes defined at runtime are instances of `DynamicAttr`. +It is possible to get a dynamic attribute with `DynamicAttr::get` and +`ExtensibleDialect::lookupAttrDefinition`. + +```c++ +auto attrDef = extensibleDialect->lookupAttrDefinition("my_dynamic_attr"); +ArrayRef params = ...; +auto attr = DynamicAttr::get(attrDef, params); +``` + +It is also possible to cast an `Attribute` known to be defined at runtime to a +`DynamicAttr`. + +```c++ +auto dynAttr = attr.cast(); +auto attrDef = dynAttr.getAttrDef(); +auto args = dynAttr.getParams(); +``` + +## Implementation details + +### Extensible dialect + +The role of extensible dialects is to own the necessary data for defined +operations and types. They also contain the necessary accessors to easily +access them. + +In order to cast a `Dialect` back to an `ExtensibleDialect`, we implement the +`IsExtensibleDialect` interface to all `ExtensibleDialect`. The casting is done +by checking if the `Dialect` implements `IsExtensibleDialect` or not. + +### Operation representation and registration + +Operations are represented in mlir using the `AbstractOperation` class. They are +registered in dialects the same way operations defined in C++ are registered, +which is by calling `AbstractOperation::insert`. + +The only difference is that a new `TypeID` needs to be created for each +operation, since operations are not represented by a C++ class. This is done +using a `TypeIDAllocator`, which can allocate a new unique `TypeID` at runtime. + +### Type representation and registration + +Unlike operations, types need to define a C++ storage class that takes care of +type parameters. They also need to define another C++ class to access that +storage. `DynamicTypeStorage` defines the storage of types defined at runtime, +and `DynamicType` gives access to the storage, as well as defining useful +functions. A `DynamicTypeStorage` contains a list of `Attribute` type +parameters, as well as a pointer to the type definition. + +Types are registered using the `Dialect::addType` method, which expect a +`TypeID` that is generated using a `TypeIDAllocator`. The type uniquer also +register the type with the given `TypeID`. This mean that we can reuse our +single `DynamicType` with different `TypeID` to represent the different types +defined at runtime. + +Since the different types defined at runtime have different `TypeID`, it is not +possible to use `TypeID` to cast a `Type` into a `DynamicType`. Thus, similar to +`Dialect`, all `DynamicType` define a `IsDynamicTypeTrait`, so casting a `Type` +to a `DynamicType` boils down to querying the `IsDynamicTypeTrait` trait. diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -45,6 +45,17 @@ T::getTypeID()); } + /// This method is used by Dialect objects to register attributes with + /// custom TypeIDs. + /// The use of this method is in general discouraged in favor of + /// 'get(dialect)'. + static AbstractAttribute get(Dialect &dialect, + detail::InterfaceMap &&interfaceMap, + HasTraitFn &&hasTrait, TypeID typeID) { + return AbstractAttribute(dialect, std::move(interfaceMap), + std::move(hasTrait), typeID); + } + /// Return the dialect this attribute was registered to. Dialect &getDialect() const { return const_cast(dialect); } @@ -175,14 +186,22 @@ // MLIRContext. This class manages all creation and uniquing of attributes. class AttributeUniquer { public: + /// Get an uniqued instance of an attribute T. + template + static T get(MLIRContext *ctx, Args &&...args) { + return getWithTypeID(ctx, T::getTypeID(), + std::forward(args)...); + } + /// Get an uniqued instance of a parametric attribute T. + /// The use of this method is in general discouraged in favor of + /// 'get(ctx, args)'. template static typename std::enable_if_t< !std::is_same::value, T> - get(MLIRContext *ctx, Args &&...args) { + getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) { #ifndef NDEBUG - if (!ctx->getAttributeUniquer().isParametricStorageInitialized( - T::getTypeID())) + if (!ctx->getAttributeUniquer().isParametricStorageInitialized(typeID)) llvm::report_fatal_error( llvm::Twine("can't create Attribute '") + llvm::getTypeName() + "' because storage uniquer isn't initialized: the dialect was likely " @@ -190,30 +209,31 @@ "in the Dialect::initialize() method."); #endif return ctx->getAttributeUniquer().get( - [ctx](AttributeStorage *storage) { - initializeAttributeStorage(storage, ctx, T::getTypeID()); + [typeID, ctx](AttributeStorage *storage) { + initializeAttributeStorage(storage, ctx, typeID); // Execute any additional attribute storage initialization with the // context. static_cast(storage)->initialize(ctx); }, - T::getTypeID(), std::forward(args)...); + typeID, std::forward(args)...); } /// Get an uniqued instance of a singleton attribute T. + /// The use of this method is in general discouraged in favor of + /// 'get(ctx, args)'. template static typename std::enable_if_t< std::is_same::value, T> - get(MLIRContext *ctx) { + getWithTypeID(MLIRContext *ctx, TypeID typeID) { #ifndef NDEBUG - if (!ctx->getAttributeUniquer().isSingletonStorageInitialized( - T::getTypeID())) + if (!ctx->getAttributeUniquer().isSingletonStorageInitialized(typeID)) llvm::report_fatal_error( llvm::Twine("can't create Attribute '") + llvm::getTypeName() + "' because storage uniquer isn't initialized: the dialect was likely " "not loaded, or the attribute wasn't added with addAttributes<...>() " "in the Dialect::initialize() method."); #endif - return ctx->getAttributeUniquer().get(T::getTypeID()); + return ctx->getAttributeUniquer().get(typeID); } template @@ -224,23 +244,33 @@ std::forward(args)...); } + /// Register an attribute instance T with the uniquer. + template + static void registerAttribute(MLIRContext *ctx) { + registerAttribute(ctx, T::getTypeID()); + } + /// Register a parametric attribute instance T with the uniquer. + /// The use of this method is in general discouraged in favor of + /// 'registerAttribute(ctx)'. template static typename std::enable_if_t< !std::is_same::value> - registerAttribute(MLIRContext *ctx) { + registerAttribute(MLIRContext *ctx, TypeID typeID) { ctx->getAttributeUniquer() - .registerParametricStorageType(T::getTypeID()); + .registerParametricStorageType(typeID); } /// Register a singleton attribute instance T with the uniquer. + /// The use of this method is in general discouraged in favor of + /// 'registerAttribute(ctx)'. template static typename std::enable_if_t< std::is_same::value> - registerAttribute(MLIRContext *ctx) { + registerAttribute(MLIRContext *ctx, TypeID typeID) { ctx->getAttributeUniquer() .registerSingletonStorageType( - T::getTypeID(), [ctx](AttributeStorage *storage) { - initializeAttributeStorage(storage, ctx, T::getTypeID()); + typeID, [ctx, typeID](AttributeStorage *storage) { + initializeAttributeStorage(storage, ctx, typeID); }); } diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -221,6 +221,11 @@ (void)std::initializer_list{0, (addAttribute(), 0)...}; } + /// Register an attribute instance with this dialect. + /// The use of this method is in general discouraged in favor of + /// 'addAttributes()'. + void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo); + /// Enable support for unregistered operations. void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; } @@ -237,7 +242,6 @@ addAttribute(T::getTypeID(), AbstractAttribute::get(*this)); detail::AttributeUniquer::registerAttribute(context); } - void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo); /// Register a type instance with this dialect. template void addType() { diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td --- a/mlir/include/mlir/IR/DialectBase.td +++ b/mlir/include/mlir/IR/DialectBase.td @@ -94,6 +94,9 @@ // UpperCamel) and prefixed with `get` or `set` depending on if it is a getter // or setter. int emitAccessorPrefix = kEmitAccessorPrefix_Raw; + + // If this dialect can be extended at runtime with new operations or types. + bit isExtensible = 0; } #endif // DIALECTBASE_TD diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/ExtensibleDialect.h @@ -0,0 +1,556 @@ +//===- ExtensibleDialect.h - Extensible dialect -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the DynamicOpDefinition class, the DynamicTypeDefinition +// class, and the DynamicAttrDefinition class, which represent respectively +// operations, types, and attributes that can be defined at runtime. They can +// be registered at runtime to an extensible dialect, using the +// ExtensibleDialect class defined in this file. +// +// For a more complete documentation, see +// https://mlir.llvm.org/docs/ExtensibleDialects/ . +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_EXTENSIBLEDIALECT_H +#define MLIR_IR_EXTENSIBLEDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/TypeID.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir { +class AsmParser; +class AsmPrinter; +class DynamicAttr; +class DynamicType; +class ExtensibleDialect; +class MLIRContext; +class OptionalParseResult; +class ParseResult; + +namespace detail { +struct DynamicAttrStorage; +struct DynamicTypeStorage; +} // namespace detail + +//===----------------------------------------------------------------------===// +// Dynamic attribute +//===----------------------------------------------------------------------===// + +/// The definition of a dynamic attribute. A dynamic attribute is an attribute +/// that is defined at runtime, and that can be registered at runtime by an +/// extensible dialect (a dialect inheriting ExtensibleDialect). This class +/// stores the parser, the printer, and the verifier of the attribute. Each +/// dynamic attribute definition refers to one instance of this class. +class DynamicAttrDefinition : SelfOwningTypeID { +public: + using VerifierFn = llvm::unique_function, ArrayRef) const>; + using ParserFn = llvm::unique_function &parsedAttributes) + const>; + using PrinterFn = llvm::unique_function params) const>; + + /// Create a new attribute definition at runtime. The attribute is registered + /// only after passing it to the dialect using registerDynamicAttr. + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier); + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier, + ParserFn &&parser, PrinterFn &&printer); + + /// Check that the attribute parameters are valid. + LogicalResult verify(function_ref emitError, + ArrayRef params) const { + return verifier(emitError, params); + } + + /// Return the MLIRContext in which the dynamic attributes are uniqued. + MLIRContext &getContext() const { return *ctx; } + + /// Return the name of the attribute, in the format 'attrname' and + /// not 'dialectname.attrname'. + StringRef getName() const { return name; } + + /// Return the dialect defining the attribute. + ExtensibleDialect *getDialect() const { return dialect; } + +private: + DynamicAttrDefinition(StringRef name, ExtensibleDialect *dialect, + VerifierFn &&verifier, ParserFn &&parser, + PrinterFn &&printer); + + /// This constructor should only be used when we need a pointer to + /// the DynamicAttrDefinition in the verifier, the parser, or the printer. + /// The verifier, parser, and printer need thus to be initialized after the + /// constructor. + DynamicAttrDefinition(ExtensibleDialect *dialect, StringRef name); + + /// Register the concrete attribute in the attribute Uniquer. + void registerInAttrUniquer(); + + /// The name should be prefixed with the dialect name followed by '.'. + std::string name; + + /// Dialect in which this attribute is defined. + ExtensibleDialect *dialect; + + /// The attribute verifier. It checks that the attribute parameters satisfy + /// the invariants. + VerifierFn verifier; + + /// The attribute parameters parser. It parses only the parameters, and + /// expects the attribute name to have already been parsed. + ParserFn parser; + + /// The attribute parameters printer. It prints only the parameters, and + /// expects the attribute name to have already been printed. + PrinterFn printer; + + /// Context in which the concrete attributes are uniqued. + MLIRContext *ctx; + + friend ExtensibleDialect; + friend DynamicAttr; +}; + +/// This trait is used to determine if an attribute is a dynamic attribute or +/// not; it should only be implemented by dynamic attributes. +/// Note: This is only required because dynamic attributes do not have a +/// static/single TypeID. +namespace AttributeTrait { +template +class IsDynamicAttr : public TraitBase {}; +} // namespace AttributeTrait + +/// A dynamic attribute instance. This is an attribute whose definition is +/// defined at runtime. +/// It is possible to check if an attribute is a dynamic attribute using +/// `my_attr.isa()`, and getting the attribute definition of a +/// dynamic attribute using the `DynamicAttr::getAttrDef` method. +/// All dynamic attributes have the same storage, which is an array of +/// attributes. + +class DynamicAttr : public Attribute::AttrBase { +public: + // Inherit Base constructors. + using Base::Base; + + /// Return an instance of a dynamic attribute given a dynamic attribute + /// definition and attribute parameters. + /// This asserts that the attribute verifier succeeded. + static DynamicAttr get(DynamicAttrDefinition *attrDef, + ArrayRef params = {}); + + /// Return an instance of a dynamic attribute given a dynamic attribute + /// definition and attribute parameters. If the parameters provided are + /// invalid, errors are emitted using the provided location and a null object + /// is returned. + static DynamicAttr getChecked(function_ref emitError, + DynamicAttrDefinition *attrDef, + ArrayRef params = {}); + + /// Return the attribute definition of the concrete attribute. + DynamicAttrDefinition *getAttrDef(); + + /// Return the attribute parameters. + ArrayRef getParams(); + + /// Check if an attribute is a specific dynamic attribute. + static bool isa(Attribute attr, DynamicAttrDefinition *attrDef) { + return attr.getTypeID() == attrDef->getTypeID(); + } + + /// Check if an attribute is a dynamic attribute. + static bool classof(Attribute attr); + + /// Parse the dynamic attribute parameters and construct the attribute. + /// The parameters are either empty, and nothing is parsed, + /// or they are in the format '<>' or ''. + static ParseResult parse(AsmParser &parser, DynamicAttrDefinition *attrDef, + DynamicAttr &parsedAttr); + + /// Print the dynamic attribute with the format 'attrname' if there is no + /// parameters, or 'attrname'. + void print(AsmPrinter &printer); +}; + +//===----------------------------------------------------------------------===// +// Dynamic type +//===----------------------------------------------------------------------===// + +/// The definition of a dynamic type. A dynamic type is a type that is +/// defined at runtime, and that can be registered at runtime by an +/// extensible dialect (a dialect inheriting ExtensibleDialect). This class +/// stores the parser, the printer, and the verifier of the type. Each dynamic +/// type definition refers to one instance of this class. +class DynamicTypeDefinition : SelfOwningTypeID { +public: + using VerifierFn = llvm::unique_function, ArrayRef) const>; + using ParserFn = llvm::unique_function &parsedAttributes) + const>; + using PrinterFn = llvm::unique_function params) const>; + + /// Create a new dynamic type definition. The type is registered only after + /// passing it to the dialect using registerDynamicType. + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier); + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier, + ParserFn &&parser, PrinterFn &&printer); + + /// Check that the type parameters are valid. + LogicalResult verify(function_ref emitError, + ArrayRef params) const { + return verifier(emitError, params); + } + + /// Return the MLIRContext in which the dynamic types is uniqued. + MLIRContext &getContext() const { return *ctx; } + + /// Return the name of the type, in the format 'typename' and + /// not 'dialectname.typename'. + StringRef getName() const { return name; } + + /// Return the dialect defining the type. + ExtensibleDialect *getDialect() const { return dialect; } + +private: + DynamicTypeDefinition(StringRef name, ExtensibleDialect *dialect, + VerifierFn &&verifier, ParserFn &&parser, + PrinterFn &&printer); + + /// This constructor should only be used when we need a pointer to + /// the DynamicTypeDefinition in the verifier, the parser, or the printer. + /// The verifier, parser, and printer need thus to be initialized after the + /// constructor. + DynamicTypeDefinition(ExtensibleDialect *dialect, StringRef name); + + /// Register the concrete type in the type Uniquer. + void registerInTypeUniquer(); + + /// The name should be prefixed with the dialect name followed by '.'. + std::string name; + + /// Dialect in which this type is defined. + ExtensibleDialect *dialect; + + /// The type verifier. It checks that the type parameters satisfy the + /// invariants. + VerifierFn verifier; + + /// The type parameters parser. It parses only the parameters, and expects the + /// type name to have already been parsed. + ParserFn parser; + + /// The type parameters printer. It prints only the parameters, and expects + /// the type name to have already been printed. + PrinterFn printer; + + /// Context in which the concrete types are uniqued. + MLIRContext *ctx; + + friend ExtensibleDialect; + friend DynamicType; +}; + +/// This trait is used to determine if a type is a dynamic type or not; +/// it should only be implemented by dynamic types. +/// Note: This is only required because dynamic type do not have a +/// static/single TypeID. +namespace TypeTrait { +template +class IsDynamicType : public TypeTrait::TraitBase { +}; +} // namespace TypeTrait + +/// A dynamic type instance. This is a type whose definition is defined at +/// runtime. +/// It is possible to check if a type is a dynamic type using +/// `my_type.isa()`, and getting the type definition of a dynamic +/// type using the `DynamicType::getTypeDef` method. +/// All dynamic types have the same storage, which is an array of attributes. +class DynamicType + : public Type::TypeBase { +public: + // Inherit Base constructors. + using Base::Base; + + /// Return an instance of a dynamic type given a dynamic type definition and + /// type parameters. + /// This asserts that the type verifier succeeded. + static DynamicType get(DynamicTypeDefinition *typeDef, + ArrayRef params = {}); + + /// Return an instance of a dynamic type given a dynamic type definition and + /// type parameters. If the parameters provided are invalid, errors are + /// emitted using the provided location and a null object is returned. + static DynamicType getChecked(function_ref emitError, + DynamicTypeDefinition *typeDef, + ArrayRef params = {}); + + /// Return the type definition of the concrete type. + DynamicTypeDefinition *getTypeDef(); + + /// Return the type parameters. + ArrayRef getParams(); + + /// Check if a type is a specific dynamic type. + static bool isa(Type type, DynamicTypeDefinition *typeDef) { + return type.getTypeID() == typeDef->getTypeID(); + } + + /// Check if a type is a dynamic type. + static bool classof(Type type); + + /// Parse the dynamic type parameters and construct the type. + /// The parameters are either empty, and nothing is parsed, + /// or they are in the format '<>' or ''. + static ParseResult parse(AsmParser &parser, DynamicTypeDefinition *typeDef, + DynamicType &parsedType); + + /// Print the dynamic type with the format + /// 'type' or 'type<>' if there is no parameters, or 'type'. + void print(AsmPrinter &printer); +}; + +//===----------------------------------------------------------------------===// +// Dynamic operation +//===----------------------------------------------------------------------===// + +/// The definition of a dynamic op. A dynamic op is an op that is defined at +/// runtime, and that can be registered at runtime by an extensible dialect (a +/// dialect inheriting ExtensibleDialect). This class stores the functions that +/// are in the OperationName class, and in addition defines the TypeID of the op +/// that will be defined. +/// Each dynamic operation definition refers to one instance of this class. +class DynamicOpDefinition { +public: + /// Create a new op at runtime. The op is registered only after passing it to + /// the dialect using registerDynamicOp. + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::VerifyRegionInvariantsFn &&verifyRegionFn); + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, + OperationName::ParseAssemblyFn &&parseFn, + OperationName::PrintAssemblyFn &&printFn); + static std::unique_ptr + get(StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, + OperationName::ParseAssemblyFn &&parseFn, + OperationName::PrintAssemblyFn &&printFn, + OperationName::FoldHookFn &&foldHookFn, + OperationName::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn); + + /// Returns the op typeID. + TypeID getTypeID() { return typeID; } + + /// Sets the verifier function for this operation. It should emits an error + /// message and returns failure if a problem is detected, or returns success + /// if everything is ok. + void setVerifyFn(OperationName::VerifyInvariantsFn &&verify) { + verifyFn = std::move(verify); + } + + /// Sets the region verifier function for this operation. It should emits an + /// error message and returns failure if a problem is detected, or returns + /// success if everything is ok. + void setVerifyRegionFn(OperationName::VerifyRegionInvariantsFn &&verify) { + verifyRegionFn = std::move(verify); + } + + /// Sets the static hook for parsing this op assembly. + void setParseFn(OperationName::ParseAssemblyFn &&parse) { + parseFn = std::move(parse); + } + + /// Sets the static hook for printing this op assembly. + void setPrintFn(OperationName::PrintAssemblyFn &&print) { + printFn = std::move(print); + } + + /// Sets the hook implementing a generalized folder for the op. See + /// `RegisteredOperationName::foldHook` for more details + void setFoldHookFn(OperationName::FoldHookFn &&foldHook) { + foldHookFn = std::move(foldHook); + } + + /// Set the hook returning any canonicalization pattern rewrites that the op + /// supports, for use by the canonicalization pass. + void + setGetCanonicalizationPatternsFn(OperationName::GetCanonicalizationPatternsFn + &&getCanonicalizationPatterns) { + getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns); + } + +private: + DynamicOpDefinition(StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, + OperationName::ParseAssemblyFn &&parseFn, + OperationName::PrintAssemblyFn &&printFn, + OperationName::FoldHookFn &&foldHookFn, + OperationName::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn); + + /// Unique identifier for this operation. + TypeID typeID; + + /// Name of the operation. + /// The name is prefixed with the dialect name. + std::string name; + + /// Dialect defining this operation. + ExtensibleDialect *dialect; + + OperationName::VerifyInvariantsFn verifyFn; + OperationName::VerifyRegionInvariantsFn verifyRegionFn; + OperationName::ParseAssemblyFn parseFn; + OperationName::PrintAssemblyFn printFn; + OperationName::FoldHookFn foldHookFn; + OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn; + + friend ExtensibleDialect; +}; + +//===----------------------------------------------------------------------===// +// Extensible dialect +//===----------------------------------------------------------------------===// + +/// A dialect that can be extended with new operations/types/attributes at +/// runtime. +class ExtensibleDialect : public mlir::Dialect { +public: + ExtensibleDialect(StringRef name, MLIRContext *ctx, TypeID typeID); + + /// Add a new type defined at runtime to the dialect. + void registerDynamicType(std::unique_ptr &&type); + + /// Add a new attribute defined at runtime to the dialect. + void registerDynamicAttr(std::unique_ptr &&attr); + + /// Add a new operation defined at runtime to the dialect. + void registerDynamicOp(std::unique_ptr &&type); + + /// Check if the dialect is an extensible dialect. + static bool classof(const Dialect *dialect); + + /// Returns nullptr if the definition was not found. + DynamicTypeDefinition *lookupTypeDefinition(StringRef name) const { + auto it = nameToDynTypes.find(name); + if (it == nameToDynTypes.end()) + return nullptr; + return it->second; + } + + /// Returns nullptr if the definition was not found. + DynamicTypeDefinition *lookupTypeDefinition(TypeID id) const { + auto it = dynTypes.find(id); + if (it == dynTypes.end()) + return nullptr; + return it->second.get(); + } + + /// Returns nullptr if the definition was not found. + DynamicAttrDefinition *lookupAttrDefinition(StringRef name) const { + auto it = nameToDynAttrs.find(name); + if (it == nameToDynAttrs.end()) + return nullptr; + return it->second; + } + + /// Returns nullptr if the definition was not found. + DynamicAttrDefinition *lookupAttrDefinition(TypeID id) const { + auto it = dynAttrs.find(id); + if (it == dynAttrs.end()) + return nullptr; + return it->second.get(); + } + +protected: + /// Parse the dynamic type 'typeName' in the dialect 'dialect'. + /// typename should not be prefixed with the dialect name. + /// If the dynamic type does not exist, return no value. + /// Otherwise, parse it, and return the parse result. + /// If the parsing succeed, put the resulting type in 'resultType'. + OptionalParseResult parseOptionalDynamicType(StringRef typeName, + AsmParser &parser, + Type &resultType) const; + + /// If 'type' is a dynamic type, print it. + /// Returns success if the type was printed, and failure if the type was not a + /// dynamic type. + static LogicalResult printIfDynamicType(Type type, AsmPrinter &printer); + + /// Parse the dynamic attribute 'attrName' in the dialect 'dialect'. + /// attrname should not be prefixed with the dialect name. + /// If the dynamic attribute does not exist, return no value. + /// Otherwise, parse it, and return the parse result. + /// If the parsing succeed, put the resulting attribute in 'resultAttr'. + OptionalParseResult parseOptionalDynamicAttr(StringRef attrName, + AsmParser &parser, + Attribute &resultAttr) const; + + /// If 'attr' is a dynamic attribute, print it. + /// Returns success if the attribute was printed, and failure if the + /// attribute was not a dynamic attribute. + static LogicalResult printIfDynamicAttr(Attribute attr, AsmPrinter &printer); + +private: + /// The set of all dynamic types registered. + DenseMap> dynTypes; + + /// This structure allows to get in O(1) a dynamic type given its name. + llvm::StringMap nameToDynTypes; + + /// The set of all dynamic attributes registered. + DenseMap> dynAttrs; + + /// This structure allows to get in O(1) a dynamic attribute given its name. + llvm::StringMap nameToDynAttrs; + + /// Give DynamicOpDefinition access to allocateTypeID. + friend DynamicOpDefinition; + + /// Allocates a type ID to uniquify operations. + TypeID allocateTypeID() { return typeIDAllocator.allocate(); } + + /// Owns the TypeID generated at runtime for operations. + TypeIDAllocator typeIDAllocator; +}; +} // namespace mlir + +namespace llvm { +/// Provide isa functionality for ExtensibleDialect. +/// This is to override the isa functionality for Dialect. +template <> +struct isa_impl { + static inline bool doit(const ::mlir::Dialect &dialect) { + return mlir::ExtensibleDialect::classof(&dialect); + } +}; +} // namespace llvm + +#endif // MLIR_IR_EXTENSIBLEDIALECT_H diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -163,13 +163,22 @@ /// A utility class to get, or create, unique instances of types within an /// MLIRContext. This class manages all creation and uniquing of types. struct TypeUniquer { + /// Get an uniqued instance of a type T. + template + static T get(MLIRContext *ctx, Args &&...args) { + return getWithTypeID(ctx, T::getTypeID(), + std::forward(args)...); + } + /// Get an uniqued instance of a parametric type T. + /// The use of this method is in general discouraged in favor of + /// 'get(ctx, args)'. template static typename std::enable_if_t< !std::is_same::value, T> - get(MLIRContext *ctx, Args &&...args) { + getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) { #ifndef NDEBUG - if (!ctx->getTypeUniquer().isParametricStorageInitialized(T::getTypeID())) + if (!ctx->getTypeUniquer().isParametricStorageInitialized(typeID)) llvm::report_fatal_error( llvm::Twine("can't create type '") + llvm::getTypeName() + "' because storage uniquer isn't initialized: the dialect was likely " @@ -177,25 +186,27 @@ "in the Dialect::initialize() method."); #endif return ctx->getTypeUniquer().get( - [&](TypeStorage *storage) { - storage->initialize(AbstractType::lookup(T::getTypeID(), ctx)); + [&, typeID](TypeStorage *storage) { + storage->initialize(AbstractType::lookup(typeID, ctx)); }, - T::getTypeID(), std::forward(args)...); + typeID, std::forward(args)...); } /// Get an uniqued instance of a singleton type T. + /// The use of this method is in general discouraged in favor of + /// 'get(ctx, args)'. template static typename std::enable_if_t< std::is_same::value, T> - get(MLIRContext *ctx) { + getWithTypeID(MLIRContext *ctx, TypeID typeID) { #ifndef NDEBUG - if (!ctx->getTypeUniquer().isSingletonStorageInitialized(T::getTypeID())) + if (!ctx->getTypeUniquer().isSingletonStorageInitialized(typeID)) llvm::report_fatal_error( llvm::Twine("can't create type '") + llvm::getTypeName() + "' because storage uniquer isn't initialized: the dialect was likely " "not loaded, or the type wasn't added with addTypes<...>() " "in the Dialect::initialize() method."); #endif - return ctx->getTypeUniquer().get(T::getTypeID()); + return ctx->getTypeUniquer().get(typeID); } /// Change the mutable component of the given type instance in the provided @@ -208,22 +219,32 @@ std::forward(args)...); } + /// Register a type instance T with the uniquer. + template + static void registerType(MLIRContext *ctx) { + registerType(ctx, T::getTypeID()); + } + /// Register a parametric type instance T with the uniquer. + /// The use of this method is in general discouraged in favor of + /// 'registerType(ctx)'. template static typename std::enable_if_t< !std::is_same::value> - registerType(MLIRContext *ctx) { + registerType(MLIRContext *ctx, TypeID typeID) { ctx->getTypeUniquer().registerParametricStorageType( - T::getTypeID()); + typeID); } /// Register a singleton type instance T with the uniquer. + /// The use of this method is in general discouraged in favor of + /// 'registerType(ctx)'. template static typename std::enable_if_t< std::is_same::value> - registerType(MLIRContext *ctx) { + registerType(MLIRContext *ctx, TypeID typeID) { ctx->getTypeUniquer().registerSingletonStorageType( - T::getTypeID(), [&](TypeStorage *storage) { - storage->initialize(AbstractType::lookup(T::getTypeID(), ctx)); + typeID, [&ctx, typeID](TypeStorage *storage) { + storage->initialize(AbstractType::lookup(typeID, ctx)); }); } }; diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -82,6 +82,10 @@ /// type printing/parsing. bool useDefaultTypePrinterParser() const; + /// Returns true if this dialect can be extended at runtime with new + /// operations or types. + bool isExtensible() const; + // Returns whether two dialects are equal by checking the equality of the // underlying record. bool operator==(const Dialect &other) const; diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -13,6 +13,7 @@ Diagnostics.cpp Dialect.cpp Dominance.cpp + ExtensibleDialect.cpp FunctionImplementation.cpp FunctionInterfaces.cpp IntegerSet.cpp diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -0,0 +1,500 @@ +//===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/AttributeSupport.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/StorageUniquerSupport.h" +#include "mlir/Support/LogicalResult.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Dynamic types and attributes shared functions +//===----------------------------------------------------------------------===// + +/// Default parser for dynamic attribute or type parameters. +/// Parse in the format '(<>)?' or ''. +static LogicalResult +typeOrAttrParser(AsmParser &parser, SmallVectorImpl &parsedParams) { + // No parameters + if (parser.parseOptionalLess() || !parser.parseOptionalGreater()) + return success(); + + Attribute attr; + if (parser.parseAttribute(attr)) + return failure(); + parsedParams.push_back(attr); + + while (parser.parseOptionalGreater()) { + Attribute attr; + if (parser.parseComma() || parser.parseAttribute(attr)) + return failure(); + parsedParams.push_back(attr); + } + + return success(); +} + +/// Default printer for dynamic attribute or type parameters. +/// Print in the format '(<>)?' or ''. +static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef params) { + if (params.empty()) + return; + + printer << "<"; + interleaveComma(params, printer.getStream()); + printer << ">"; +} + +//===----------------------------------------------------------------------===// +// Dynamic type +//===----------------------------------------------------------------------===// + +std::unique_ptr +DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect, + VerifierFn &&verifier) { + return DynamicTypeDefinition::get(name, dialect, std::move(verifier), + typeOrAttrParser, typeOrAttrPrinter); +} + +std::unique_ptr +DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect, + VerifierFn &&verifier, ParserFn &&parser, + PrinterFn &&printer) { + return std::unique_ptr( + new DynamicTypeDefinition(name, dialect, std::move(verifier), + std::move(parser), std::move(printer))); +} + +DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef, + ExtensibleDialect *dialect, + VerifierFn &&verifier, + ParserFn &&parser, + PrinterFn &&printer) + : name(nameRef), dialect(dialect), verifier(std::move(verifier)), + parser(std::move(parser)), printer(std::move(printer)), + ctx(dialect->getContext()) {} + +DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect, + StringRef nameRef) + : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {} + +void DynamicTypeDefinition::registerInTypeUniquer() { + detail::TypeUniquer::registerType(&getContext(), getTypeID()); +} + +namespace mlir { +namespace detail { +/// Storage of DynamicType. +/// Contains a pointer to the type definition and type parameters. +struct DynamicTypeStorage : public TypeStorage { + + using KeyTy = std::pair>; + + explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef, + ArrayRef params) + : typeDef(typeDef), params(params) {} + + bool operator==(const KeyTy &key) const { + return typeDef == key.first && params == key.second; + } + + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + static DynamicTypeStorage *construct(TypeStorageAllocator &alloc, + const KeyTy &key) { + return new (alloc.allocate()) + DynamicTypeStorage(key.first, alloc.copyInto(key.second)); + } + + /// Definition of the type. + DynamicTypeDefinition *typeDef; + + /// The type parameters. + ArrayRef params; +}; +} // namespace detail +} // namespace mlir + +DynamicType DynamicType::get(DynamicTypeDefinition *typeDef, + ArrayRef params) { + auto &ctx = typeDef->getContext(); + auto emitError = detail::getDefaultDiagnosticEmitFn(&ctx); + assert(succeeded(typeDef->verify(emitError, params))); + return detail::TypeUniquer::getWithTypeID( + &ctx, typeDef->getTypeID(), typeDef, params); +} + +DynamicType +DynamicType::getChecked(function_ref emitError, + DynamicTypeDefinition *typeDef, + ArrayRef params) { + if (failed(typeDef->verify(emitError, params))) + return {}; + auto &ctx = typeDef->getContext(); + return detail::TypeUniquer::getWithTypeID( + &ctx, typeDef->getTypeID(), typeDef, params); +} + +DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; } + +ArrayRef DynamicType::getParams() { return getImpl()->params; } + +bool DynamicType::classof(Type type) { + return type.hasTrait(); +} + +ParseResult DynamicType::parse(AsmParser &parser, + DynamicTypeDefinition *typeDef, + DynamicType &parsedType) { + SmallVector params; + if (failed(typeDef->parser(parser, params))) + return failure(); + parsedType = parser.getChecked(typeDef, params); + if (!parsedType) + return failure(); + return success(); +} + +void DynamicType::print(AsmPrinter &printer) { + printer << getTypeDef()->getName(); + getTypeDef()->printer(printer, getParams()); +} + +//===----------------------------------------------------------------------===// +// Dynamic attribute +//===----------------------------------------------------------------------===// + +std::unique_ptr +DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect, + VerifierFn &&verifier) { + return DynamicAttrDefinition::get(name, dialect, std::move(verifier), + typeOrAttrParser, typeOrAttrPrinter); +} + +std::unique_ptr +DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect, + VerifierFn &&verifier, ParserFn &&parser, + PrinterFn &&printer) { + return std::unique_ptr( + new DynamicAttrDefinition(name, dialect, std::move(verifier), + std::move(parser), std::move(printer))); +} + +DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef, + ExtensibleDialect *dialect, + VerifierFn &&verifier, + ParserFn &&parser, + PrinterFn &&printer) + : name(nameRef), dialect(dialect), verifier(std::move(verifier)), + parser(std::move(parser)), printer(std::move(printer)), + ctx(dialect->getContext()) {} + +DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect, + StringRef nameRef) + : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {} + +void DynamicAttrDefinition::registerInAttrUniquer() { + detail::AttributeUniquer::registerAttribute(&getContext(), + getTypeID()); +} + +namespace mlir { +namespace detail { +/// Storage of DynamicAttr. +/// Contains a pointer to the attribute definition and attribute parameters. +struct DynamicAttrStorage : public AttributeStorage { + using KeyTy = std::pair>; + + explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef, + ArrayRef params) + : attrDef(attrDef), params(params) {} + + bool operator==(const KeyTy &key) const { + return attrDef == key.first && params == key.second; + } + + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc, + const KeyTy &key) { + return new (alloc.allocate()) + DynamicAttrStorage(key.first, alloc.copyInto(key.second)); + } + + /// Definition of the type. + DynamicAttrDefinition *attrDef; + + /// The type parameters. + ArrayRef params; +}; +} // namespace detail +} // namespace mlir + +DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef, + ArrayRef params) { + auto &ctx = attrDef->getContext(); + return detail::AttributeUniquer::getWithTypeID( + &ctx, attrDef->getTypeID(), attrDef, params); +} + +DynamicAttr +DynamicAttr::getChecked(function_ref emitError, + DynamicAttrDefinition *attrDef, + ArrayRef params) { + if (failed(attrDef->verify(emitError, params))) + return {}; + return get(attrDef, params); +} + +DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; } + +ArrayRef DynamicAttr::getParams() { return getImpl()->params; } + +bool DynamicAttr::classof(Attribute attr) { + return attr.hasTrait(); +} + +ParseResult DynamicAttr::parse(AsmParser &parser, + DynamicAttrDefinition *attrDef, + DynamicAttr &parsedAttr) { + SmallVector params; + if (failed(attrDef->parser(parser, params))) + return failure(); + parsedAttr = parser.getChecked(attrDef, params); + if (!parsedAttr) + return failure(); + return success(); +} + +void DynamicAttr::print(AsmPrinter &printer) { + printer << getAttrDef()->getName(); + getAttrDef()->printer(printer, getParams()); +} + +//===----------------------------------------------------------------------===// +// Dynamic operation +//===----------------------------------------------------------------------===// + +DynamicOpDefinition::DynamicOpDefinition( + StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, + OperationName::ParseAssemblyFn &&parseFn, + OperationName::PrintAssemblyFn &&printFn, + OperationName::FoldHookFn &&foldHookFn, + OperationName::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn) + : typeID(dialect->allocateTypeID()), + name((dialect->getNamespace() + "." + name).str()), dialect(dialect), + verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)), + parseFn(std::move(parseFn)), printFn(std::move(printFn)), + foldHookFn(std::move(foldHookFn)), + getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)) {} + +std::unique_ptr DynamicOpDefinition::get( + StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::VerifyRegionInvariantsFn &&verifyRegionFn) { + auto parseFn = [](OpAsmParser &parser, OperationState &result) { + return parser.emitError( + parser.getCurrentLocation(), + "dynamic operation do not define any parser function"); + }; + + auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) { + printer.printGenericOp(op); + }; + + return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), + std::move(verifyRegionFn), std::move(parseFn), + std::move(printFn)); +} + +std::unique_ptr DynamicOpDefinition::get( + StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, + OperationName::ParseAssemblyFn &&parseFn, + OperationName::PrintAssemblyFn &&printFn) { + auto foldHookFn = [](Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return failure(); + }; + + auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) { + }; + + return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), + std::move(verifyRegionFn), std::move(parseFn), + std::move(printFn), std::move(foldHookFn), + std::move(getCanonicalizationPatternsFn)); +} + +std::unique_ptr +DynamicOpDefinition::get(StringRef name, ExtensibleDialect *dialect, + OperationName::VerifyInvariantsFn &&verifyFn, + OperationName::VerifyInvariantsFn &&verifyRegionFn, + OperationName::ParseAssemblyFn &&parseFn, + OperationName::PrintAssemblyFn &&printFn, + OperationName::FoldHookFn &&foldHookFn, + OperationName::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn) { + return std::unique_ptr(new DynamicOpDefinition( + name, dialect, std::move(verifyFn), std::move(verifyRegionFn), + std::move(parseFn), std::move(printFn), std::move(foldHookFn), + std::move(getCanonicalizationPatternsFn))); +} + +//===----------------------------------------------------------------------===// +// Extensible dialect +//===----------------------------------------------------------------------===// + +namespace { +/// Interface that can only be implemented by extensible dialects. +/// The interface is used to check if a dialect is extensible or not. +class IsExtensibleDialect : public DialectInterface::Base { +public: + IsExtensibleDialect(Dialect *dialect) : Base(dialect) {} + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsExtensibleDialect) +}; +} // namespace + +ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx, + TypeID typeID) + : Dialect(name, ctx, typeID) { + addInterfaces(); +} + +void ExtensibleDialect::registerDynamicType( + std::unique_ptr &&type) { + DynamicTypeDefinition *typePtr = type.get(); + TypeID typeID = type->getTypeID(); + StringRef name = type->getName(); + ExtensibleDialect *dialect = type->getDialect(); + + assert(dialect == this && + "trying to register a dynamic type in the wrong dialect"); + + // If a type with the same name is already defined, fail. + auto registered = dynTypes.try_emplace(typeID, std::move(type)).second; + (void)registered; + assert(registered && "type TypeID was not unique"); + + registered = nameToDynTypes.insert({name, typePtr}).second; + (void)registered; + assert(registered && + "Trying to create a new dynamic type with an existing name"); + + auto abstractType = + AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(), + DynamicType::getHasTraitFn(), typeID); + + /// Add the type to the dialect and the type uniquer. + addType(typeID, std::move(abstractType)); + typePtr->registerInTypeUniquer(); +} + +void ExtensibleDialect::registerDynamicAttr( + std::unique_ptr &&attr) { + auto *attrPtr = attr.get(); + auto typeID = attr->getTypeID(); + auto name = attr->getName(); + auto *dialect = attr->getDialect(); + + assert(dialect == this && + "trying to register a dynamic attribute in the wrong dialect"); + + // If an attribute with the same name is already defined, fail. + auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second; + (void)registered; + assert(registered && "attribute TypeID was not unique"); + + registered = nameToDynAttrs.insert({name, attrPtr}).second; + (void)registered; + assert(registered && + "Trying to create a new dynamic attribute with an existing name"); + + auto abstractAttr = + AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(), + DynamicAttr::getHasTraitFn(), typeID); + + /// Add the type to the dialect and the type uniquer. + addAttribute(typeID, std::move(abstractAttr)); + attrPtr->registerInAttrUniquer(); +} + +void ExtensibleDialect::registerDynamicOp( + std::unique_ptr &&op) { + assert(op->dialect == this && + "trying to register a dynamic op in the wrong dialect"); + auto hasTraitFn = [](TypeID traitId) { return false; }; + + RegisteredOperationName::insert( + op->name, *op->dialect, op->typeID, std::move(op->parseFn), + std::move(op->printFn), std::move(op->verifyFn), + std::move(op->verifyRegionFn), std::move(op->foldHookFn), + std::move(op->getCanonicalizationPatternsFn), + detail::InterfaceMap::get<>(), std::move(hasTraitFn), {}); +} + +bool ExtensibleDialect::classof(const Dialect *dialect) { + return const_cast(dialect) + ->getRegisteredInterface(); +} + +OptionalParseResult ExtensibleDialect::parseOptionalDynamicType( + StringRef typeName, AsmParser &parser, Type &resultType) const { + DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName); + if (!typeDef) + return llvm::None; + + DynamicType dynType; + if (DynamicType::parse(parser, typeDef, dynType)) + return failure(); + resultType = dynType; + return success(); +} + +LogicalResult ExtensibleDialect::printIfDynamicType(Type type, + AsmPrinter &printer) { + if (auto dynType = type.dyn_cast()) { + dynType.print(printer); + return success(); + } + return failure(); +} + +OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr( + StringRef attrName, AsmParser &parser, Attribute &resultAttr) const { + DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName); + if (!attrDef) + return llvm::None; + + DynamicAttr dynAttr; + if (DynamicAttr::parse(parser, attrDef, dynAttr)) + return failure(); + resultAttr = dynAttr; + return success(); +} + +LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute, + AsmPrinter &printer) { + if (auto dynAttr = attribute.dyn_cast()) { + dynAttr.print(printer); + return success(); + } + return failure(); +} diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -102,9 +102,14 @@ int prefix = def->getValueAsInt("emitAccessorPrefix"); if (prefix < 0 || prefix > static_cast(EmitPrefix::Both)) PrintFatalError(def->getLoc(), "Invalid accessor prefix value"); + return static_cast(prefix); } +bool Dialect::isExtensible() const { + return def->getValueAsBit("isExtensible"); +} + bool Dialect::operator==(const Dialect &other) const { return def == other.def; } diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/dynamic.mlir @@ -0,0 +1,126 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics | FileCheck %s +// Verify that extensible dialects can register dynamic operations and types. + +//===----------------------------------------------------------------------===// +// Dynamic type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @succeededDynamicTypeVerifier +func @succeededDynamicTypeVerifier() { + // CHECK: %{{.*}} = "unregistered_op"() : () -> !test.singleton_dyntype + "unregistered_op"() : () -> !test.singleton_dyntype + // CHECK-NEXT: "unregistered_op"() : () -> !test.pair_dyntype + "unregistered_op"() : () -> !test.pair_dyntype + // CHECK_NEXT: %{{.*}} = "unregistered_op"() : () -> !test.pair_dyntype, !test.singleton_dyntype> + "unregistered_op"() : () -> !test.pair_dyntype, !test.singleton_dyntype> + return +} + +// ----- + +func @failedDynamicTypeVerifier() { + // expected-error@+1 {{expected 0 type arguments, but had 1}} + "unregistered_op"() : () -> !test.singleton_dyntype + return +} + +// ----- + +func @failedDynamicTypeVerifier2() { + // expected-error@+1 {{expected 2 type arguments, but had 1}} + "unregistered_op"() : () -> !test.pair_dyntype + return +} + +// ----- + +// CHECK-LABEL: func @customTypeParserPrinter +func @customTypeParserPrinter() { + // CHECK: "unregistered_op"() : () -> !test.custom_assembly_format_dyntype + "unregistered_op"() : () -> !test.custom_assembly_format_dyntype + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Dynamic attribute +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @succeededDynamicAttributeVerifier +func @succeededDynamicAttributeVerifier() { + // CHECK: "unregistered_op"() {test_attr = #test.singleton_dynattr} : () -> () + "unregistered_op"() {test_attr = #test.singleton_dynattr} : () -> () + // CHECK-NEXT: "unregistered_op"() {test_attr = #test.pair_dynattr<3 : i32, 5 : i32>} : () -> () + "unregistered_op"() {test_attr = #test.pair_dynattr<3 : i32, 5 : i32>} : () -> () + // CHECK_NEXT: "unregistered_op"() {test_attr = #test.pair_dynattr<3 : i32, 5 : i32>} : () -> () + "unregistered_op"() {test_attr = #test.pair_dynattr<#test.pair_dynattr<3 : i32, 5 : i32>, f64>} : () -> () + return +} + +// ----- + +func @failedDynamicAttributeVerifier() { + // expected-error@+1 {{expected 0 attribute arguments, but had 1}} + "unregistered_op"() {test_attr = #test.singleton_dynattr} : () -> () + return +} + +// ----- + +func @failedDynamicAttributeVerifier2() { + // expected-error@+1 {{expected 2 attribute arguments, but had 1}} + "unregistered_op"() {test_attr = #test.pair_dynattr : () -> () + return +} + +// ----- + +// CHECK-LABEL: func @customAttributeParserPrinter +func @customAttributeParserPrinter() { + // CHECK: "unregistered_op"() {test_attr = #test.custom_assembly_format_dynattr} : () -> () + "unregistered_op"() {test_attr = #test.custom_assembly_format_dynattr} : () -> () + return +} + +//===----------------------------------------------------------------------===// +// Dynamic op +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @succeededDynamicOpVerifier +func @succeededDynamicOpVerifier(%a: f32) { + // CHECK: "test.generic_dynamic_op"() : () -> () + // CHECK-NEXT: %{{.*}} = "test.generic_dynamic_op"(%{{.*}}) : (f32) -> f64 + // CHECK-NEXT: %{{.*}}:2 = "test.one_operand_two_results"(%{{.*}}) : (f32) -> (f64, f64) + "test.generic_dynamic_op"() : () -> () + "test.generic_dynamic_op"(%a) : (f32) -> f64 + "test.one_operand_two_results"(%a) : (f32) -> (f64, f64) + return +} + +// ----- + +func @failedDynamicOpVerifier() { + // expected-error@+1 {{expected 1 operand, but had 0}} + "test.one_operand_two_results"() : () -> (f64, f64) + return +} + +// ----- + +func @failedDynamicOpVerifier2(%a: f32) { + // expected-error@+1 {{expected 2 results, but had 0}} + "test.one_operand_two_results"(%a) : (f32) -> () + return +} + +// ----- + +// CHECK-LABEL: func @customOpParserPrinter +func @customOpParserPrinter() { + // CHECK: test.custom_parser_printer_dynamic_op custom_keyword + test.custom_parser_printer_dynamic_op custom_keyword + return +} diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -15,12 +15,14 @@ #include "TestDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/Types.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" +#include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace test; @@ -216,6 +218,74 @@ #define GET_ATTRDEF_CLASSES #include "TestAttrDefs.cpp.inc" +//===----------------------------------------------------------------------===// +// Dynamic Attributes +//===----------------------------------------------------------------------===// + +/// Define a singleton dynamic attribute. +static std::unique_ptr +getSingletonDynamicAttr(TestDialect *testDialect) { + return DynamicAttrDefinition::get( + "singleton_dynattr", testDialect, + [](function_ref emitError, + ArrayRef args) { + if (!args.empty()) { + emitError() << "expected 0 attribute arguments, but had " + << args.size(); + return failure(); + } + return success(); + }); +} + +/// Define a dynamic attribute representing a pair or attributes. +static std::unique_ptr +getPairDynamicAttr(TestDialect *testDialect) { + return DynamicAttrDefinition::get( + "pair_dynattr", testDialect, + [](function_ref emitError, + ArrayRef args) { + if (args.size() != 2) { + emitError() << "expected 2 attribute arguments, but had " + << args.size(); + return failure(); + } + return success(); + }); +} + +static std::unique_ptr +getCustomAssemblyFormatDynamicAttr(TestDialect *testDialect) { + auto verifier = [](function_ref emitError, + ArrayRef args) { + if (args.size() != 2) { + emitError() << "expected 2 attribute arguments, but had " << args.size(); + return failure(); + } + return success(); + }; + + auto parser = [](AsmParser &parser, + llvm::SmallVectorImpl &parsedParams) { + Attribute leftAttr, rightAttr; + if (parser.parseLess() || parser.parseAttribute(leftAttr) || + parser.parseColon() || parser.parseAttribute(rightAttr) || + parser.parseGreater()) + return failure(); + parsedParams.push_back(leftAttr); + parsedParams.push_back(rightAttr); + return success(); + }; + + auto printer = [](AsmPrinter &printer, ArrayRef params) { + printer << "<" << params[0] << ":" << params[1] << ">"; + }; + + return DynamicAttrDefinition::get("custom_assembly_format_dynattr", + testDialect, std::move(verifier), + std::move(parser), std::move(printer)); +} + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// @@ -225,4 +295,7 @@ #define GET_ATTRDEF_LIST #include "TestAttrDefs.cpp.inc" >(); + registerDynamicAttr(getSingletonDynamicAttr(this)); + registerDynamicAttr(getPairDynamicAttr(this)); + registerDynamicAttr(getCustomAssemblyFormatDynamicAttr(this)); } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -24,6 +24,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/RegionKindInterface.h" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Verifier.h" @@ -207,6 +208,58 @@ } // namespace +//===----------------------------------------------------------------------===// +// Dynamic operations +//===----------------------------------------------------------------------===// + +std::unique_ptr getGenericDynamicOp(TestDialect *dialect) { + return DynamicOpDefinition::get( + "generic_dynamic_op", dialect, [](Operation *op) { return success(); }, + [](Operation *op) { return success(); }); +} + +std::unique_ptr +getOneOperandTwoResultsDynamicOp(TestDialect *dialect) { + return DynamicOpDefinition::get( + "one_operand_two_results", dialect, + [](Operation *op) { + if (op->getNumOperands() != 1) { + op->emitOpError() + << "expected 1 operand, but had " << op->getNumOperands(); + return failure(); + } + if (op->getNumResults() != 2) { + op->emitOpError() + << "expected 2 results, but had " << op->getNumResults(); + return failure(); + } + return success(); + }, + [](Operation *op) { return success(); }); +} + +std::unique_ptr +getCustomParserPrinterDynamicOp(TestDialect *dialect) { + auto verifier = [](Operation *op) { + if (op->getNumOperands() == 0 && op->getNumResults() == 0) + return success(); + op->emitError() << "operation should have no operands and no results"; + return failure(); + }; + auto regionVerifier = [](Operation *op) { return success(); }; + + auto parser = [](OpAsmParser &parser, OperationState &state) { + return parser.parseKeyword("custom_keyword"); + }; + + auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) { + printer << op->getName() << " custom_keyword"; + }; + + return DynamicOpDefinition::get("custom_parser_printer_dynamic_op", dialect, + verifier, regionVerifier, parser, printer); +} + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// @@ -241,6 +294,10 @@ #define GET_OP_LIST #include "TestOps.cpp.inc" >(); + registerDynamicOp(getGenericDynamicOp(this)); + registerDynamicOp(getOneOperandTwoResultsDynamicOp(this)); + registerDynamicOp(getCustomParserPrinterDynamicOp(this)); + addInterfaces(); allowUnknownOperations(); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td --- a/mlir/test/lib/Dialect/Test/TestDialect.td +++ b/mlir/test/lib/Dialect/Test/TestDialect.td @@ -23,6 +23,8 @@ let hasOperationInterfaceFallback = 1; let hasNonDefaultDestructor = 1; let useDefaultTypePrinterParser = 0; + let useDefaultAttributePrinterParser = 1; + let isExtensible = 1; let dependentDialects = ["::mlir::DLTIDialect"]; let extraClassDeclaration = [{ @@ -43,6 +45,10 @@ // Storage for a custom fallback interface. void *fallbackEffectOpInterfaces; + ::mlir::Type parseTestType(::mlir::AsmParser &parser, + ::llvm::SetVector<::mlir::Type> &stack) const; + void printTestType(::mlir::Type type, ::mlir::AsmPrinter &printer, + ::llvm::SetVector<::mlir::Type> &stack) const; }]; } diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/Types.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/SetVector.h" @@ -315,6 +316,72 @@ return 1; } +//===----------------------------------------------------------------------===// +// Dynamic Types +//===----------------------------------------------------------------------===// + +/// Define a singleton dynamic type. +static std::unique_ptr +getSingletonDynamicType(TestDialect *testDialect) { + return DynamicTypeDefinition::get( + "singleton_dyntype", testDialect, + [](function_ref emitError, + ArrayRef args) { + if (!args.empty()) { + emitError() << "expected 0 type arguments, but had " << args.size(); + return failure(); + } + return success(); + }); +} + +/// Define a dynamic type representing a pair. +static std::unique_ptr +getPairDynamicType(TestDialect *testDialect) { + return DynamicTypeDefinition::get( + "pair_dyntype", testDialect, + [](function_ref emitError, + ArrayRef args) { + if (args.size() != 2) { + emitError() << "expected 2 type arguments, but had " << args.size(); + return failure(); + } + return success(); + }); +} + +static std::unique_ptr +getCustomAssemblyFormatDynamicType(TestDialect *testDialect) { + auto verifier = [](function_ref emitError, + ArrayRef args) { + if (args.size() != 2) { + emitError() << "expected 2 type arguments, but had " << args.size(); + return failure(); + } + return success(); + }; + + auto parser = [](AsmParser &parser, + llvm::SmallVectorImpl &parsedParams) { + Attribute leftAttr, rightAttr; + if (parser.parseLess() || parser.parseAttribute(leftAttr) || + parser.parseColon() || parser.parseAttribute(rightAttr) || + parser.parseGreater()) + return failure(); + parsedParams.push_back(leftAttr); + parsedParams.push_back(rightAttr); + return success(); + }; + + auto printer = [](AsmPrinter &printer, ArrayRef params) { + printer << "<" << params[0] << ":" << params[1] << ">"; + }; + + return DynamicTypeDefinition::get("custom_assembly_format_dyntype", + testDialect, std::move(verifier), + std::move(parser), std::move(printer)); +} + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// @@ -332,9 +399,14 @@ #include "TestTypeDefs.cpp.inc" >(); SimpleAType::attachInterface(*getContext()); + + registerDynamicType(getSingletonDynamicType(this)); + registerDynamicType(getPairDynamicType(this)); + registerDynamicType(getCustomAssemblyFormatDynamicType(this)); } -static Type parseTestType(AsmParser &parser, SetVector &stack) { +Type TestDialect::parseTestType(AsmParser &parser, + SetVector &stack) const { StringRef typeTag; if (failed(parser.parseKeyword(&typeTag))) return Type(); @@ -346,6 +418,16 @@ return genType; } + { + Type dynType; + auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynType; + return Type(); + } + } + if (typeTag != "test_rec") { parser.emitError(parser.getNameLoc()) << "unknown type!"; return Type(); @@ -381,11 +463,14 @@ return parseTestType(parser, stack); } -static void printTestType(Type type, AsmPrinter &printer, - SetVector &stack) { +void TestDialect::printTestType(Type type, AsmPrinter &printer, + SetVector &stack) const { if (succeeded(generatedTypePrinter(type, printer))) return; + if (succeeded(printIfDynamicType(type, printer))) + return; + auto rec = type.cast(); printer << "test_rec<" << rec.getName(); if (!stack.contains(rec)) { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -689,6 +689,8 @@ /// The code block for default attribute parser/printer dispatch boilerplate. /// {0}: the dialect fully qualified class name. +/// {1}: the optional code for the dynamic attribute parser dispatch. +/// {2}: the optional code for the dynamic attribute printer dispatch. static const char *const dialectDefaultAttrPrinterParserDispatch = R"( /// Parse an attribute registered to this dialect. ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser, @@ -703,6 +705,7 @@ if (parseResult.hasValue()) return attr; } + {1} parser.emitError(typeLoc) << "unknown attribute `" << attrTag << "` in dialect `" << getNamespace() << "`"; return {{}; @@ -712,11 +715,33 @@ ::mlir::DialectAsmPrinter &printer) const {{ if (::mlir::succeeded(generatedAttributePrinter(attr, printer))) return; + {2} } )"; +/// The code block for dynamic attribute parser dispatch boilerplate. +static const char *const dialectDynamicAttrParserDispatch = R"( + { + ::mlir::Attribute genAttr; + auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr); + if (parseResult.hasValue()) { + if (::mlir::succeeded(parseResult.getValue())) + return genAttr; + return Attribute(); + } + } +)"; + +/// The code block for dynamic type printer dispatch boilerplate. +static const char *const dialectDynamicAttrPrinterDispatch = R"( + if (::mlir::succeeded(printIfDynamicAttr(attr, printer))) + return; +)"; + /// The code block for default type parser/printer dispatch boilerplate. /// {0}: the dialect fully qualified class name. +/// {1}: the optional code for the dynamic type parser dispatch. +/// {2}: the optional code for the dynamic type printer dispatch. static const char *const dialectDefaultTypePrinterParserDispatch = R"( /// Parse a type registered to this dialect. ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{ @@ -728,6 +753,7 @@ auto parseResult = generatedTypeParser(parser, mnemonic, genType); if (parseResult.hasValue()) return genType; + {1} parser.emitError(typeLoc) << "unknown type `" << mnemonic << "` in dialect `" << getNamespace() << "`"; return {{}; @@ -737,9 +763,28 @@ ::mlir::DialectAsmPrinter &printer) const {{ if (::mlir::succeeded(generatedTypePrinter(type, printer))) return; + {2} } )"; +/// The code block for dynamic type parser dispatch boilerplate. +static const char *const dialectDynamicTypeParserDispatch = R"( + { + auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType); + if (parseResult.hasValue()) { + if (::mlir::succeeded(parseResult.getValue())) + return genType; + return Type(); + } + } +)"; + +/// The code block for dynamic type printer dispatch boilerplate. +static const char *const dialectDynamicTypePrinterDispatch = R"( + if (::mlir::succeeded(printIfDynamicType(type, printer))) + return; +)"; + /// Emit the dialect printer/parser dispatcher. User's code should call these /// functions from their dialect's print/parse methods. void DefGenerator::emitParsePrintDispatch(ArrayRef defs) { @@ -839,16 +884,30 @@ if (valueType == "Attribute" && needsDialectParserPrinter && firstDialect.useDefaultAttributePrinterParser()) { NamespaceEmitter nsEmitter(os, firstDialect); - os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, - firstDialect.getCppClassName()); + if (firstDialect.isExtensible()) { + os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, + firstDialect.getCppClassName(), + dialectDynamicAttrParserDispatch, + dialectDynamicAttrPrinterDispatch); + } else { + os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, + firstDialect.getCppClassName(), "", ""); + } } // Emit the default parser/printer for Types if the dialect asked for it. if (valueType == "Type" && needsDialectParserPrinter && firstDialect.useDefaultTypePrinterParser()) { NamespaceEmitter nsEmitter(os, firstDialect); - os << llvm::formatv(dialectDefaultTypePrinterParserDispatch, - firstDialect.getCppClassName()); + if (firstDialect.isExtensible()) { + os << llvm::formatv(dialectDefaultTypePrinterParserDispatch, + firstDialect.getCppClassName(), + dialectDynamicTypeParserDispatch, + dialectDynamicTypePrinterDispatch); + } else { + os << llvm::formatv(dialectDefaultTypePrinterParserDispatch, + firstDialect.getCppClassName(), "", ""); + } } return false; diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -87,8 +87,9 @@ /// /// {0}: The name of the dialect class. /// {1}: The dialect namespace. +/// {2}: The dialect parent class. static const char *const dialectDeclBeginStr = R"( -class {0} : public ::mlir::Dialect { +class {0} : public ::mlir::{2} { explicit {0}(::mlir::MLIRContext *context); void initialize(); @@ -189,7 +190,10 @@ // Emit the start of the decl. std::string cppName = dialect.getCppClassName(); - os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName()); + StringRef superClassName = + dialect.isExtensible() ? "ExtensibleDialect" : "Dialect"; + os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), + superClassName); // Check for any attributes/types registered to this dialect. If there are, // add the hooks for parsing/printing. @@ -250,9 +254,10 @@ /// {0}: The name of the dialect class. /// {1}: initialization code that is emitted in the ctor body before calling /// initialize(). +/// {2}: The dialect parent class. static const char *const dialectConstructorStr = R"( {0}::{0}(::mlir::MLIRContext *context) - : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{ + : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{ {1} initialize(); } @@ -287,8 +292,10 @@ } // Emit the constructor and destructor. + StringRef superClassName = + dialect.isExtensible() ? "ExtensibleDialect" : "Dialect"; os << llvm::formatv(dialectConstructorStr, cppClassName, - dependentDialectRegistrations); + dependentDialectRegistrations, superClassName); if (!dialect.hasNonDefaultDestructor()) os << llvm::formatv(dialectDestructorStr, cppClassName); }