diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h --- a/mlir/include/mlir/IR/ExtensibleDialect.h +++ b/mlir/include/mlir/IR/ExtensibleDialect.h @@ -540,6 +540,29 @@ /// Owns the TypeID generated at runtime for operations. TypeIDAllocator typeIDAllocator; }; + +//===----------------------------------------------------------------------===// +// Dynamic dialect +//===----------------------------------------------------------------------===// + +/// A dialect that can be defined at runtime. +/// It can be extended with new operations, types, and attributes at runtime. +class DynamicDialect : public SelfOwningTypeID, public ExtensibleDialect { +public: + DynamicDialect(StringRef name, MLIRContext *ctx) + : SelfOwningTypeID(), + ExtensibleDialect(name, ctx, SelfOwningTypeID::getTypeID()) {} + + TypeID getTypeID() { return SelfOwningTypeID::getTypeID(); } + + virtual Type parseType(DialectAsmParser &parser) const override; + virtual void printType(Type type, DialectAsmPrinter &printer) const override; + + virtual Attribute parseAttribute(DialectAsmParser &parser, + Type type) const override; + virtual void printAttribute(Attribute attr, + DialectAsmPrinter &printer) const override; +}; } // namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -24,6 +24,7 @@ class DiagnosticEngine; class Dialect; class DialectRegistry; +class DynamicDialect; class InFlightDiagnostic; class Location; class MLIRContextImpl; @@ -110,6 +111,11 @@ loadDialect(); } + /// Create and load a new dialect that will be populated with operations, + /// types, and attributes at runtime. Abort if a dialect with a similar name + /// is already loaded. + DynamicDialect *createDynamicDialect(StringRef dialectNamespace); + /// Load all dialects available in the registry in this context. void loadAllAvailableDialects(); diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp --- a/mlir/lib/IR/ExtensibleDialect.cpp +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -498,3 +498,60 @@ } return failure(); } + +//===----------------------------------------------------------------------===// +// Dynamic dialect +//===----------------------------------------------------------------------===// + +Type DynamicDialect::parseType(DialectAsmParser &parser) const { + auto loc = parser.getCurrentLocation(); + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Type(); + + { + Type dynType; + auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynType; + return Type(); + } + } + + parser.emitError(loc, "Expected dynamic type"); + return Type(); +} + +void DynamicDialect::printType(Type type, DialectAsmPrinter &printer) const { + auto wasDynamic = printIfDynamicType(type, printer); + assert(succeeded(wasDynamic) && + "non-dynamic type defined in dynamic dialect"); +} + +Attribute DynamicDialect::parseAttribute(DialectAsmParser &parser, + Type type) const { + auto loc = parser.getCurrentLocation(); + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Attribute(); + + { + Attribute dynAttr; + auto parseResult = parseOptionalDynamicAttr(typeTag, parser, dynAttr); + if (parseResult.hasValue()) { + if (succeeded(parseResult.getValue())) + return dynAttr; + return Attribute(); + } + } + + parser.emitError(loc, "Expected dynamic attribute"); + return Attribute(); +} +void DynamicDialect::printAttribute(Attribute attr, + DialectAsmPrinter &printer) const { + auto wasDynamic = printIfDynamicAttr(attr, printer); + assert(succeeded(wasDynamic) && + "non-dynamic attribute defined in dynamic dialect"); +} diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpImplementation.h" @@ -456,6 +457,18 @@ return dialect.get(); } +DynamicDialect *MLIRContext::createDynamicDialect(StringRef dialectNamespace) { + auto name = StringAttr::get(this, dialectNamespace); + auto *dialect = new DynamicDialect(name, this); + auto *loadedDialect = + getOrLoadDialect(name, dialect->getTypeID(), [&dialect]() { + return std::unique_ptr(dialect); + }); + // This is the same result as `getOrLoadDialect` (if it didn't failed), + // since it has the same TypeID, and TypeIDs are unique. + return dialect; +} + void MLIRContext::loadAllAvailableDialects() { for (StringRef name : getAvailableDialects()) getOrLoadDialect(name);