diff --git a/mlir/docs/ExtensibleDialects.md b/mlir/docs/ExtensibleDialects.md new file mode 100644 --- /dev/null +++ b/mlir/docs/ExtensibleDialects.md @@ -0,0 +1,265 @@ +# Extensible dialects + +This file documents the design and API of the extensible dialects. Extensible +dialects are dialects that can be extended with new operations and types defined +at runtime. This let users define dialects with metaprogramming, or from another +language, without having to recompile C++ code. + +[TOC] + +## Usage + +### Making a dialect extensible at runtime + +Dialects defined in C++ can be extended with new operations and types at runtime +by making them inherit from `mlir::ExtensibleDialect` instead of `mlir::Dialect` +(note that `ExtensibleDialect` inherits from `Dialect`). The `ExtensibleDialect` +class contains the necessary fields and methods to extend the dialect at +runtime with new types and operations. + +```c++ +class MyDialect : public mlir::ExtensibleDialect { + ... +} +``` + +For dialects defined in TableGen, this is done by setting the `isExtensible` +flag to `1`. + +```tablegen +def Test_Dialect : Dialect { + let isExtensible = 1; + ... +} +``` + +Extensible `Dialect` can be casted back to `ExtensibleDialect` using +`llvm::dyn_cast`, or `llvm::cast`: + +```c++ +if (auto extensibleDialect = llvm::dyn_cast(dialect)) { + ... +} +``` + +### Defining an operation at runtime + +The `DynamicOpDefinition` class represents the definition of an operation +defined at runtime. It is created using the `DynamicOpDefinition::get` +functions. An operation defined at runtime needs to define a name, a dialect in +which the operation will be registered in, an operation verifier. It can also +define optionally a custom parser and a printer, an operation fold hook, and a +function that returns the canonicalization patterns. + +```c++ +// The operation name, without the dialect name prefix. +StringRef name = "my_operation_name"; + +// The dialect defining the operation. +Dialect* dialect = ctx->getOrLoadDialect(); + +// Operation verifier definition. +AbstractOperation::VerifyInvariantsFn verifyFn = [](Operation* op) { + // Logic for the operation verification. + ... +} + +// Parser function definition. +AbstractOperation::ParseAssemblyFn parseFn = + [](OpAsmParser &parser, OperationState &state) { + // Parse the operation, given that the name is already parsed. + ... +}; + +// Printer function +auto printFn = [](Operation *op, OpAsmPrinter &printer) { + printer << op->getName(); + // Print the operation, given that the name is already printed. + ... +}; + +// General folder implementation, see AbstractOperation::foldHook for more +// information. +auto foldHookFn = [](Operation * op, ArrayRef operands, + SmallVectorImpl &result) { + ... +}; + +// Returns any canonicalization pattern rewrites that the operation +// supports, for use by the canonicalization pass. +auto getCanonicalizationPatterns = + [](RewritePatternSet &results, MLIRContext *context) { + ... +} + +// Definition of the operation. +std::unique_ptr opDef = + DynamicOpDefinition::get(name, dialect, std::move(verifyFn), + std::move(parseFn), std::move(printFn), std::move(foldHookFn), + std::move(getCanonicalizationPatterns)); +``` + +Once the operation is defined, it can be registered by an `ExtensibleDialect`: + +```c++ +extensibleDialect->addDynamicOperation(std::move(opDef)); +``` + +Note that the `Dialect` given to the operation should be the one registering +the operation. + +### Using an operation defined at runtime + +It is possible to match on an operation defined at runtime using their names: + +```c++ +if (op->getName().getStringRef() == "my_dialect.my_dynamic_op") { + ... +} +``` + +An operation defined at runtime can be created by creating an `OperationState` +with its name, and passing it to a rewriter such as a `PatternRewriter`. + +```c++ +OperationState state(location, "my_dialect.my_dynamic_op", + operands, resultTypes, attributes); + +rewriter.createOperation(state); +``` + + +### Defining a type at runtime + +Contrary to types defined in C++ or in TableGen, types defined at runtime can +only have as argument a list of `Attribute`. + +Similarily to operations, a type is defined at runtime using the class +`DynamicTypeDefinition`, which is created using the `DynamicTypeDefinition::get` +functions. A type definition requires a name, the dialect that will register the +type, and a parameter verifier. It can also define optionally a custom parser +and printer for the arguments (the type name is assumed to be already +parsed/printed). + +```c++ +// The type name, without the dialect name prefix. +StringRef name = "my_type_name"; + +// The dialect defining the operation. +Dialect* dialect = ctx->getOrLoadDialect(); + +// The type verifier. A type defined at runtime has a list of attributes as parameters. +auto verifier = [](function_ref emitError, + ArrayRef args) { + ... +}; + +// The type parameters parser. +auto parser = [](DialectAsmParser &parser, + llvm::SmallVectorImpl &parsedParams) { + ... +}; + +// The type parameters printer. +auto printer =[](DialectAsmPrinter &printer, ArrayRef params) { + ... +}; + +std::unique_ptr typeDef = + DynamicTypeDefinition::get(std::move(name), std::move(dialect), + std::move(verifier), std::move(printer), std::move(parser)); +``` + +If the printer and the parser are ommited, a default parser and printer is +generated with the format `!dialect.typename`. + +The type can then be registered by the `ExtensibleDialect`: + +```c++ +dialect->addDynamicType(std::move(typeDef)); +``` + +### Parsing types defined at runtime in an extensible dialect + +In order to parse types defined at runtime, it is necessary to add in the +`MyDialect::parseType` method the necessary support. + +```c++ +Type MyDialect::parseType(DialectAsmParser &parser) const { + ... + // The type name. + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Type(); + + // Try to parse a dynamic type with 'typeTag' name. + Type dynType; + auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynType; + return Type(); + } +``` + +### Using a type defined at runtime + +Dynamic types are instances of `DynamicType`. It is possible to get a dynamic +type with `DynamicType::get` and `ExtensibleDialect::lookupTypeDefinition`. + +```c++ +auto typeDef = extensibleDialect->lookupTypeDefinition("my_dynamic_type"); +ArrayRef params = ...; +auto type = DynamicType::get(typeDef, params); +``` + +It is also possible to cast a `Type` known to be defined at runtime to a +`DynamicType`. + +```c++ +auto dynType = type.cast(); +auto typeDef = dynType.getTypeDef(); +auto args = dynType.getParams(); +``` + +## Implementation details + +### Extensible dialect + +The role of extensible dialects is to own the necessary data for defined +operations and types. They also contain the necessary accessors to easily +access them. + +In order to cast a `Dialect` back to an `ExtensibleDialect`, we implement the +`IsExtensibleDialect` interface to all `ExtensibleDialect`. The casting is done +by checking if the `Dialect` implements `IsExtensibleDialect` or not. + +### Operation representation and registration + +Operations are represented in mlir using the `AbstractOperation` class. They are +registered in dialects the same way operations defined in C++ are registered, +which is by calling `AbstractOperation::insert`. + +The only difference is that a new `TypeID` needs to be created for each +operation, since operations are not represented by a C++ class. This is done +using a `TypeIDAllocator`, which can allocate a new unique `TypeID` at runtime. + +### Type representation and registration + +Unlike operations, types need to define a C++ storage class that takes care of +type parameters. They also need to define another C++ class to access that +storage. `DynamicTypeStorage` defines the storage of types defined at runtime, +and `DynamicType` gives access to the storage, as well as defining useful +functions. A `DynamicTypeStorage` contains a list of `Attribute` type +parameters, as well as a pointer to the type definition. + +Types are registered using the `Dialect::addType` method, which expect a +`TypeID` that is generated using a `TypeIDAllocator`. The type uniquer also +register the type with the given `TypeID`. This mean that we can reuse our +single `DynamicType` with different `TypeID` to represent the different types +defined at runtime. + +Since the different types defined at runtime have different `TypeID`, it is not +possible to use `TypeID` to cast a `Type` into a `DynamicType`. Thus, similar to +`Dialect`, all `DynamicType` define a `IsDynamicTypeTrait`, so casting a `Type` +to a `DynamicType` boils down to querying the `IsDynamicTypeTrait` trait. diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -38,6 +38,11 @@ mlir_tablegen(SubElementTypeInterfaces.cpp.inc -gen-type-interface-defs) add_public_tablegen_target(MLIRSubElementInterfacesIncGen) +set(LLVM_TARGET_DEFINITIONS IsDynamicInterfaces.td) +mlir_tablegen(IsDynamicTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(IsDynamicTypeInterfaces.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(MLIRIsDynamicTypeInterfaces) + set(LLVM_TARGET_DEFINITIONS TensorEncoding.td) mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls) mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs) diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/ExtensibleDialect.h @@ -0,0 +1,315 @@ +//===- ExtensibleDialect.h - Extensible dialect -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Dialects that can register new operations/types/attributes at runtime. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_EXTENSIBLEDIALECT_H +#define MLIR_IR_EXTENSIBLEDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir { +class MLIRContext; +class DialectAsmPrinter; +class DialectAsmParser; +class ParseResult; +class OptionalParseResult; +class ExtensibleDialect; + +namespace detail { +struct DynamicTypeStorage; +} // namespace detail + +//===----------------------------------------------------------------------===// +// Dynamic type +//===----------------------------------------------------------------------===// + +class DynamicType; + +/// This is the definition of a dynamic type. It stores the parser, +/// the printer, and the verifier. +/// Each dynamic type instance refer to one instance of this class. +class DynamicTypeDefinition { +public: + using VerifierFn = llvm::unique_function, ArrayRef) const>; + using ParserFn = llvm::unique_function &parsedAttributes) const>; + using PrinterFn = llvm::unique_function params) const>; + + static std::unique_ptr + get(llvm::StringRef name, Dialect *dialect, VerifierFn &&verifier); + + static std::unique_ptr + get(llvm::StringRef name, Dialect *dialect, VerifierFn &&verifier, + ParserFn &&parser, PrinterFn &&printer); + + /// Check that the type parameters are valid. + LogicalResult verify(function_ref emitError, + ArrayRef params) const { + return verifier(emitError, params); + } + + /// Get the unique identifier associated with the concrete type. + TypeID getTypeID() const { return typeID; } + + /// Get the MLIRContext in which the dynamic types are uniqued. + MLIRContext &getContext() const { return *ctx; } + + /// Get the name of the type, in the format 'typename' and + /// not 'dialectname.typename'. + StringRef getName() const { return name; } + + /// Get the dialect defining the type. + Dialect *getDialect() const { return dialect; } + +private: + DynamicTypeDefinition(llvm::StringRef name, Dialect *dialect, + VerifierFn &&verifier, ParserFn &&parser, + PrinterFn &&printer); + + /// This constructor should only be used when we need a pointer to + /// the DynamicTypeDefinition in the verifier, the parser, or the printer. + /// The verifier, parser, and printer need thus to be initialized after the + /// constructor. + DynamicTypeDefinition(Dialect *dialect, llvm::StringRef name); + + /// Register the concrete type in the type Uniquer. + void registerInTypeUniquer(); + + /// The name should be prefixed with the dialect name followed by '.'. + std::string name; + + /// Dialect in which this type is defined. + Dialect *dialect; + + /// Verifier for the type parameters. + VerifierFn verifier; + + /// Parse the type parameters. + ParserFn parser; + + /// Print the type parameters. + PrinterFn printer; + + /// Unique identifier for the concrete type. + TypeID typeID; + + /// Context in which the concrete types are uniqued. + MLIRContext *ctx; + + friend ExtensibleDialect; + friend DynamicType; +}; + +/// A type defined at runtime. +/// Each DynamicType instance represent a different dynamic type. +class DynamicType + : public Type::TypeBase { +public: + // Inherit Base constructors. + using Base::Base; + + /// Get an instance of a dynamic type given a dynamic type definition and + /// type parameters. + /// This function does not call the type verifier. + static DynamicType get(DynamicTypeDefinition *typeDef, + ArrayRef params = {}); + + /// Get an instance of a dynamic type given a dynamic type definition and type + /// parameters. + /// This function also call the verifier to check if the parameters are valid. + static DynamicType getChecked(function_ref emitError, + DynamicTypeDefinition *typeDef, + ArrayRef params = {}); + + /// Get the type definition of the concrete type. + DynamicTypeDefinition *getTypeDef(); + + /// Get the type parameters. + ArrayRef getParams(); + + /// Check if a type is a specific dynamic type. + static bool isa(Type type, DynamicTypeDefinition *typeDef) { + return type.getTypeID() == typeDef->getTypeID(); + } + + /// Check if a type is a dynamic type. + static bool classof(Type type); + + /// Parse the dynamic type parameters and construct the type. + /// The parameters are either empty, and nothing is parsed, + /// or they are in the format '<>' or ''. + static ParseResult parse(DialectAsmParser &parser, + DynamicTypeDefinition *typeDef, + DynamicType &parsedType); + + /// Print the dynamic type with the format + /// 'type' or 'type<>' if there is no parameters, or 'type'. + void print(DialectAsmPrinter &printer); +}; + +//===----------------------------------------------------------------------===// +// Dynamic operation +//===----------------------------------------------------------------------===// + +/// The definition of a dynamic operation. +/// It contains the name of the operation, its owning dialect, a verifier, +/// a printer, and parser. +class DynamicOpDefinition { +public: + static std::unique_ptr + get(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn); + + static std::unique_ptr + get(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn, + AbstractOperation::ParseAssemblyFn &&parseFn, + AbstractOperation::PrintAssemblyFn &&printFn); + + static std::unique_ptr + get(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn, + AbstractOperation::ParseAssemblyFn &&parseFn, + AbstractOperation::PrintAssemblyFn &&printFn, + AbstractOperation::FoldHookFn &&foldHookFn, + AbstractOperation::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn); + + void setVerifyFn(AbstractOperation::VerifyInvariantsFn &&verify) { + verifyFn = std::move(verify); + } + + void setParseFn(AbstractOperation::ParseAssemblyFn &&parse) { + parseFn = std::move(parse); + } + + void setPrintFn(AbstractOperation::PrintAssemblyFn &&print) { + printFn = std::move(print); + } + + void setFoldHookFn(AbstractOperation::FoldHookFn &&foldHook) { + foldHookFn = std::move(foldHook); + } + + void setGetCanonicalizationPatternsFn( + AbstractOperation::GetCanonicalizationPatternsFn + &&getCanonicalizationPatterns) { + getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns); + } + +private: + DynamicOpDefinition(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn, + AbstractOperation::ParseAssemblyFn &&parseFn, + AbstractOperation::PrintAssemblyFn &&printFn, + AbstractOperation::FoldHookFn &&foldHookFn, + AbstractOperation::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn); + + /// Unique identifier for this operation. + TypeID typeID; + + /// Name of the operation. + /// The name is prefixed with the dialect name. + std::string name; + + /// Dialect defining this operation. + Dialect *dialect; + + AbstractOperation::VerifyInvariantsFn verifyFn; + AbstractOperation::ParseAssemblyFn parseFn; + AbstractOperation::PrintAssemblyFn printFn; + AbstractOperation::FoldHookFn foldHookFn; + AbstractOperation::GetCanonicalizationPatternsFn + getCanonicalizationPatternsFn; + + friend ExtensibleDialect; +}; + +//===----------------------------------------------------------------------===// +// Extensible dialect +//===----------------------------------------------------------------------===// + +/// A dialect that can be extended with new operations/types/attributes at +/// runtime. +class ExtensibleDialect : public mlir::Dialect { +public: + ExtensibleDialect(StringRef name, MLIRContext *ctx, TypeID typeID); + + /// Add a new type defined at runtime to the dialect. + void addDynamicType(std::unique_ptr &&type); + + /// Add a new operation defined at runtime to the dialect. + void addDynamicOp(std::unique_ptr &&type); + + /// Check if the dialect is an extensible dialect. + static bool classof(const mlir::Dialect *dialect); + + /// Returns nullptr if the definition was not found. + DynamicTypeDefinition *lookupTypeDefinition(StringRef name) const { + auto it = nameToDynTypes.find(name); + if (it == nameToDynTypes.end()) + return nullptr; + return it->second; + } + + /// Returns nullptr if the definition was not found. + DynamicTypeDefinition *lookupTypeDefinition(TypeID id) const { + auto it = dynTypes.find(id); + if (it == dynTypes.end()) + return nullptr; + return it->second.get(); + } + +protected: + /// Parse the dynamic type 'typeName' in the dialect 'dialect'. + /// typename should not be prefixed with the dialect name. + /// If the dynamic type does not exist, return no value. + /// Otherwise, parse it, and return the parse result. + /// If the parsing succeed, put the resulting type in 'resultType'. + OptionalParseResult parseOptionalDynamicType(StringRef typeName, + DialectAsmParser &parser, + Type &resultType) const; + + /// If 'type' is a dynamic type, print it. + /// Returns success if the type was printed, and failure if the type was not a + /// dynamic type. + static LogicalResult printIfDynamicType(Type type, + DialectAsmPrinter &printer); + +private: + /// The set of all dynamic types registered. + llvm::DenseMap> dynTypes; + + /// This structure allows to get in O(1) a dynamic type given its name. + llvm::StringMap nameToDynTypes; +}; +} // namespace mlir + +namespace llvm { +/// Provide isa functionality for ExtensibleDialect. +/// This is to override the isa functionality for Dialect. +template <> +struct isa_impl { + static inline bool doit(const ::mlir::Dialect &dialect) { + return mlir::ExtensibleDialect::classof(&dialect); + } +}; +} // namespace llvm + +#endif // MLIR_IR_EXTENSIBLEDIALECT_H diff --git a/mlir/include/mlir/IR/IsDynamicInterfaces.h b/mlir/include/mlir/IR/IsDynamicInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/IsDynamicInterfaces.h @@ -0,0 +1,24 @@ +//===- IsDynamicInterfaces.h - Dynamic objects interfaces --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Define an interface that is only implemented on dynamic types. +// The interface is used to check if a type is a DynamicType or not. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_ISDYNAMICINTERFACES_H +#define MLIR_IR_ISDYNAMICINTERFACES_H + +#include "mlir/IR/OpDefinition.h" + +//===----------------------------------------------------------------------===// +// Type interfaces +//===----------------------------------------------------------------------===// + +#include "mlir/IR/IsDynamicTypeInterfaces.h.inc" + +#endif // MLIR_IR_ISDYNAMICINTERFACES_H diff --git a/mlir/include/mlir/IR/IsDynamicInterfaces.td b/mlir/include/mlir/IR/IsDynamicInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/IsDynamicInterfaces.td @@ -0,0 +1,29 @@ +//===- IsDynamicInterfaces.td - Dynamic objects interfaces -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Define an interface that is only implemented on dynamic types. +// The interface is used to check if a type is a DynamicType or not. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_ISDYNAMICINTERFACES +#define MLIR_IR_ISDYNAMICINTERFACES + +include "mlir/IR/OpBase.td" + +def IsDynamicTypeInterface : TypeInterface<"IsDynamicTypeInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + This interface is implemented by all dynamic types, and should not be + implemented by any other type. + The interface is used to check if a type is a dynamic type or not. + }]; + + let methods = []; +} + +#endif // MLIR_IR_ISDYNAMICINTERFACES \ No newline at end of file diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -282,6 +282,9 @@ // If this dialect overrides the hook for canonicalization patterns. bit hasCanonicalizer = 0; + + // If this dialect can be extended at runtime with new operations or types. + bit isExtensible = 0; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -80,7 +80,9 @@ /// Return the unique identifier representing the concrete type class. TypeID getTypeID() const { return typeID; } -private: + /// This should not be used directly + /// The use of this constructor is in general discouraged in favor of + /// 'AbstractType::get()'. AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, TypeID typeID) : dialect(dialect), interfaceMap(std::move(interfaceMap)), @@ -96,6 +98,7 @@ /// be found in the context. static AbstractType *lookupMutable(TypeID typeID, MLIRContext *context); +private: /// This is the dialect that this type was registered to. const Dialect &dialect; @@ -163,37 +166,48 @@ /// A utility class to get, or create, unique instances of types within an /// MLIRContext. This class manages all creation and uniquing of types. struct TypeUniquer { + /// Get an uniqued instance of a type T. + template + static T get(MLIRContext *ctx, Args &&...args) { + return getWithTypeID(ctx, T::getTypeID(), + std::forward(args)...); + } + /// Get an uniqued instance of a parametric type T. + /// The use of this method is in general discouraged in favor of + /// 'get(ctx, args)'. template static typename std::enable_if_t< !std::is_same::value, T> - get(MLIRContext *ctx, Args &&...args) { + getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) { #ifndef NDEBUG - if (!ctx->getTypeUniquer().isParametricStorageInitialized(T::getTypeID())) + if (!ctx->getTypeUniquer().isParametricStorageInitialized(typeID)) llvm::report_fatal_error(llvm::Twine("can't create type '") + llvm::getTypeName() + "' because storage uniquer isn't initialized: " "the dialect was likely not loaded."); #endif return ctx->getTypeUniquer().get( - [&](TypeStorage *storage) { - storage->initialize(AbstractType::lookup(T::getTypeID(), ctx)); + [&, typeID](TypeStorage *storage) { + storage->initialize(AbstractType::lookup(typeID, ctx)); }, - T::getTypeID(), std::forward(args)...); + typeID, std::forward(args)...); } /// Get an uniqued instance of a singleton type T. + /// The use of this method is in general discouraged in favor of + /// 'get(ctx, args)'. template static typename std::enable_if_t< std::is_same::value, T> - get(MLIRContext *ctx) { + getWithTypeID(MLIRContext *ctx, TypeID typeID) { #ifndef NDEBUG - if (!ctx->getTypeUniquer().isSingletonStorageInitialized(T::getTypeID())) + if (!ctx->getTypeUniquer().isSingletonStorageInitialized(typeID)) llvm::report_fatal_error(llvm::Twine("can't create type '") + llvm::getTypeName() + "' because storage uniquer isn't initialized: " "the dialect was likely not loaded."); #endif - return ctx->getTypeUniquer().get(T::getTypeID()); + return ctx->getTypeUniquer().get(typeID); } /// Change the mutable component of the given type instance in the provided @@ -206,22 +220,32 @@ std::forward(args)...); } + /// Register a type instance T with the uniquer. + template + static void registerType(MLIRContext *ctx) { + registerType(ctx, T::getTypeID()); + } + /// Register a parametric type instance T with the uniquer. + /// The use of this method is in general discouraged in favor of + /// 'registerType(ctx)'. template static typename std::enable_if_t< !std::is_same::value> - registerType(MLIRContext *ctx) { + registerType(MLIRContext *ctx, TypeID typeID) { ctx->getTypeUniquer().registerParametricStorageType( - T::getTypeID()); + typeID); } /// Register a singleton type instance T with the uniquer. + /// The use of this method is in general discouraged in favor of + /// 'registerType(ctx)'. template static typename std::enable_if_t< std::is_same::value> - registerType(MLIRContext *ctx) { + registerType(MLIRContext *ctx, TypeID typeID) { ctx->getTypeUniquer().registerSingletonStorageType( - T::getTypeID(), [&](TypeStorage *storage) { - storage->initialize(AbstractType::lookup(T::getTypeID(), ctx)); + typeID, [&ctx, typeID](TypeStorage *storage) { + storage->initialize(AbstractType::lookup(typeID, ctx)); }); } }; diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -69,6 +69,10 @@ /// Returns true if this dialect has fallback interfaces for its operations. bool hasOperationInterfaceFallback() const; + /// Returns true if this dialect can be extended at runtime with new + /// operations or types. + bool isExtensible() const; + // Returns whether two dialects are equal by checking the equality of the // underlying record. bool operator==(const Dialect &other) const; diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -11,9 +11,11 @@ Diagnostics.cpp Dialect.cpp Dominance.cpp + ExtensibleDialect.cpp FunctionImplementation.cpp FunctionSupport.cpp IntegerSet.cpp + IsDynamicInterfaces.cpp Location.cpp MLIRContext.cpp Operation.cpp diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -0,0 +1,333 @@ +//===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Dialects that can register new operations/types/attributes at runtime. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/IsDynamicInterfaces.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Dynamic type +//===----------------------------------------------------------------------===// + +std::unique_ptr +DynamicTypeDefinition::get(llvm::StringRef name, Dialect *dialect, + VerifierFn &&verifier) { + auto *typeDef = new DynamicTypeDefinition(dialect, name); + + auto parser = [](DialectAsmParser &parser, + llvm::SmallVectorImpl &parsedParams) { + // No parameters + if (parser.parseOptionalLess() || !parser.parseOptionalGreater()) + return success(); + + Attribute attr; + if (parser.parseAttribute(attr)) + return failure(); + parsedParams.push_back(attr); + + while (parser.parseOptionalGreater()) { + Attribute attr; + if (parser.parseComma() || parser.parseAttribute(attr)) + return failure(); + parsedParams.push_back(attr); + } + + return success(); + }; + + auto printer = [](DialectAsmPrinter &printer, ArrayRef params) { + if (params.empty()) + return; + + printer << "<"; + llvm::interleaveComma(params, printer.getStream()); + printer << ">"; + }; + + typeDef->verifier = std::move(verifier); + typeDef->parser = std::move(parser); + typeDef->printer = std::move(printer); + + return std::unique_ptr(typeDef); +} + +std::unique_ptr +DynamicTypeDefinition::get(llvm::StringRef name, Dialect *dialect, + VerifierFn &&verifier, ParserFn &&parser, + PrinterFn &&printer) { + return std::unique_ptr( + new DynamicTypeDefinition(name, dialect, std::move(verifier), + std::move(parser), std::move(printer))); +} + +DynamicTypeDefinition::DynamicTypeDefinition(llvm::StringRef nameRef, + Dialect *dialect, + VerifierFn &&verifier, + ParserFn &&parser, + PrinterFn &&printer) + : name(nameRef), dialect(dialect), verifier(std::move(verifier)), + parser(std::move(parser)), printer(std::move(printer)), + typeID(dialect->getContext()->allocateTypeID()), + ctx(dialect->getContext()) { + assert(!nameRef.contains('.') && + "name should not be prefixed by the dialect name"); +} + +DynamicTypeDefinition::DynamicTypeDefinition(Dialect *dialect, + llvm::StringRef nameRef) + : name(nameRef), dialect(dialect), + typeID(dialect->getContext()->allocateTypeID()), + ctx(dialect->getContext()) { + assert(!nameRef.contains('.') && + "name should not be prefixed by the dialect name"); +} + +void DynamicTypeDefinition::registerInTypeUniquer() { + detail::TypeUniquer::registerType(&getContext(), getTypeID()); +} + +namespace mlir { +namespace detail { +/// Storage of DynamicType. +/// Contains a pointer to the type definition and type parameters. +struct DynamicTypeStorage : public TypeStorage { + + using KeyTy = std::pair>; + + explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef, + ArrayRef params) + : typeDef(typeDef), params(params) {} + + bool operator==(const KeyTy &key) const { + return typeDef == key.first && params == key.second; + } + + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + static DynamicTypeStorage *construct(TypeStorageAllocator &alloc, + const KeyTy &key) { + return new (alloc.allocate()) + DynamicTypeStorage(key.first, alloc.copyInto(key.second)); + } + + /// Definition of the type. + DynamicTypeDefinition *typeDef; + + /// The type parameters. + ArrayRef params; +}; +} // namespace detail +} // namespace mlir + +DynamicType DynamicType::get(DynamicTypeDefinition *typeDef, + ArrayRef params) { + auto &ctx = typeDef->getContext(); + return detail::TypeUniquer::getWithTypeID( + &ctx, typeDef->getTypeID(), typeDef, params); +} + +DynamicType +DynamicType::getChecked(function_ref emitError, + DynamicTypeDefinition *typeDef, + ArrayRef params) { + if (failed(typeDef->verify(emitError, params))) + return {}; + return get(typeDef, params); +} + +DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; } + +ArrayRef DynamicType::getParams() { return getImpl()->params; } + +bool DynamicType::classof(Type type) { + return type.isa(); +} + +ParseResult DynamicType::parse(DialectAsmParser &parser, + DynamicTypeDefinition *typeDef, + DynamicType &parsedType) { + llvm::SmallVector params; + if (failed(typeDef->parser(parser, params))) + return failure(); + auto emitError = [&]() { + return parser.emitError(parser.getCurrentLocation()); + }; + parsedType = DynamicType::getChecked(emitError, typeDef, params); + return success(); +} + +void DynamicType::print(DialectAsmPrinter &printer) { + printer << getTypeDef()->getName(); + getTypeDef()->printer(printer, getParams()); +} + +//===----------------------------------------------------------------------===// +// Dynamic operation +//===----------------------------------------------------------------------===// + +DynamicOpDefinition::DynamicOpDefinition( + StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn, + AbstractOperation::ParseAssemblyFn &&parseFn, + AbstractOperation::PrintAssemblyFn &&printFn, + AbstractOperation::FoldHookFn &&foldHookFn, + AbstractOperation::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn) + : typeID(dialect->getContext()->allocateTypeID()), + name((dialect->getNamespace() + "." + name).str()), dialect(dialect), + verifyFn(std::move(verifyFn)), parseFn(std::move(parseFn)), + printFn(std::move(printFn)), foldHookFn(std::move(foldHookFn)), + getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)) { + assert(!name.contains('.') && + "name should not be prefixed by the dialect name"); +} + +std::unique_ptr +DynamicOpDefinition::get(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn) { + auto parseFn = [](OpAsmParser &parser, OperationState &result) { + parser.emitError(parser.getCurrentLocation(), + "dynamic operation do not define any parser function"); + return failure(); + }; + + auto printFn = [](Operation *op, OpAsmPrinter &printer) { + printer.printGenericOp(op); + }; + + return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), + std::move(parseFn), std::move(printFn)); +} + +std::unique_ptr +DynamicOpDefinition::get(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn, + AbstractOperation::ParseAssemblyFn &&parseFn, + AbstractOperation::PrintAssemblyFn &&printFn) { + auto foldHookFn = [](mlir::Operation *op, + llvm::ArrayRef operands, + llvm::SmallVectorImpl &results) { + return failure(); + }; + + auto getCanonicalizationPatternsFn = [](OwningRewritePatternList &, + MLIRContext *) {}; + + return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), + std::move(parseFn), std::move(printFn), + std::move(foldHookFn), + std::move(getCanonicalizationPatternsFn)); +} + +std::unique_ptr +DynamicOpDefinition::get(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn, + AbstractOperation::ParseAssemblyFn &&parseFn, + AbstractOperation::PrintAssemblyFn &&printFn, + AbstractOperation::FoldHookFn &&foldHookFn, + AbstractOperation::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn) { + return std::unique_ptr(new DynamicOpDefinition( + name, dialect, std::move(verifyFn), std::move(parseFn), + std::move(printFn), std::move(foldHookFn), + std::move(getCanonicalizationPatternsFn))); +} + +//===----------------------------------------------------------------------===// +// Extensible dialect +//===----------------------------------------------------------------------===// + +/// Interface that can only be implemented by extensible dialects. +/// The interface is used to check if a dialect is extensible or not. +class IsExtensibleDialect : public DialectInterface::Base { +public: + IsExtensibleDialect(Dialect *dialect) : Base(dialect) {} +}; + +ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx, + TypeID typeID) + : Dialect(Identifier::get(name, ctx), ctx, typeID) { + addInterfaces(); +} + +void ExtensibleDialect::addDynamicType( + std::unique_ptr &&type) { + auto *typePtr = type.get(); + auto typeID = type->getTypeID(); + auto name = type->getName(); + auto *dialect = type->getDialect(); + + assert(dialect == this && + "trying to register a dynamic type in the wrong dialect"); + + // If a type with the same name is already defined, fail. + auto registered = dynTypes.try_emplace(typeID, std::move(type)).second; + assert(registered && "generated TypeID was not unique"); + + registered = nameToDynTypes.insert({name, typePtr}).second; + assert(registered && + "Trying to create a new dynamic type with an existing name"); + + auto interfaceMap = + detail::InterfaceMap::get>(); + auto abstractType = AbstractType(*dialect, std::move(interfaceMap), {}, typeID); + + /// Add the type to the dialect and the type uniquer. + addType(typeID, std::move(abstractType)); + typePtr->registerInTypeUniquer(); +} + +void ExtensibleDialect::addDynamicOp( + std::unique_ptr &&op) { + assert(op->dialect == this && + "trying to register a dynamic op in the wrong dialect"); + auto hasTraitFn = [](TypeID traitId) { return false; }; + + AbstractOperation::insert( + op->name, *op->dialect, op->typeID, std::move(op->parseFn), + std::move(op->printFn), std::move(op->verifyFn), + std::move(op->foldHookFn), std::move(op->getCanonicalizationPatternsFn), + detail::InterfaceMap::get<>(), std::move(hasTraitFn), {}); +} + +bool ExtensibleDialect::classof(const Dialect *dialect) { + return const_cast(dialect) + ->getRegisteredInterface(); +} + +OptionalParseResult ExtensibleDialect::parseOptionalDynamicType( + StringRef typeName, DialectAsmParser &parser, Type &resultType) const { + auto *typeDef = lookupTypeDefinition(typeName); + if (typeDef) { + DynamicType dynType; + if (DynamicType::parse(parser, typeDef, dynType)) + return failure(); + resultType = dynType; + return {success()}; + } + + return {}; +} + +LogicalResult +ExtensibleDialect::printIfDynamicType(Type type, DialectAsmPrinter &printer) { + if (auto dynType = type.dyn_cast()) { + dynType.print(printer); + return success(); + } + return failure(); +} diff --git a/mlir/lib/IR/IsDynamicInterfaces.cpp b/mlir/lib/IR/IsDynamicInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/IsDynamicInterfaces.cpp @@ -0,0 +1,15 @@ +//===- IsDynamicInterfaces.cpp - Dynamic objects interfaces *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Define an interface that is only implemented on dynamic types. +// The interface is used to check if a type is a DynamicType or not. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/IsDynamicInterfaces.h" + +#include "mlir/IR/IsDynamicTypeInterfaces.cpp.inc" diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -85,6 +85,10 @@ return def->getValueAsBit("hasOperationInterfaceFallback"); } +bool Dialect::isExtensible() const { + return def->getValueAsBit("isExtensible"); +} + bool Dialect::operator==(const Dialect &other) const { return def == other.def; } diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/dynamic.mlir @@ -0,0 +1,84 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics | FileCheck %s +// Verify that extensible dialects can register dynamic operations and types. + +//===----------------------------------------------------------------------===// +// Dynamic type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @succeededDynamicTypeVerifier +func @succeededDynamicTypeVerifier() { + // CHECK: %{{.*}} = "unregistered_op"() : () -> !test.singleton_dyntype + // CHECK-NEXT: "unregistered_op"() : () -> !test.pair_dyntype + // CHECK_NEXT: %{{.*}} = "unregistered_op"() : () -> !test.pair_dyntype, !test.singleton_dyntype> + "unregistered_op"() : () -> !test.singleton_dyntype + "unregistered_op"() : () -> !test.pair_dyntype + "unregistered_op"() : () -> !test.pair_dyntype, !test.singleton_dyntype> + return +} + +// ----- + +func @failedDynamicTypeVerifier() { + // expected-error@+1 {{expected 0 type arguments, but had 1}} + "unregistered_op"() : () -> !test.singleton_dyntype + return +} + +// ----- + +func @failedDynamicTypeVerifier2() { + // expected-error@+1 {{expected 2 type arguments, but had 1}} + "unregistered_op"() : () -> !test.pair_dyntype + return +} + +// ----- + +// CHECK-LABEL: func @customTypeParserPrinter +func @customTypeParserPrinter() { + // CHECK: "unregistered_op"() : () -> !test.custom_assembly_format_dyntype + "unregistered_op"() : () -> !test.custom_assembly_format_dyntype + return +} + +//===----------------------------------------------------------------------===// +// Dynamic op +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @succeededDynamicOpVerifier +func @succeededDynamicOpVerifier(%a: f32) { + // CHECK: "test.generic_dynamic_op"() : () -> () + // CHECK-NEXT: %{{.*}} = "test.generic_dynamic_op"(%{{.*}}) : (f32) -> f64 + // CHECK-NEXT: %{{.*}}:2 = "test.one_operand_two_results"(%{{.*}}) : (f32) -> (f64, f64) + "test.generic_dynamic_op"() : () -> () + "test.generic_dynamic_op"(%a) : (f32) -> f64 + "test.one_operand_two_results"(%a) : (f32) -> (f64, f64) + return +} + +// ----- + +func @failedDynamicOpVerifier() { + // expected-error@+1 {{expected 1 operand, but had 0}} + "test.one_operand_two_results"() : () -> (f64, f64) + return +} + +// ----- + +func @failedDynamicOpVerifier2(%a: f32) { + // expected-error@+1 {{expected 2 results, but had 0}} + "test.one_operand_two_results"(%a) : (f32) -> () + return +} + +// ----- + +// CHECK-LABEL: func @customOpParserPrinter +func @customOpParserPrinter() { + // CHECK: test.custom_parser_printer_dynamic_op custom_keyword + test.custom_parser_printer_dynamic_op custom_keyword + return +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -21,6 +21,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/RegionKindInterface.h" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Reducer/ReductionPatternInterface.h" @@ -201,6 +202,54 @@ } // end anonymous namespace +//===----------------------------------------------------------------------===// +// Dynamic operations +//===----------------------------------------------------------------------===// + +std::unique_ptr getGenericDynamicOp(Dialect *dialect) { + return DynamicOpDefinition::get("generic_dynamic_op", dialect, + [](Operation *op) { return success(); }); +} + +std::unique_ptr +getOneOperandTwoResultsDynamicOp(Dialect *dialect) { + return DynamicOpDefinition::get( + "one_operand_two_results", dialect, [](Operation *op) { + if (op->getNumOperands() != 1) { + op->emitOpError() + << "expected 1 operand, but had " << op->getNumOperands(); + return failure(); + } + if (op->getNumResults() != 2) { + op->emitOpError() + << "expected 2 results, but had " << op->getNumResults(); + return failure(); + } + return success(); + }); +} + +std::unique_ptr +getCustomParserPrinterDynamicOp(Dialect *dialect) { + auto verifier = [](Operation *op) { + if (op->getNumOperands() == 0 && op->getNumResults() == 0) + return success(); + op->emitError() << "operation should have no operands and no results"; + return failure(); + }; + + auto parser = [](OpAsmParser &parser, OperationState &state) { + return parser.parseKeyword("custom_keyword"); + }; + + auto printer = [](Operation *op, OpAsmPrinter &printer) { + printer << op->getName() << " custom_keyword"; + }; + + return DynamicOpDefinition::get("custom_parser_printer_dynamic_op", dialect, + verifier, parser, printer); +} + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// @@ -235,6 +284,10 @@ #define GET_OP_LIST #include "TestOps.cpp.inc" >(); + addDynamicOp(getGenericDynamicOp(this)); + addDynamicOp(getOneOperandTwoResultsDynamicOp(this)); + addDynamicOp(getCustomParserPrinterDynamicOp(this)); + addInterfaces(); allowUnknownOperations(); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -31,6 +31,7 @@ let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; let hasOperationInterfaceFallback = 1; + let isExtensible = 1; let dependentDialects = ["::mlir::DLTIDialect"]; let extraClassDeclaration = [{ @@ -53,6 +54,10 @@ // Storage for a custom fallback interface. void *fallbackEffectOpInterfaces; + Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, + SetVector &stack) const; + void printTestType(Type type, DialectAsmPrinter &printer, + SetVector &stack) const; }]; } diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -218,6 +218,75 @@ #define GET_TYPEDEF_CLASSES #include "TestTypeDefs.cpp.inc" +//===----------------------------------------------------------------------===// +// Dynamic Types +//===----------------------------------------------------------------------===// + +namespace { +/// Define a singleton dynamic type. +std::unique_ptr +getSingletonDynamicType(Dialect *testDialect) { + return DynamicTypeDefinition::get( + "singleton_dyntype", testDialect, + [](function_ref emitError, + ArrayRef args) { + if (!args.empty()) { + emitError() << "expected 0 type arguments, but had " << args.size(); + return failure(); + } + return success(); + }); +} + +/// Define a dynamic type representing a pair. +std::unique_ptr +getPairDynamicType(Dialect *testDialect) { + return DynamicTypeDefinition::get( + "pair_dyntype", testDialect, + [](function_ref emitError, + ArrayRef args) { + if (args.size() != 2) { + emitError() << "expected 2 type arguments, but had " << args.size(); + return failure(); + } + return success(); + }); +} + +std::unique_ptr +getCustomAssemblyFormatDynamicType(Dialect *testDialect) { + auto verifier = [](function_ref emitError, + ArrayRef args) { + if (args.size() != 2) { + emitError() << "expected 2 type arguments, but had " << args.size(); + return failure(); + } + return success(); + }; + + auto parser = [](DialectAsmParser &parser, + llvm::SmallVectorImpl &parsedParams) { + Attribute leftAttr, rightAttr; + if (parser.parseLess() || parser.parseAttribute(leftAttr) || + parser.parseColon() || parser.parseAttribute(rightAttr) || + parser.parseGreater()) + return failure(); + parsedParams.push_back(leftAttr); + parsedParams.push_back(rightAttr); + return success(); + }; + + auto printer = [](DialectAsmPrinter &printer, ArrayRef params) { + printer << "<" << params[0] << ":" << params[1] << ">"; + }; + + return DynamicTypeDefinition::get("custom_assembly_format_dyntype", + testDialect, std::move(verifier), + std::move(parser), std::move(printer)); +} + +} // namespace + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// @@ -227,10 +296,14 @@ #define GET_TYPEDEF_LIST #include "TestTypeDefs.cpp.inc" >(); + + addDynamicType(getSingletonDynamicType(this)); + addDynamicType(getPairDynamicType(this)); + addDynamicType(getCustomAssemblyFormatDynamicType(this)); } -static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, - SetVector &stack) { +Type TestDialect::parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, + SetVector &stack) const { StringRef typeTag; if (failed(parser.parseKeyword(&typeTag))) return Type(); @@ -242,6 +315,16 @@ return genType; } + { + Type dynType; + auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynType; + return Type(); + } + } + if (typeTag != "test_rec") { parser.emitError(parser.getNameLoc()) << "unknown type!"; return Type(); @@ -277,11 +360,14 @@ return parseTestType(getContext(), parser, stack); } -static void printTestType(Type type, DialectAsmPrinter &printer, - SetVector &stack) { +void TestDialect::printTestType(Type type, DialectAsmPrinter &printer, + SetVector &stack) const { if (succeeded(generatedTypePrinter(type, printer))) return; + if (succeeded(printIfDynamicType(type, printer))) + return; + auto rec = type.cast(); printer << "test_rec<" << rec.getName(); if (!stack.contains(rec)) { diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -65,9 +65,9 @@ /// {2}: initialization code that is emitted in the ctor body before calling /// initialize() static const char *const dialectDeclBeginStr = R"( -class {0} : public ::mlir::Dialect { +class {0} : public ::mlir::{3} { explicit {0}(::mlir::MLIRContext *context) - : ::mlir::Dialect(getDialectNamespace(), context, + : ::mlir::{3}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{ {2} initialize(); @@ -176,8 +176,10 @@ // Emit the start of the decl. std::string cppName = dialect.getCppClassName(); + StringRef superClassName = + dialect.isExtensible() ? "ExtensibleDialect" : "Dialect"; os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), - dependentDialectRegistrations); + dependentDialectRegistrations, superClassName); // Check for any attributes/types registered to this dialect. If there are, // add the hooks for parsing/printing.