diff --git a/mlir/docs/DefiningDialects.md b/mlir/docs/DefiningDialects.md --- a/mlir/docs/DefiningDialects.md +++ b/mlir/docs/DefiningDialects.md @@ -372,6 +372,30 @@ } ``` +### Defining a dynamic dialect + +Dynamic dialects are extensible dialects that can be defined at runtime. They +are only populated with dynamic operations, types, and attributes. They can be +registered in a `DialectRegistry` with `insertDynamic`. + +```c++ +auto populateDialect = [](MLIRContext *ctx, DynamicDialect* dialect) { + // Code that will be ran when the dynamic dialect is created and loaded. + // For instance, this is where we register the dynamic operations, types, and + // attributes of the dialect. + ... +} + +registry.insertDynamic("dialectName", populateDialect); +``` + +Once a dynamic dialect is registered in the `MLIRContext`, it can be retrieved +with `getOrLoadDialect`. + +```c++ +Dialect* dialect = ctx->getOrLoadDialect("dialectName"); +``` + ### Defining an operation at runtime The `DynamicOpDefinition` class represents the definition of an operation diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h --- a/mlir/include/mlir/IR/DialectRegistry.h +++ b/mlir/include/mlir/IR/DialectRegistry.h @@ -26,6 +26,8 @@ using DialectAllocatorFunction = std::function; using DialectAllocatorFunctionRef = function_ref; +using DynamicDialectPopulationFunction = + std::function; //===----------------------------------------------------------------------===// // DialectExtension @@ -136,8 +138,14 @@ void insert(TypeID typeID, StringRef name, const DialectAllocatorFunction &ctor); - /// Return an allocation function for constructing the dialect identified by - /// its namespace, or nullptr if the namespace is not in this registry. + /// Add a new dynamic dialect constructor in the registry. + /// The given function is meant to populate the dialect with its types, + /// attributes, and ops. + void insertDynamic(StringRef name, + const DynamicDialectPopulationFunction &ctor); + + /// Return an allocation function for constructing the dialect identified + /// by its namespace, or nullptr if the namespace is not in this registry. DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const; // Register all dialects available in the current registry with the registry diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h --- a/mlir/include/mlir/IR/ExtensibleDialect.h +++ b/mlir/include/mlir/IR/ExtensibleDialect.h @@ -540,6 +540,30 @@ /// Owns the TypeID generated at runtime for operations. TypeIDAllocator typeIDAllocator; }; + +//===----------------------------------------------------------------------===// +// Dynamic dialect +//===----------------------------------------------------------------------===// + +/// A dialect that can be defined at runtime. +/// It can be extended with new operations, types, and attributes at runtime. +class DynamicDialect : public SelfOwningTypeID, public ExtensibleDialect { +public: + DynamicDialect(StringRef name, MLIRContext *ctx); + + TypeID getTypeID() { return SelfOwningTypeID::getTypeID(); } + + /// Check if the dialect is an extensible dialect. + static bool classof(const Dialect *dialect); + + virtual Type parseType(DialectAsmParser &parser) const override; + virtual void printType(Type type, DialectAsmPrinter &printer) const override; + + virtual Attribute parseAttribute(DialectAsmParser &parser, + Type type) const override; + virtual void printAttribute(Attribute attr, + DialectAsmPrinter &printer) const override; +}; } // namespace mlir namespace llvm { @@ -551,6 +575,15 @@ return mlir::ExtensibleDialect::classof(&dialect); } }; + +/// Provide isa functionality for DynamicDialect. +/// This is to override the isa functionality for Dialect. +template <> +struct isa_impl { + static inline bool doit(const ::mlir::Dialect &dialect) { + return mlir::DynamicDialect::classof(&dialect); + } +}; } // namespace llvm #endif // MLIR_IR_EXTENSIBLEDIALECT_H diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -24,6 +24,7 @@ class DiagnosticEngine; class Dialect; class DialectRegistry; +class DynamicDialect; class InFlightDiagnostic; class Location; class MLIRContextImpl; @@ -110,6 +111,11 @@ loadDialect(); } + /// Get (or create) a dialect for the given name. + DynamicDialect * + getOrLoadDynamicDialect(StringRef dialectNamespace, + function_ref ctor); + /// Load all dialects available in the registry in this context. void loadAllAvailableDialects(); diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -163,6 +163,24 @@ } } +void DialectRegistry::insertDynamic( + StringRef name, const DynamicDialectPopulationFunction &ctor) { + // This TypeID marks dynamic dialects. + // We cannot give a TypeID for the dialect yet, since the TypeID of a dynamic + // dialect is defined at its construction. + auto typeID = TypeID::get(); + + // Create the dialect, and then call ctor, which allocates its components. + auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) { + auto *dynDialect = ctx->getOrLoadDynamicDialect( + nameStr, [ctx, ctor](DynamicDialect &dialect) { ctor(ctx, &dialect); }); + assert(dynDialect && "Dynamic dialect creation unexpectedly failed"); + return (Dialect *)dynDialect; + }; + + insert(typeID, name, constructor); +} + void DialectRegistry::applyExtensions(Dialect *dialect) const { MLIRContext *ctx = dialect->getContext(); StringRef dialectName = dialect->getNamespace(); diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp --- a/mlir/lib/IR/ExtensibleDialect.cpp +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -498,3 +498,84 @@ } return failure(); } + +//===----------------------------------------------------------------------===// +// Dynamic dialect +//===----------------------------------------------------------------------===// + +namespace { +/// Interface that can only be implemented by extensible dialects. +/// The interface is used to check if a dialect is extensible or not. +class IsDynamicDialect : public DialectInterface::Base { +public: + IsDynamicDialect(Dialect *dialect) : Base(dialect) {} + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsDynamicDialect) +}; +} // namespace + +DynamicDialect::DynamicDialect(StringRef name, MLIRContext *ctx) + : SelfOwningTypeID(), + ExtensibleDialect(name, ctx, SelfOwningTypeID::getTypeID()) { + addInterfaces(); +} + +bool DynamicDialect::classof(const Dialect *dialect) { + return const_cast(dialect) + ->getRegisteredInterface(); +} + +Type DynamicDialect::parseType(DialectAsmParser &parser) const { + auto loc = parser.getCurrentLocation(); + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Type(); + + { + Type dynType; + auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynType; + return Type(); + } + } + + parser.emitError(loc, "expected dynamic type"); + return Type(); +} + +void DynamicDialect::printType(Type type, DialectAsmPrinter &printer) const { + auto wasDynamic = printIfDynamicType(type, printer); + (void)wasDynamic; + assert(succeeded(wasDynamic) && + "non-dynamic type defined in dynamic dialect"); +} + +Attribute DynamicDialect::parseAttribute(DialectAsmParser &parser, + Type type) const { + auto loc = parser.getCurrentLocation(); + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Attribute(); + + { + Attribute dynAttr; + auto parseResult = parseOptionalDynamicAttr(typeTag, parser, dynAttr); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynAttr; + return Attribute(); + } + } + + parser.emitError(loc, "expected dynamic attribute"); + return Attribute(); +} +void DynamicDialect::printAttribute(Attribute attr, + DialectAsmPrinter &printer) const { + auto wasDynamic = printIfDynamicAttr(attr, printer); + (void)wasDynamic; + assert(succeeded(wasDynamic) && + "non-dynamic attribute defined in dynamic dialect"); +} diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpImplementation.h" @@ -456,6 +457,42 @@ return dialect.get(); } +DynamicDialect *MLIRContext::getOrLoadDynamicDialect( + StringRef dialectNamespace, function_ref ctor) { + auto &impl = getImpl(); + // Get the correct insertion position sorted by namespace. + auto dialectIt = impl.loadedDialects.find(dialectNamespace); + + if (dialectIt != impl.loadedDialects.end()) { + if (auto dynDialect = + llvm::dyn_cast(dialectIt->second.get())) + return dynDialect; + llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + + "' has already been registered"); + } + + LLVM_DEBUG(llvm::dbgs() << "Load new dynamic dialect in Context " + << dialectNamespace << "\n"); +#ifndef NDEBUG + if (impl.multiThreadedExecutionContext != 0) + llvm::report_fatal_error( + "Loading a dynamic dialect (" + dialectNamespace + + ") while in a multi-threaded execution context (maybe " + "the PassManager): this can indicate a " + "missing `dependentDialects` in a pass for example."); +#endif + + auto name = StringAttr::get(this, dialectNamespace); + auto *dialect = new DynamicDialect(name, this); + (void)getOrLoadDialect(name, dialect->getTypeID(), [&dialect, ctor]() { + ctor(*dialect); + return std::unique_ptr(dialect); + }); + // This is the same result as `getOrLoadDialect` (if it didn't failed), + // since it has the same TypeID, and TypeIDs are unique. + return dialect; +} + void MLIRContext::loadAllAvailableDialects() { for (StringRef name : getAvailableDialects()) getOrLoadDialect(name); @@ -504,8 +541,8 @@ impl->attributeUniquer.disableMultithreading(disable); impl->typeUniquer.disableMultithreading(disable); - // Destroy thread pool (stop all threads) if it is no longer needed, or create - // a new one if multithreading was re-enabled. + // Destroy thread pool (stop all threads) if it is no longer needed, or + // create a new one if multithreading was re-enabled. if (disable) { // If the thread pool is owned, explicitly set it to nullptr to avoid // keeping a dangling pointer around. If the thread pool is externally @@ -778,8 +815,8 @@ const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { const AbstractType *type = lookupMutable(typeID, context); if (!type) - llvm::report_fatal_error( - "Trying to create a Type that was not registered in this MLIRContext."); + llvm::report_fatal_error("Trying to create a Type that was not " + "registered in this MLIRContext."); return *type; } @@ -870,8 +907,8 @@ NoneType NoneType::get(MLIRContext *context) { if (NoneType cachedInst = context->getImpl().noneType) return cachedInst; - // Note: May happen when initializing the singleton attributes of the builtin - // dialect. + // Note: May happen when initializing the singleton attributes of the + // builtin dialect. return Base::get(context); } @@ -920,9 +957,9 @@ if (dialectNamePair.first.empty() || dialectNamePair.second.empty()) return; - // If one exists, we check to see if this dialect is loaded. If it is, we set - // the dialect now, if it isn't we record this storage for initialization - // later if the dialect ever gets loaded. + // If one exists, we check to see if this dialect is loaded. If it is, we + // set the dialect now, if it isn't we record this storage for + // initialization later if the dialect ever gets loaded. if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first))) return; @@ -956,8 +993,8 @@ /// Check whether the arguments passed to the AffineMap::get() are consistent. /// This method checks whether the highest index of dimensional identifier -/// present in result expressions is less than `dimCount` and the highest index -/// of symbolic identifier present in result expressions is less than +/// present in result expressions is less than `dimCount` and the highest +/// index of symbolic identifier present in result expressions is less than /// `symbolCount`. LLVM_ATTRIBUTE_UNUSED static bool willBeValidAffineMap(unsigned dimCount, unsigned symbolCount, @@ -969,7 +1006,8 @@ if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) { LLVM_DEBUG( llvm::dbgs() - << "maximum dimensional identifier position in result expression must " + << "maximum dimensional identifier position in result expression " + "must " "be less than `dimCount` and maximum symbolic identifier position " "in result expression must be less than `symbolCount`\n"); return false; diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir --- a/mlir/test/IR/dynamic.mlir +++ b/mlir/test/IR/dynamic.mlir @@ -124,3 +124,19 @@ test.dynamic_custom_parser_printer custom_keyword return } + +//===----------------------------------------------------------------------===// +// Dynamic dialect +//===----------------------------------------------------------------------===// + +// ----- + +// Check that the verifier of a dynamic operation in a dynamic dialect +// can fail. This shows that the dialect is correctly registered. + + +func.func @failedDynamicDialectOpVerifier() { + // expected-error@+1 {{expected a single result, no operands and no regions}} + "test_dyn.one_result"() : () -> () + return +} \ No newline at end of file diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -10,6 +10,7 @@ add_subdirectory(SPIRV) add_subdirectory(Tensor) add_subdirectory(Test) +add_subdirectory(TestDyn) add_subdirectory(Tosa) add_subdirectory(Transform) add_subdirectory(Vector) diff --git a/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt b/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_dialect_library(MLIRTestDynDialect + TestDynDialect.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp @@ -0,0 +1,37 @@ +//===- TestDynDialect.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a fake 'test_dyn' dynamic dialect that is used to test the +// registration of dynamic dialects. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/ExtensibleDialect.h" + +using namespace mlir; + +namespace test { + +void registerTestDynDialect(::mlir::DialectRegistry ®istry) { + registry.insertDynamic("test_dyn", [](::mlir::MLIRContext *ctx, + ::mlir::DynamicDialect *testDyn) { + auto opVerifier = [](Operation *op) -> LogicalResult { + if (op->getNumOperands() == 0 && op->getNumResults() == 1 && + op->getNumRegions() == 0) + return success(); + return op->emitError( + "expected a single result, no operands and no regions"); + }; + + auto opRegionVerifier = [](Operation *op) { return success(); }; + + testDyn->registerDynamicOp(DynamicOpDefinition::get( + "one_result", testDyn, opVerifier, opRegionVerifier)); + }); +} +} // namespace test diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -34,6 +34,7 @@ // CHECK-NEXT: spv // CHECK-NEXT: tensor // CHECK-NEXT: test +// CHECK-NEXT: test_dyn // CHECK-NEXT: tosa // CHECK-NEXT: transform // CHECK-NEXT: vector diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -26,6 +26,7 @@ MLIRTensorTestPasses MLIRTestAnalysis MLIRTestDialect + MLIRTestDynDialect MLIRTestIR MLIRTestPass MLIRTestPDLL diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -119,6 +119,7 @@ namespace test { void registerTestDialect(DialectRegistry &); void registerTestTransformDialectExtension(DialectRegistry &); +void registerTestDynDialect(DialectRegistry &); } // namespace test #ifdef MLIR_INCLUDE_TESTS @@ -221,6 +222,7 @@ #ifdef MLIR_INCLUDE_TESTS ::test::registerTestDialect(registry); ::test::registerTestTransformDialectExtension(registry); + ::test::registerTestDynDialect(registry); #endif return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,