diff --git a/mlir/docs/DefiningDialects.md b/mlir/docs/DefiningDialects.md --- a/mlir/docs/DefiningDialects.md +++ b/mlir/docs/DefiningDialects.md @@ -372,6 +372,30 @@ } ``` +### Defining a dynamic dialect + +Dynamic dialects are extensible dialects that can be defined at runtime. They +are only populated with dynamic operations, types, and attributes. They can be +registered in a `DialectRegistry` with `insertDynamic`. + +```c++ +auto populateDialect = [](MLIRContext *ctx, DynamicDialect* dialect) { + // Code that will be ran when the dynamic dialect is created and loaded. + // For instance, this is where we register the dynamic operations, types, and + // attributes of the dialect. + ... +} + +registry.insertDynamic("dialectName", populateDialect); +``` + +Once a dynamic dialect is registered in the `MLIRContext`, it can be retrieved +with `getOrLoadDialect`. + +```c++ +Dialect *dialect = ctx->getOrLoadDialect("dialectName"); +``` + ### Defining an operation at runtime The `DynamicOpDefinition` class represents the definition of an operation diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h --- a/mlir/include/mlir/IR/DialectRegistry.h +++ b/mlir/include/mlir/IR/DialectRegistry.h @@ -26,6 +26,8 @@ using DialectAllocatorFunction = std::function; using DialectAllocatorFunctionRef = function_ref; +using DynamicDialectPopulationFunction = + std::function; //===----------------------------------------------------------------------===// // DialectExtension @@ -135,8 +137,15 @@ void insert(TypeID typeID, StringRef name, const DialectAllocatorFunction &ctor); - /// Return an allocation function for constructing the dialect identified by - /// its namespace, or nullptr if the namespace is not in this registry. + /// Add a new dynamic dialect constructor in the registry. The constructor + /// provides as argument the created dynamic dialect, and is expected to + /// register the dialect types, attributes, and ops, using the + /// methods defined in ExtensibleDialect such as registerDynamicOperation. + void insertDynamic(StringRef name, + const DynamicDialectPopulationFunction &ctor); + + /// Return an allocation function for constructing the dialect identified + /// by its namespace, or nullptr if the namespace is not in this registry. DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const; // Register all dialects available in the current registry with the registry diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h --- a/mlir/include/mlir/IR/ExtensibleDialect.h +++ b/mlir/include/mlir/IR/ExtensibleDialect.h @@ -550,6 +550,30 @@ /// Owns the TypeID generated at runtime for operations. TypeIDAllocator typeIDAllocator; }; + +//===----------------------------------------------------------------------===// +// Dynamic dialect +//===----------------------------------------------------------------------===// + +/// A dialect that can be defined at runtime. It can be extended with new +/// operations, types, and attributes at runtime. +class DynamicDialect : public SelfOwningTypeID, public ExtensibleDialect { +public: + DynamicDialect(StringRef name, MLIRContext *ctx); + + TypeID getTypeID() { return SelfOwningTypeID::getTypeID(); } + + /// Check if the dialect is an extensible dialect. + static bool classof(const Dialect *dialect); + + virtual Type parseType(DialectAsmParser &parser) const override; + virtual void printType(Type type, DialectAsmPrinter &printer) const override; + + virtual Attribute parseAttribute(DialectAsmParser &parser, + Type type) const override; + virtual void printAttribute(Attribute attr, + DialectAsmPrinter &printer) const override; +}; } // namespace mlir namespace llvm { @@ -561,6 +585,15 @@ return mlir::ExtensibleDialect::classof(&dialect); } }; + +/// Provide isa functionality for DynamicDialect. +/// This is to override the isa functionality for Dialect. +template <> +struct isa_impl { + static inline bool doit(const ::mlir::Dialect &dialect) { + return mlir::DynamicDialect::classof(&dialect); + } +}; } // namespace llvm #endif // MLIR_IR_EXTENSIBLEDIALECT_H diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -24,6 +24,7 @@ class DiagnosticEngine; class Dialect; class DialectRegistry; +class DynamicDialect; class InFlightDiagnostic; class Location; class MLIRContextImpl; @@ -110,6 +111,11 @@ loadDialect(); } + /// Get (or create) a dynamic dialect for the given name. + DynamicDialect * + getOrLoadDynamicDialect(StringRef dialectNamespace, + function_ref ctor); + /// Load all dialects available in the registry in this context. void loadAllAvailableDialects(); diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectInterface.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/MapVector.h" @@ -167,6 +168,24 @@ } } +void DialectRegistry::insertDynamic( + StringRef name, const DynamicDialectPopulationFunction &ctor) { + // This TypeID marks dynamic dialects. We cannot give a TypeID for the + // dialect yet, since the TypeID of a dynamic dialect is defined at its + // construction. + TypeID typeID = TypeID::get(); + + // Create the dialect, and then call ctor, which allocates its components. + auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) { + auto *dynDialect = ctx->getOrLoadDynamicDialect( + nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); }); + assert(dynDialect && "Dynamic dialect creation unexpectedly failed"); + return dynDialect; + }; + + insert(typeID, name, constructor); +} + void DialectRegistry::applyExtensions(Dialect *dialect) const { MLIRContext *ctx = dialect->getContext(); StringRef dialectName = dialect->getNamespace(); diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp --- a/mlir/lib/IR/ExtensibleDialect.cpp +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -507,3 +507,84 @@ } return failure(); } + +//===----------------------------------------------------------------------===// +// Dynamic dialect +//===----------------------------------------------------------------------===// + +namespace { +/// Interface that can only be implemented by extensible dialects. +/// The interface is used to check if a dialect is extensible or not. +class IsDynamicDialect : public DialectInterface::Base { +public: + IsDynamicDialect(Dialect *dialect) : Base(dialect) {} + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsDynamicDialect) +}; +} // namespace + +DynamicDialect::DynamicDialect(StringRef name, MLIRContext *ctx) + : SelfOwningTypeID(), + ExtensibleDialect(name, ctx, SelfOwningTypeID::getTypeID()) { + addInterfaces(); +} + +bool DynamicDialect::classof(const Dialect *dialect) { + return const_cast(dialect) + ->getRegisteredInterface(); +} + +Type DynamicDialect::parseType(DialectAsmParser &parser) const { + auto loc = parser.getCurrentLocation(); + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Type(); + + { + Type dynType; + auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); + if (parseResult.has_value()) { + if (succeeded(parseResult.value())) + return dynType; + return Type(); + } + } + + parser.emitError(loc, "expected dynamic type"); + return Type(); +} + +void DynamicDialect::printType(Type type, DialectAsmPrinter &printer) const { + auto wasDynamic = printIfDynamicType(type, printer); + (void)wasDynamic; + assert(succeeded(wasDynamic) && + "non-dynamic type defined in dynamic dialect"); +} + +Attribute DynamicDialect::parseAttribute(DialectAsmParser &parser, + Type type) const { + auto loc = parser.getCurrentLocation(); + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Attribute(); + + { + Attribute dynAttr; + auto parseResult = parseOptionalDynamicAttr(typeTag, parser, dynAttr); + if (parseResult.has_value()) { + if (succeeded(parseResult.value())) + return dynAttr; + return Attribute(); + } + } + + parser.emitError(loc, "expected dynamic attribute"); + return Attribute(); +} +void DynamicDialect::printAttribute(Attribute attr, + DialectAsmPrinter &printer) const { + auto wasDynamic = printIfDynamicAttr(attr, printer); + (void)wasDynamic; + assert(succeeded(wasDynamic) && + "non-dynamic attribute defined in dynamic dialect"); +} diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpImplementation.h" @@ -455,6 +456,41 @@ return dialect.get(); } +DynamicDialect *MLIRContext::getOrLoadDynamicDialect( + StringRef dialectNamespace, function_ref ctor) { + auto &impl = getImpl(); + // Get the correct insertion position sorted by namespace. + auto dialectIt = impl.loadedDialects.find(dialectNamespace); + + if (dialectIt != impl.loadedDialects.end()) { + if (auto dynDialect = dyn_cast(dialectIt->second.get())) + return dynDialect; + llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + + "' has already been registered"); + } + + LLVM_DEBUG(llvm::dbgs() << "Load new dynamic dialect in Context " + << dialectNamespace << "\n"); +#ifndef NDEBUG + if (impl.multiThreadedExecutionContext != 0) + llvm::report_fatal_error( + "Loading a dynamic dialect (" + dialectNamespace + + ") while in a multi-threaded execution context (maybe " + "the PassManager): this can indicate a " + "missing `dependentDialects` in a pass for example."); +#endif + + auto name = StringAttr::get(this, dialectNamespace); + auto *dialect = new DynamicDialect(name, this); + (void)getOrLoadDialect(name, dialect->getTypeID(), [dialect, ctor]() { + ctor(dialect); + return std::unique_ptr(dialect); + }); + // This is the same result as `getOrLoadDialect` (if it didn't failed), + // since it has the same TypeID, and TypeIDs are unique. + return dialect; +} + void MLIRContext::loadAllAvailableDialects() { for (StringRef name : getAvailableDialects()) getOrLoadDialect(name); diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir --- a/mlir/test/IR/dynamic.mlir +++ b/mlir/test/IR/dynamic.mlir @@ -124,3 +124,18 @@ test.dynamic_custom_parser_printer custom_keyword return } + +//===----------------------------------------------------------------------===// +// Dynamic dialect +//===----------------------------------------------------------------------===// + +// ----- + +// Check that the verifier of a dynamic operation in a dynamic dialect +// can fail. This shows that the dialect is correctly registered. + +func.func @failedDynamicDialectOpVerifier() { + // expected-error@+1 {{expected a single result, no operands and no regions}} + "test_dyn.one_result"() : () -> () + return +} diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -11,6 +11,7 @@ add_subdirectory(SPIRV) add_subdirectory(Tensor) add_subdirectory(Test) +add_subdirectory(TestDyn) add_subdirectory(Tosa) add_subdirectory(Transform) add_subdirectory(Vector) diff --git a/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt b/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_dialect_library(MLIRTestDynDialect + TestDynDialect.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp @@ -0,0 +1,36 @@ +//===- TestDynDialect.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a fake 'test_dyn' dynamic dialect that is used to test the +// registration of dynamic dialects. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/ExtensibleDialect.h" + +using namespace mlir; + +namespace test { +void registerTestDynDialect(DialectRegistry ®istry) { + registry.insertDynamic( + "test_dyn", [](MLIRContext *ctx, DynamicDialect *testDyn) { + auto opVerifier = [](Operation *op) -> LogicalResult { + if (op->getNumOperands() == 0 && op->getNumResults() == 1 && + op->getNumRegions() == 0) + return success(); + return op->emitError( + "expected a single result, no operands and no regions"); + }; + + auto opRegionVerifier = [](Operation *op) { return success(); }; + + testDyn->registerDynamicOp(DynamicOpDefinition::get( + "one_result", testDyn, opVerifier, opRegionVerifier)); + }); +} +} // namespace test diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -34,6 +34,7 @@ // CHECK-NEXT: spv // CHECK-NEXT: tensor // CHECK-NEXT: test +// CHECK-NEXT: test_dyn // CHECK-NEXT: tosa // CHECK-NEXT: transform // CHECK-NEXT: vector diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -27,6 +27,7 @@ MLIRTensorTestPasses MLIRTestAnalysis MLIRTestDialect + MLIRTestDynDialect MLIRTestIR MLIRTestPass MLIRTestPDLL diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -121,6 +121,7 @@ namespace test { void registerTestDialect(DialectRegistry &); void registerTestTransformDialectExtension(DialectRegistry &); +void registerTestDynDialect(DialectRegistry &); } // namespace test #ifdef MLIR_INCLUDE_TESTS @@ -225,6 +226,7 @@ #ifdef MLIR_INCLUDE_TESTS ::test::registerTestDialect(registry); ::test::registerTestTransformDialectExtension(registry); + ::test::registerTestDynDialect(registry); #endif return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,