diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h --- a/mlir/include/mlir/IR/DialectRegistry.h +++ b/mlir/include/mlir/IR/DialectRegistry.h @@ -26,6 +26,8 @@ using DialectAllocatorFunction = std::function; using DialectAllocatorFunctionRef = function_ref; +using DynamicDialectPopulationFunction = + std::function; //===----------------------------------------------------------------------===// // DialectExtension @@ -136,8 +138,14 @@ void insert(TypeID typeID, StringRef name, const DialectAllocatorFunction &ctor); - /// Return an allocation function for constructing the dialect identified by - /// its namespace, or nullptr if the namespace is not in this registry. + /// Add a new dynamic dialect constructor in the registry. + /// The given function is meant to populate the dialect with its types, + /// attributes, and ops. + void insertDynamic(StringRef name, + const DynamicDialectPopulationFunction &ctor); + + /// Return an allocation function for constructing the dialect identified + /// by its namespace, or nullptr if the namespace is not in this registry. DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const; // Register all dialects available in the current registry with the registry diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h --- a/mlir/include/mlir/IR/ExtensibleDialect.h +++ b/mlir/include/mlir/IR/ExtensibleDialect.h @@ -540,6 +540,28 @@ /// Owns the TypeID generated at runtime for operations. TypeIDAllocator typeIDAllocator; }; + +//===----------------------------------------------------------------------===// +// Dynamic dialect +//===----------------------------------------------------------------------===// + +/// A dialect that can be defined at runtime. +/// It can be extended with new operations, types, and attributes at runtime. +class DynamicDialect : public SelfOwningTypeID, public ExtensibleDialect { +public: + DynamicDialect(StringRef name, MLIRContext *ctx) + : ExtensibleDialect(name, ctx, SelfOwningTypeID::getTypeID()) {} + + TypeID getTypeID() { return SelfOwningTypeID::getTypeID(); } + + virtual Type parseType(DialectAsmParser &parser) const override; + virtual void printType(Type type, DialectAsmPrinter &printer) const override; + + virtual Attribute parseAttribute(DialectAsmParser &parser, + Type type) const override; + virtual void printAttribute(Attribute attr, + DialectAsmPrinter &printer) const override; +}; } // namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -24,6 +24,7 @@ class DiagnosticEngine; class Dialect; class DialectRegistry; +class DynamicDialect; class InFlightDiagnostic; class Location; class MLIRContextImpl; @@ -110,6 +111,11 @@ loadDialect(); } + /// Create and load a new dialect that will be populated with operations, + /// types, and attributes at runtime. Abort if a dialect with a similar name + /// is already loaded. + DynamicDialect *createDynamicDialect(StringRef dialectNamespace); + /// Load all dialects available in the registry in this context. void loadAllAvailableDialects(); diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -163,6 +163,25 @@ } } +void DialectRegistry::insertDynamic( + StringRef name, const DynamicDialectPopulationFunction &ctor) { + // This TypeID marks dynamic dialects. + // We cannot give a TypeID for the dialect yet, since the TypeID of a dynamic + // dialect is defined at its construction. + auto typeID = TypeID::get(); + std::string nameStr = name.str(); + + // Create the dialect, and then call ctor, which allocates its components. + auto constructor = [nameStr, ctor](MLIRContext *ctx) { + auto *dynDialect = ctx->createDynamicDialect(nameStr); + assert(dynDialect && "Dynamic dialect creation unexpectedly failed"); + ctor(ctx, dynDialect); + return (Dialect *)dynDialect; + }; + + insert(typeID, name, constructor); +} + void DialectRegistry::applyExtensions(Dialect *dialect) const { MLIRContext *ctx = dialect->getContext(); StringRef dialectName = dialect->getNamespace(); diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp --- a/mlir/lib/IR/ExtensibleDialect.cpp +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -498,3 +498,62 @@ } return failure(); } + +//===----------------------------------------------------------------------===// +// Dynamic dialect +//===----------------------------------------------------------------------===// + +Type DynamicDialect::parseType(DialectAsmParser &parser) const { + auto loc = parser.getCurrentLocation(); + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Type(); + + { + Type dynType; + auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynType; + return Type(); + } + } + + parser.emitError(loc, "expected dynamic type"); + return Type(); +} + +void DynamicDialect::printType(Type type, DialectAsmPrinter &printer) const { + auto wasDynamic = printIfDynamicType(type, printer); + (void)wasDynamic; + assert(succeeded(wasDynamic) && + "non-dynamic type defined in dynamic dialect"); +} + +Attribute DynamicDialect::parseAttribute(DialectAsmParser &parser, + Type type) const { + auto loc = parser.getCurrentLocation(); + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Attribute(); + + { + Attribute dynAttr; + auto parseResult = parseOptionalDynamicAttr(typeTag, parser, dynAttr); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynAttr; + return Attribute(); + } + } + + parser.emitError(loc, "expected dynamic attribute"); + return Attribute(); +} +void DynamicDialect::printAttribute(Attribute attr, + DialectAsmPrinter &printer) const { + auto wasDynamic = printIfDynamicAttr(attr, printer); + (void)wasDynamic; + assert(succeeded(wasDynamic) && + "non-dynamic attribute defined in dynamic dialect"); +} diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpImplementation.h" @@ -456,6 +457,27 @@ return dialect.get(); } +DynamicDialect *MLIRContext::createDynamicDialect(StringRef dialectNamespace) { + auto &impl = getImpl(); + // Return a null pointer if a dialect with the same namespace has already been + // registered. + auto *loadedDialect = getLoadedDialect(dialectNamespace); + if (loadedDialect) + return nullptr; + + auto *name = new (impl.abstractDialectSymbolAllocator.Allocate()) + std::string(dialectNamespace.str()); + + auto *dialect = new DynamicDialect(*name, this); + // This loads the dialect + (void)getOrLoadDialect(*name, dialect->getTypeID(), [dialect]() { + return std::unique_ptr(dialect); + }); + // This is the same result as `getOrLoadDialect` (if it didn't failed), + // since it has the same TypeID, and TypeIDs are unique. + return dialect; +} + void MLIRContext::loadAllAvailableDialects() { for (StringRef name : getAvailableDialects()) getOrLoadDialect(name); diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir --- a/mlir/test/IR/dynamic.mlir +++ b/mlir/test/IR/dynamic.mlir @@ -124,3 +124,19 @@ test.dynamic_custom_parser_printer custom_keyword return } + +//===----------------------------------------------------------------------===// +// Dynamic dialect +//===----------------------------------------------------------------------===// + +// ----- + +// Check that the verifier of a dynamic operation in a dynamic dialect +// can fail. This shows that the dialect is correctly registered. + + +func.func @failedDynamicDialectOpVerifier() { + // expected-error@+1 {{expected a single result, no operands and no regions}} + "test_dyn.one_result"() : () -> () + return +} \ No newline at end of file diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -10,6 +10,7 @@ add_subdirectory(SPIRV) add_subdirectory(Tensor) add_subdirectory(Test) +add_subdirectory(TestDyn) add_subdirectory(Tosa) add_subdirectory(Transform) add_subdirectory(Vector) diff --git a/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt b/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_dialect_library(MLIRTestDynDialect + TestDynDialect.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/mlir/test/lib/Dialect/TestDyn/TestDynDialect.h b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.h @@ -0,0 +1,18 @@ +//===- TestDynDialect.h ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a fake 'test_dyn' dynamic dialect that is used to test the +// registration of dynamic dialects. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/DialectRegistry.h" + +namespace test_dynamic { +void registerTestDynDialect(::mlir::DialectRegistry ®istry); +} diff --git a/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp @@ -0,0 +1,39 @@ +//===- TestDynDialect.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a fake 'test_dyn' dynamic dialect that is used to test the +// registration of dynamic dialects. +// +//===----------------------------------------------------------------------===// + +#include "TestDynDialect.h" +#include "mlir/IR/ExtensibleDialect.h" + +using namespace mlir; + +namespace test_dynamic { + +void registerTestDynDialect(::mlir::DialectRegistry ®istry) { + registry.insertDynamic("test_dyn", [](::mlir::MLIRContext *ctx, + ::mlir::DynamicDialect *testDyn) { + auto dynOp = DynamicOpDefinition::get( + "one_result", testDyn, + [](Operation *op) { + if (op->getNumOperands() != 0 || op->getNumResults() != 1 || + op->getNumRegions() != 0) { + op->emitError( + "expected a single result, no operands and no regions"); + return failure(); + } + return success(); + }, + [](Operation *op) { return success(); }); + testDyn->registerDynamicOp(std::move(dynOp)); + }); +} +} // namespace test_dynamic diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -34,6 +34,7 @@ // CHECK-NEXT: spv // CHECK-NEXT: tensor // CHECK-NEXT: test +// CHECK-NEXT: test_dyn // CHECK-NEXT: tosa // CHECK-NEXT: transform // CHECK-NEXT: vector diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -26,6 +26,7 @@ MLIRTensorTestPasses MLIRTestAnalysis MLIRTestDialect + MLIRTestDynDialect MLIRTestIR MLIRTestPass MLIRTestPDLL diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -121,6 +121,10 @@ void registerTestTransformDialectExtension(DialectRegistry &); } // namespace test +namespace test_dynamic { +void registerTestDynDialect(DialectRegistry &); +} // namespace test_dynamic + #ifdef MLIR_INCLUDE_TESTS void registerTestPasses() { registerCloneTestPasses(); @@ -221,6 +225,7 @@ #ifdef MLIR_INCLUDE_TESTS ::test::registerTestDialect(registry); ::test::registerTestTransformDialectExtension(registry); + ::test_dynamic::registerTestDynDialect(registry); #endif return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,