diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -26,6 +26,11 @@ mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs) add_public_tablegen_target(MLIRBuiltinTypesIncGen) +set(LLVM_TARGET_DEFINITIONS IsDynamicInterfaces.td) +mlir_tablegen(IsDynamicTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(IsDynamicTypeInterfaces.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(MLIRIsDynamicTypeInterfaces) + set(LLVM_TARGET_DEFINITIONS TensorEncoding.td) mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls) mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs) diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/ExtensibleDialect.h @@ -0,0 +1,283 @@ +//===- ExtensibleDialect.h - Extensible dialect -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Dialects that can register new operations/types/attributes at runtime. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_EXTENSIBLEDIALECT_H +#define MLIR_IR_EXTENSIBLEDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir { +class MLIRContext; +class DialectAsmPrinter; +class DialectAsmParser; +class ParseResult; +class OptionalParseResult; +class ExtensibleDialect; + +namespace detail { +struct DynamicTypeStorage; +} // namespace detail + +//===----------------------------------------------------------------------===// +// Dynamic type +//===----------------------------------------------------------------------===// + +class DynamicType; + +/// This is the definition of a dynamic type. It stores the parser, +/// the printer, and the verifier. +/// Each dynamic type instance refer to one instance of this class. +class DynamicTypeDefinition { +public: + using VerifierFn = llvm::unique_function, ArrayRef) const>; + using ParserFn = llvm::unique_function &parsedAttributes) + const>; + using PrinterFn = llvm::unique_function params) const>; + + static std::unique_ptr + get(llvm::StringRef name, Dialect *dialect, VerifierFn &&verifier); + + static std::unique_ptr + get(llvm::StringRef name, Dialect *dialect, VerifierFn &&verifier, + ParserFn &&parser, PrinterFn &&printer); + + /// Check that the type parameters are valid. + LogicalResult verify(function_ref emitError, + ArrayRef params) const { + return verifier(emitError, params); + } + + /// Get the unique identifier associated with the concrete type. + TypeID getTypeID() const { return typeID; } + + /// Get the MLIRContext in which the dynamic types are uniqued. + MLIRContext &getContext() const { return *ctx; } + + /// Get the name of the type, in the format 'typename' and + /// not 'dialectname.typename'. + StringRef getName() const { return name; } + + /// Get the dialect defining the type. + Dialect *getDialect() const { return dialect; } + +private: + DynamicTypeDefinition(llvm::StringRef name, Dialect *dialect, + VerifierFn &&verifier, ParserFn &&parser, + PrinterFn &&printer); + + /// This constructor should only be used when we need a pointer to + /// the DynamicTypeDefinition in the verifier, the parser, or the printer. + /// The verifier, parser, and printer need thus to be initialized after the + /// constructor. + DynamicTypeDefinition(Dialect *dialect, llvm::StringRef name); + + /// Register the concrete type in the type Uniquer. + void registerInTypeUniquer(); + + /// The name should be prefixed with the dialect name followed by '.'. + std::string name; + + /// Dialect in which this type is defined. + Dialect *dialect; + + /// Verifier for the type parameters. + VerifierFn verifier; + + /// Parse the type parameters. + ParserFn parser; + + /// Print the type parameters. + PrinterFn printer; + + /// Unique identifier for the concrete type. + TypeID typeID; + + /// Context in which the concrete types are uniqued. + MLIRContext *ctx; + + friend ExtensibleDialect; + friend DynamicType; +}; + +/// A type defined at runtime. +/// Each DynamicType instance represent a different dynamic type. +class DynamicType + : public Type::TypeBase { +public: + // Inherit Base constructors. + using Base::Base; + + /// Get an instance of a dynamic type given a dynamic type definition and + /// type parameters. + /// This function does not call the type verifier. + static DynamicType get(DynamicTypeDefinition *typeDef, + ArrayRef params = {}); + + /// Get an instance of a dynamic type given a dynamic type definition and type + /// parameters. + /// This function also call the verifier to check if the parameters are valid. + static DynamicType getChecked(function_ref emitError, + DynamicTypeDefinition *typeDef, + ArrayRef params = {}); + + /// Get the type definition of the concrete type. + DynamicTypeDefinition *getTypeDef(); + + /// Get the type parameters. + ArrayRef getParams(); + + /// Check if a type is a specific dynamic type. + static bool isa(Type type, DynamicTypeDefinition *typeDef) { + return type.getTypeID() == typeDef->getTypeID(); + } + + /// Check if a type is a dynamic type. + static bool classof(Type type); + + /// Parse the dynamic type parameters and construct the type. + /// The parameters are either empty, and nothing is parsed, + /// or they are in the format '<>' or ''. + static ParseResult parse(DialectAsmParser &parser, + DynamicTypeDefinition *typeDef, + DynamicType &parsedType); + + /// Print the dynamic type with the format + /// 'type' or 'type<>' if there is no parameters, or 'type'. + void print(DialectAsmPrinter &printer); +}; + +//===----------------------------------------------------------------------===// +// Dynamic operation +//===----------------------------------------------------------------------===// + +/// The definition of a dynamic operation. +/// It contains the name of the operation, its owning dialect, a verifier, +/// a printer, and parser. +class DynamicOpDefinition { +public: + static std::unique_ptr + get(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn); + + static std::unique_ptr + get(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn, + AbstractOperation::ParseAssemblyFn &&parseFn, + AbstractOperation::PrintAssemblyFn &&printFn); + + static std::unique_ptr + get(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn, + AbstractOperation::ParseAssemblyFn &&parseFn, + AbstractOperation::PrintAssemblyFn &&printFn, + AbstractOperation::FoldHookFn &&foldHookFn, + AbstractOperation::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn); + +private: + DynamicOpDefinition(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn, + AbstractOperation::ParseAssemblyFn &&parseFn, + AbstractOperation::PrintAssemblyFn &&printFn, + AbstractOperation::FoldHookFn &&foldHookFn, + AbstractOperation::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn); + + /// Unique identifier for this operation. + TypeID typeID; + + /// Name of the operation. + /// The name is prefixed with the dialect name. + std::string name; + + /// Dialect defining this operation. + Dialect *dialect; + + AbstractOperation::VerifyInvariantsFn verifyFn; + AbstractOperation::ParseAssemblyFn parseFn; + AbstractOperation::PrintAssemblyFn printFn; + AbstractOperation::FoldHookFn foldHookFn; + AbstractOperation::GetCanonicalizationPatternsFn + getCanonicalizationPatternsFn; + + friend ExtensibleDialect; +}; + +//===----------------------------------------------------------------------===// +// Extensible dialect +//===----------------------------------------------------------------------===// + +/// A dialect that can be extended with new operations/types/attributes at +/// runtime. +class ExtensibleDialect : public mlir::Dialect { +public: + ExtensibleDialect(StringRef name, MLIRContext *ctx, TypeID typeID); + + /// Add a new type defined at runtime to the dialect. + void addDynamicType(std::unique_ptr &&type); + + /// Add a new operation defined at runtime to the dialect. + void addDynamicOp(std::unique_ptr &&type); + + /// Check if the dialect is an extensible dialect. + static bool classof(Dialect *dialect); + + /// The pointer is guaranteed to be non-null. + FailureOr + lookupTypeDefinition(StringRef name) const { + auto it = nameToDynTypes.find(name); + if (it == nameToDynTypes.end()) + return failure(); + return it->second; + } + + /// The pointer is guaranteed to be non-null. + FailureOr lookupTypeDefinition(TypeID id) const { + auto it = dynTypes.find(id); + if (it == dynTypes.end()) + return failure(); + return it->second.get(); + } + +protected: + /// Parse the dynamic type 'typeName' in the dialect 'dialect'. + /// typename should not be prefixed with the dialect name. + /// If the dynamic type does not exist, return no value. + /// Otherwise, parse it, and return the parse result. + /// If the parsing succeed, put the resulting type in 'resultType'. + OptionalParseResult parseOptionalDynamicType(StringRef typeName, + DialectAsmParser &parser, + Type &resultType) const; + + /// If 'type' is a dynamic type, print it. + /// Returns success if the type was printed, and failure if the type was not a + /// dynamic type. + static LogicalResult printIfDynamicType(Type type, + DialectAsmPrinter &printer); + +private: + /// The set of all dynamic types registered. + llvm::DenseMap> dynTypes; + + /// This structure allows to get in O(1) a dynamic type given its name. + llvm::StringMap nameToDynTypes; +}; +} // namespace mlir + +#endif // MLIR_IR_EXTENSIBLEDIALECT_H diff --git a/mlir/include/mlir/IR/IsDynamicInterfaces.h b/mlir/include/mlir/IR/IsDynamicInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/IsDynamicInterfaces.h @@ -0,0 +1,24 @@ +//===- IsDynamicInterfaces.h - Dynamic objects interfaces --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Define an interface that is only implemented on dynamic types. +// The interface is used to check if a type is a DynamicType or not. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_ISDYNAMICDIALECT_H +#define MLIR_IR_ISDYNAMICDIALECT_H + +#include "mlir/IR/OpDefinition.h" + +//===----------------------------------------------------------------------===// +// Type interfaces +//===----------------------------------------------------------------------===// + +#include "mlir/IR/IsDynamicTypeInterfaces.h.inc" + +#endif // MLIR_IR_ISDYNAMICDIALECT_H diff --git a/mlir/include/mlir/IR/IsDynamicInterfaces.td b/mlir/include/mlir/IR/IsDynamicInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/IsDynamicInterfaces.td @@ -0,0 +1,29 @@ +//===- IsDynamicInterfaces.td - Dynamic objects interfaces -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Define an interface that is only implemented on dynamic types. +// The interface is used to check if a type is a DynamicType or not. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_ISDYNAMICINTERFACES +#define MLIR_IR_ISDYNAMICINTERFACES + +include "mlir/IR/OpBase.td" + +def IsDynamicTypeInterface : TypeInterface<"IsDynamicTypeInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + This interface is implemented by all dynamic types, and should not be + implemented by any other type. + The interface is used to check if a type is a dynamic type or not. + }]; + + let methods = []; +} + +#endif // MLIR_IR_ISDYNAMICINTERFACES \ No newline at end of file diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -275,6 +275,9 @@ // If this dialect overrides the hook for op interface fallback. bit hasOperationInterfaceFallback = 0; + + // If this dialect can be extended at runtime with new operations or types. + bit isExtensible = 0; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -61,12 +61,15 @@ /// Return the unique identifier representing the concrete type class. TypeID getTypeID() const { return typeID; } -private: + /// This should not be used directly + /// The use of this constructor is in general discouraged in favor of + /// 'AbstractType::get()'. AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap, TypeID typeID) : dialect(dialect), interfaceMap(std::move(interfaceMap)), typeID(typeID) {} +private: /// This is the dialect that this type was registered to. Dialect &dialect; @@ -131,37 +134,48 @@ /// A utility class to get, or create, unique instances of types within an /// MLIRContext. This class manages all creation and uniquing of types. struct TypeUniquer { + /// Get an uniqued instance of a type T. + template + static T get(MLIRContext *ctx, Args &&...args) { + return getWithTypeID(ctx, T::getTypeID(), + std::forward(args)...); + } + /// Get an uniqued instance of a parametric type T. + /// The use of this method is in general discouraged in favor of + /// 'get(ctx, args)'. template static typename std::enable_if_t< !std::is_same::value, T> - get(MLIRContext *ctx, Args &&...args) { + getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) { #ifndef NDEBUG - if (!ctx->getTypeUniquer().isParametricStorageInitialized(T::getTypeID())) + if (!ctx->getTypeUniquer().isParametricStorageInitialized(typeID)) llvm::report_fatal_error(llvm::Twine("can't create type '") + llvm::getTypeName() + "' because storage uniquer isn't initialized: " "the dialect was likely not loaded."); #endif return ctx->getTypeUniquer().get( - [&](TypeStorage *storage) { - storage->initialize(AbstractType::lookup(T::getTypeID(), ctx)); + [&, typeID](TypeStorage *storage) { + storage->initialize(AbstractType::lookup(typeID, ctx)); }, - T::getTypeID(), std::forward(args)...); + typeID, std::forward(args)...); } /// Get an uniqued instance of a singleton type T. + /// The use of this method is in general discouraged in favor of + /// 'get(ctx, args)'. template static typename std::enable_if_t< std::is_same::value, T> - get(MLIRContext *ctx) { + getWithTypeID(MLIRContext *ctx, TypeID typeID) { #ifndef NDEBUG - if (!ctx->getTypeUniquer().isSingletonStorageInitialized(T::getTypeID())) + if (!ctx->getTypeUniquer().isSingletonStorageInitialized(typeID)) llvm::report_fatal_error(llvm::Twine("can't create type '") + llvm::getTypeName() + "' because storage uniquer isn't initialized: " "the dialect was likely not loaded."); #endif - return ctx->getTypeUniquer().get(T::getTypeID()); + return ctx->getTypeUniquer().get(typeID); } /// Change the mutable component of the given type instance in the provided @@ -174,22 +188,32 @@ std::forward(args)...); } + /// Register a type instance T with the uniquer. + template + static void registerType(MLIRContext *ctx) { + registerType(ctx, T::getTypeID()); + } + /// Register a parametric type instance T with the uniquer. + /// The use of this method is in general discouraged in favor of + /// 'registerType(ctx)'. template static typename std::enable_if_t< !std::is_same::value> - registerType(MLIRContext *ctx) { + registerType(MLIRContext *ctx, TypeID typeID) { ctx->getTypeUniquer().registerParametricStorageType( - T::getTypeID()); + typeID); } /// Register a singleton type instance T with the uniquer. + /// The use of this method is in general discouraged in favor of + /// 'registerType(ctx)'. template static typename std::enable_if_t< std::is_same::value> - registerType(MLIRContext *ctx) { + registerType(MLIRContext *ctx, TypeID typeID) { ctx->getTypeUniquer().registerSingletonStorageType( - T::getTypeID(), [&](TypeStorage *storage) { - storage->initialize(AbstractType::lookup(T::getTypeID(), ctx)); + typeID, [&ctx, typeID](TypeStorage *storage) { + storage->initialize(AbstractType::lookup(typeID, ctx)); }); } }; diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -66,6 +66,10 @@ /// Returns true if this dialect has fallback interfaces for its operations. bool hasOperationInterfaceFallback() const; + /// Returns true if this dialect can be extended at runtime with new + /// operations or types. + bool isExtensible() const; + // Returns whether two dialects are equal by checking the equality of the // underlying record. bool operator==(const Dialect &other) const; diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -11,9 +11,11 @@ Diagnostics.cpp Dialect.cpp Dominance.cpp + ExtensibleDialect.cpp FunctionImplementation.cpp FunctionSupport.cpp IntegerSet.cpp + IsDynamicInterfaces.cpp Location.cpp MLIRContext.cpp Operation.cpp diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -0,0 +1,332 @@ +//===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Dialects that can register new operations/types/attributes at runtime. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/IsDynamicInterfaces.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Dynamic type +//===----------------------------------------------------------------------===// + +std::unique_ptr +DynamicTypeDefinition::get(llvm::StringRef name, Dialect *dialect, + VerifierFn &&verifier) { + auto *typeDef = new DynamicTypeDefinition(dialect, name); + + auto parser = [](DialectAsmParser &parser, + std::vector &parsedParams) { + // No parameters + if (parser.parseOptionalLess() || !parser.parseOptionalGreater()) + return success(); + + Attribute attr; + if (parser.parseAttribute(attr)) + return failure(); + parsedParams.push_back(attr); + + while (parser.parseOptionalGreater()) { + Attribute attr; + if (parser.parseComma() || parser.parseAttribute(attr)) + return failure(); + parsedParams.push_back(attr); + } + + return success(); + }; + + auto printer = [](DialectAsmPrinter &printer, ArrayRef params) { + if (params.empty()) + return; + + printer << "<"; + llvm::interleaveComma(params, printer.getStream()); + printer << ">"; + }; + + typeDef->verifier = std::move(verifier); + typeDef->parser = std::move(parser); + typeDef->printer = std::move(printer); + + return std::unique_ptr(typeDef); +} + +std::unique_ptr +DynamicTypeDefinition::get(llvm::StringRef name, Dialect *dialect, + VerifierFn &&verifier, ParserFn &&parser, + PrinterFn &&printer) { + return std::unique_ptr( + new DynamicTypeDefinition(name, dialect, std::move(verifier), + std::move(parser), std::move(printer))); +} + +DynamicTypeDefinition::DynamicTypeDefinition(llvm::StringRef nameRef, + Dialect *dialect, + VerifierFn &&verifier, + ParserFn &&parser, + PrinterFn &&printer) + : name(nameRef), dialect(dialect), verifier(std::move(verifier)), + parser(std::move(parser)), printer(std::move(printer)), + typeID(dialect->getContext()->allocateTypeID()), + ctx(dialect->getContext()) { + assert(!nameRef.contains('.') && + "name should not be prefixed by the dialect name"); +} + +DynamicTypeDefinition::DynamicTypeDefinition(Dialect *dialect, + llvm::StringRef nameRef) + : name(nameRef), dialect(dialect), + typeID(dialect->getContext()->allocateTypeID()), + ctx(dialect->getContext()) { + assert(!nameRef.contains('.') && + "name should not be prefixed by the dialect name"); +} + +void DynamicTypeDefinition::registerInTypeUniquer() { + detail::TypeUniquer::registerType(&getContext(), getTypeID()); +} + +namespace mlir { +namespace detail { +/// Storage of DynamicType. +/// Contains a pointer to the type definition and type parameters. +struct DynamicTypeStorage : public TypeStorage { + + using KeyTy = std::pair>; + + explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef, + ArrayRef params) + : typeDef(typeDef), params(params) {} + + bool operator==(const KeyTy &key) const { + return typeDef == key.first && params == key.second; + } + + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + static DynamicTypeStorage *construct(TypeStorageAllocator &alloc, + const KeyTy &key) { + return new (alloc.allocate()) + DynamicTypeStorage(key.first, alloc.copyInto(key.second)); + } + + /// Definition of the type. + DynamicTypeDefinition *typeDef; + + /// The type parameters. + ArrayRef params; +}; +} // namespace detail +} // namespace mlir + +DynamicType DynamicType::get(DynamicTypeDefinition *typeDef, + ArrayRef params) { + auto &ctx = typeDef->getContext(); + return detail::TypeUniquer::getWithTypeID( + &ctx, typeDef->getTypeID(), typeDef, params); +} + +DynamicType +DynamicType::getChecked(function_ref emitError, + DynamicTypeDefinition *typeDef, + ArrayRef params) { + if (failed(typeDef->verify(emitError, params))) + return {}; + return get(typeDef, params); +} + +DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; } + +ArrayRef DynamicType::getParams() { return getImpl()->params; } + +bool DynamicType::classof(Type type) { + return type.isa(); +} + +ParseResult DynamicType::parse(DialectAsmParser &parser, + DynamicTypeDefinition *typeDef, + DynamicType &parsedType) { + std::vector params; + if (failed(typeDef->parser(parser, params))) + return failure(); + auto emitError = [&]() { + return parser.emitError(parser.getCurrentLocation()); + }; + parsedType = DynamicType::getChecked(emitError, typeDef, params); + return success(); +} + +void DynamicType::print(DialectAsmPrinter &printer) { + printer << getTypeDef()->getName(); + getTypeDef()->printer(printer, getParams()); +} + +//===----------------------------------------------------------------------===// +// Dynamic operation +//===----------------------------------------------------------------------===// + +DynamicOpDefinition::DynamicOpDefinition( + StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn, + AbstractOperation::ParseAssemblyFn &&parseFn, + AbstractOperation::PrintAssemblyFn &&printFn, + AbstractOperation::FoldHookFn &&foldHookFn, + AbstractOperation::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn) + : typeID(dialect->getContext()->allocateTypeID()), + name((dialect->getNamespace() + "." + name).str()), dialect(dialect), + verifyFn(std::move(verifyFn)), parseFn(std::move(parseFn)), + printFn(std::move(printFn)), foldHookFn(std::move(foldHookFn)), + getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)) { + assert(!name.contains('.') && + "name should not be prefixed by the dialect name"); +} + +std::unique_ptr +DynamicOpDefinition::get(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn) { + auto parseFn = [](OpAsmParser &parser, OperationState &result) { + parser.emitError(parser.getCurrentLocation(), + "dynamic operation do not define any parser function"); + return failure(); + }; + + auto printFn = [](Operation *op, OpAsmPrinter &printer) { + printer.printGenericOp(op); + }; + + return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), + std::move(parseFn), std::move(printFn)); +} + +std::unique_ptr +DynamicOpDefinition::get(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn, + AbstractOperation::ParseAssemblyFn &&parseFn, + AbstractOperation::PrintAssemblyFn &&printFn) { + auto foldHookFn = [](mlir::Operation *op, + llvm::ArrayRef operands, + llvm::SmallVectorImpl &results) { + return failure(); + }; + + auto getCanonicalizationPatternsFn = [](OwningRewritePatternList &, + MLIRContext *) {}; + + return DynamicOpDefinition::get(name, dialect, std::move(verifyFn), + std::move(parseFn), std::move(printFn), + std::move(foldHookFn), + std::move(getCanonicalizationPatternsFn)); +} + +std::unique_ptr +DynamicOpDefinition::get(StringRef name, Dialect *dialect, + AbstractOperation::VerifyInvariantsFn &&verifyFn, + AbstractOperation::ParseAssemblyFn &&parseFn, + AbstractOperation::PrintAssemblyFn &&printFn, + AbstractOperation::FoldHookFn &&foldHookFn, + AbstractOperation::GetCanonicalizationPatternsFn + &&getCanonicalizationPatternsFn) { + return std::unique_ptr( + new DynamicOpDefinition(name, dialect, std::move(verifyFn), + std::move(parseFn), std::move(printFn), + std::move(foldHookFn), + std::move(getCanonicalizationPatternsFn))); +} + +//===----------------------------------------------------------------------===// +// Extensible dialect +//===----------------------------------------------------------------------===// + +/// Interface that can only be implemented by extensible dialects. +/// The interface is used to check if a dialect is extensible or not. +class IsExtensibleDialect : public DialectInterface::Base { +public: + IsExtensibleDialect(Dialect *dialect) : Base(dialect) {} +}; + +ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx, + TypeID typeID) + : Dialect(name, ctx, typeID) { + addInterfaces(); +} + +void ExtensibleDialect::addDynamicType( + std::unique_ptr &&type) { + auto *typePtr = type.get(); + auto typeID = type->getTypeID(); + auto name = type->getName(); + auto *dialect = type->getDialect(); + + assert(dialect == this && + "trying to register a dynamic type in the wrong dialect"); + + // If a type with the same name is already defined, fail. + auto registered = dynTypes.try_emplace(typeID, std::move(type)).second; + assert(registered && "generated TypeID was not unique"); + + registered = nameToDynTypes.insert({name, typePtr}).second; + assert(registered && + "Trying to create a new dynamic type with an existing name"); + + auto interfaceMap = + detail::InterfaceMap::get>(); + auto abstractType = AbstractType(*dialect, std::move(interfaceMap), typeID); + + /// Add the type to the dialect and the type uniquer. + addType(typeID, std::move(abstractType)); + typePtr->registerInTypeUniquer(); +} + +void ExtensibleDialect::addDynamicOp( + std::unique_ptr &&op) { + assert(op->dialect == this && + "trying to register a dynamic op in the wrong dialect"); + auto hasTraitFn = [](TypeID traitId) { return false; }; + + AbstractOperation::insert( + op->name, *op->dialect, op->typeID, std::move(op->parseFn), + std::move(op->printFn), std::move(op->verifyFn), + std::move(op->foldHookFn), std::move(op->getCanonicalizationPatternsFn), + detail::InterfaceMap::get<>(), std::move(hasTraitFn)); +} + +bool ExtensibleDialect::classof(Dialect *dialect) { + return dialect->getRegisteredInterface(); +} + +OptionalParseResult ExtensibleDialect::parseOptionalDynamicType( + StringRef typeName, DialectAsmParser &parser, Type &resultType) const { + auto typeDef = lookupTypeDefinition(typeName); + if (succeeded(typeDef)) { + DynamicType dynType; + if (DynamicType::parse(parser, *typeDef, dynType)) + return failure(); + resultType = dynType; + return {success()}; + } + + return {}; +} + +LogicalResult +ExtensibleDialect::printIfDynamicType(Type type, DialectAsmPrinter &printer) { + if (auto dynType = type.dyn_cast()) { + dynType.print(printer); + return success(); + } + return failure(); +} diff --git a/mlir/lib/IR/IsDynamicInterfaces.cpp b/mlir/lib/IR/IsDynamicInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/IsDynamicInterfaces.cpp @@ -0,0 +1,15 @@ +//===- IsDynamicInterfaces.cpp - Dynamic objects interfaces *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Define an interface that is only implemented on dynamic types. +// The interface is used to check if a type is a DynamicType or not. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/IsDynamicInterfaces.h" + +#include "mlir/IR/IsDynamicTypeInterfaces.cpp.inc" diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -81,6 +81,10 @@ return def->getValueAsBit("hasOperationInterfaceFallback"); } +bool Dialect::isExtensible() const { + return def->getValueAsBit("isExtensible"); +} + bool Dialect::operator==(const Dialect &other) const { return def == other.def; } diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/dynamic.mlir @@ -0,0 +1,84 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics | FileCheck %s +// Verify that extensible dialects can register dynamic operations and types. + +//===----------------------------------------------------------------------===// +// Dynamic type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @succeededDynamicTypeVerifier +func @succeededDynamicTypeVerifier() { + // CHECK: %{{.*}} = "unregistered_op"() : () -> !test.singleton_dyntype + // CHECK-NEXT: "unregistered_op"() : () -> !test.pair_dyntype + // CHECK_NEXT: %{{.*}} = "unregistered_op"() : () -> !test.pair_dyntype, !test.singleton_dyntype> + "unregistered_op"() : () -> !test.singleton_dyntype + "unregistered_op"() : () -> !test.pair_dyntype + "unregistered_op"() : () -> !test.pair_dyntype, !test.singleton_dyntype> + return +} + +// ----- + +func @failedDynamicTypeVerifier() { + // expected-error@+1 {{expected 0 type arguments, but had 1}} + "unregistered_op"() : () -> !test.singleton_dyntype + return +} + +// ----- + +func @failedDynamicTypeVerifier2() { + // expected-error@+1 {{expected 2 type arguments, but had 1}} + "unregistered_op"() : () -> !test.pair_dyntype + return +} + +// ----- + +// CHECK-LABEL: func @customTypeParserPrinter +func @customTypeParserPrinter() { + // CHECK: "unregistered_op"() : () -> !test.custom_assembly_format_dyntype + "unregistered_op"() : () -> !test.custom_assembly_format_dyntype + return +} + +//===----------------------------------------------------------------------===// +// Dynamic op +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @succeededDynamicOpVerifier +func @succeededDynamicOpVerifier(%a: f32) { + // CHECK: "test.generic_dynamic_op"() : () -> () + // CHECK-NEXT: %{{.*}} = "test.generic_dynamic_op"(%{{.*}}) : (f32) -> f64 + // CHECK-NEXT: %{{.*}}:2 = "test.one_operand_two_results"(%{{.*}}) : (f32) -> (f64, f64) + "test.generic_dynamic_op"() : () -> () + "test.generic_dynamic_op"(%a) : (f32) -> f64 + "test.one_operand_two_results"(%a) : (f32) -> (f64, f64) + return +} + +// ----- + +func @failedDynamicOpVerifier() { + // expected-error@+1 {{expected 1 operand, but had 0}} + "test.one_operand_two_results"() : () -> (f64, f64) + return +} + +// ----- + +func @failedDynamicOpVerifier2(%a: f32) { + // expected-error@+1 {{expected 2 results, but had 0}} + "test.one_operand_two_results"(%a) : (f32) -> () + return +} + +// ----- + +// CHECK-LABEL: func @customOpParserPrinter +func @customOpParserPrinter() { + // CHECK: test.custom_parser_printer_dynamic_op custom_keyword + test.custom_parser_printer_dynamic_op custom_keyword + return +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -21,6 +21,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/RegionKindInterface.h" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/FoldUtils.h" @@ -172,6 +173,54 @@ }; } // end anonymous namespace +//===----------------------------------------------------------------------===// +// Dynamic operations +//===----------------------------------------------------------------------===// + +std::unique_ptr getGenericDynamicOp(Dialect *dialect) { + return DynamicOpDefinition::get("generic_dynamic_op", dialect, + [](Operation *op) { return success(); }); +} + +std::unique_ptr +getOneOperandTwoResultsDynamicOp(Dialect *dialect) { + return DynamicOpDefinition::get( + "one_operand_two_results", dialect, [](Operation *op) { + if (op->getNumOperands() != 1) { + op->emitOpError() + << "expected 1 operand, but had " << op->getNumOperands(); + return failure(); + } + if (op->getNumResults() != 2) { + op->emitOpError() + << "expected 2 results, but had " << op->getNumResults(); + return failure(); + } + return success(); + }); +} + +std::unique_ptr +getCustomParserPrinterDynamicOp(Dialect *dialect) { + auto verifier = [](Operation *op) { + if (op->getNumOperands() == 0 && op->getNumResults() == 0) + return success(); + op->emitError() << "operation should have no operands and no results"; + return failure(); + }; + + auto parser = [](OpAsmParser &parser, OperationState &state) { + return parser.parseKeyword("custom_keyword"); + }; + + auto printer = [](Operation *op, OpAsmPrinter &printer) { + printer << op->getName() << " custom_keyword"; + }; + + return DynamicOpDefinition::get("custom_parser_printer_dynamic_op", dialect, + verifier, parser, printer); +} + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// @@ -206,6 +255,10 @@ #define GET_OP_LIST #include "TestOps.cpp.inc" >(); + addDynamicOp(getGenericDynamicOp(this)); + addDynamicOp(getOneOperandTwoResultsDynamicOp(this)); + addDynamicOp(getCustomParserPrinterDynamicOp(this)); + addInterfaces(); allowUnknownOperations(); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -30,6 +30,7 @@ let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; let hasOperationInterfaceFallback = 1; + let isExtensible = 1; let dependentDialects = ["::mlir::DLTIDialect"]; let extraClassDeclaration = [{ @@ -52,6 +53,10 @@ // Storage for a custom fallback interface. void *fallbackEffectOpInterfaces; + Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, + SetVector &stack) const; + void printTestType(Type type, DialectAsmPrinter &printer, + SetVector &stack) const; }]; } diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -215,6 +215,75 @@ #define GET_TYPEDEF_CLASSES #include "TestTypeDefs.cpp.inc" +//===----------------------------------------------------------------------===// +// Dynamic Types +//===----------------------------------------------------------------------===// + +namespace { +/// Define a singleton dynamic type. +std::unique_ptr +getSingletonDynamicType(Dialect *testDialect) { + return DynamicTypeDefinition::get( + "singleton_dyntype", testDialect, + [](function_ref emitError, + ArrayRef args) { + if (!args.empty()) { + emitError() << "expected 0 type arguments, but had " << args.size(); + return failure(); + } + return success(); + }); +} + +/// Define a dynamic type representing a pair. +std::unique_ptr +getPairDynamicType(Dialect *testDialect) { + return DynamicTypeDefinition::get( + "pair_dyntype", testDialect, + [](function_ref emitError, + ArrayRef args) { + if (args.size() != 2) { + emitError() << "expected 2 type arguments, but had " << args.size(); + return failure(); + } + return success(); + }); +} + +std::unique_ptr +getCustomAssemblyFormatDynamicType(Dialect *testDialect) { + auto verifier = [](function_ref emitError, + ArrayRef args) { + if (args.size() != 2) { + emitError() << "expected 2 type arguments, but had " << args.size(); + return failure(); + } + return success(); + }; + + auto parser = [](DialectAsmParser &parser, + std::vector &parsedParams) { + Attribute leftAttr, rightAttr; + if (parser.parseLess() || parser.parseAttribute(leftAttr) || + parser.parseColon() || parser.parseAttribute(rightAttr) || + parser.parseGreater()) + return failure(); + parsedParams.push_back(leftAttr); + parsedParams.push_back(rightAttr); + return success(); + }; + + auto printer = [](DialectAsmPrinter &printer, std::vector params) { + printer << "<" << params[0] << ":" << params[1] << ">"; + }; + + return DynamicTypeDefinition::get("custom_assembly_format_dyntype", + testDialect, std::move(verifier), + std::move(parser), std::move(printer)); +} + +} // namespace + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// @@ -224,10 +293,14 @@ #define GET_TYPEDEF_LIST #include "TestTypeDefs.cpp.inc" >(); + + addDynamicType(getSingletonDynamicType(this)); + addDynamicType(getPairDynamicType(this)); + addDynamicType(getCustomAssemblyFormatDynamicType(this)); } -static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, - SetVector &stack) { +Type TestDialect::parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, + SetVector &stack) const { StringRef typeTag; if (failed(parser.parseKeyword(&typeTag))) return Type(); @@ -239,6 +312,16 @@ return genType; } + { + Type dynType; + auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynType; + return Type(); + } + } + if (typeTag != "test_rec") { parser.emitError(parser.getNameLoc()) << "unknown type!"; return Type(); @@ -274,11 +357,14 @@ return parseTestType(getContext(), parser, stack); } -static void printTestType(Type type, DialectAsmPrinter &printer, - SetVector &stack) { +void TestDialect::printTestType(Type type, DialectAsmPrinter &printer, + SetVector &stack) const { if (succeeded(generatedTypePrinter(type, printer))) return; + if (succeeded(printIfDynamicType(type, printer))) + return; + auto rec = type.cast(); printer << "test_rec<" << rec.getName(); if (!stack.contains(rec)) { diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -65,9 +65,9 @@ /// {2}: initialization code that is emitted in the ctor body before calling /// initialize() static const char *const dialectDeclBeginStr = R"( -class {0} : public ::mlir::Dialect { +class {0} : public ::mlir::{3} { explicit {0}(::mlir::MLIRContext *context) - : ::mlir::Dialect(getDialectNamespace(), context, + : ::mlir::{3}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{ {2} initialize(); @@ -169,8 +169,10 @@ // Emit the start of the decl. std::string cppName = dialect.getCppClassName(); + StringRef superClassName = + dialect.isExtensible() ? "ExtensibleDialect" : "Dialect"; os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), - dependentDialectRegistrations); + dependentDialectRegistrations, superClassName); // Check for any attributes/types registered to this dialect. If there are, // add the hooks for parsing/printing.