diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td @@ -29,6 +29,7 @@ }]; let hasConstantMaterializer = 1; + let useDefaultTypePrinterParser = 1; } #endif // MLIR_DIALECT_EMITC_IR_EMITCBASE diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -302,6 +302,10 @@ // it'll dispatch the parsing to every individual attributes directly. bit useDefaultAttributePrinterParser = 0; + // If this dialect should use default generated type parser boilerplate: + // it'll dispatch the parsing to every individual types directly. + bit useDefaultTypePrinterParser = 0; + // If this dialect overrides the hook for canonicalization patterns. bit hasCanonicalizer = 0; diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -78,6 +78,10 @@ /// attribute printing/parsing. bool useDefaultAttributePrinterParser() const; + /// Returns true if this dialect should generate the default dispatch for + /// type printing/parsing. + bool useDefaultTypePrinterParser() const; + // Returns whether two dialects are equal by checking the equality of the // underlying record. bool operator==(const Dialect &other) const; diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -227,25 +227,6 @@ return get(parser.getContext(), value); } -Type EmitCDialect::parseType(DialectAsmParser &parser) const { - llvm::SMLoc typeLoc = parser.getCurrentLocation(); - StringRef mnemonic; - if (parser.parseKeyword(&mnemonic)) - return Type(); - Type genType; - OptionalParseResult parseResult = - generatedTypeParser(parser, mnemonic, genType); - if (parseResult.hasValue()) - return genType; - parser.emitError(typeLoc, "unknown type in EmitC dialect"); - return Type(); -} - -void EmitCDialect::printType(Type type, DialectAsmPrinter &os) const { - if (failed(generatedTypePrinter(type, os))) - llvm_unreachable("unexpected 'EmitC' type kind"); -} - void emitc::OpaqueType::print(DialectAsmPrinter &printer) const { printer << "opaque<\""; llvm::printEscapedString(getValue(), printer.getStream()); diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -94,6 +94,10 @@ return def->getValueAsBit("useDefaultAttributePrinterParser"); } +bool Dialect::useDefaultTypePrinterParser() const { + return def->getValueAsBit("useDefaultTypePrinterParser"); +} + Dialect::EmitPrefix Dialect::getEmitAccessorPrefix() const { int prefix = def->getValueAsInt("emitAccessorPrefix"); if (prefix < 0 || prefix > static_cast(EmitPrefix::Both)) diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -528,6 +528,32 @@ } )"; +/// The code block for default type parser/printer dispatch boilerplate. +/// {0}: the dialect fully qualified class name. +static const char *const dialectDefaultTypePrinterParserDispatch = R"( +/// Parse a type registered to this dialect. +::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{ + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + StringRef mnemonic; + if (parser.parseKeyword(&mnemonic)) + return Type(); + Type genType; + OptionalParseResult parseResult = + generatedTypeParser(parser, mnemonic, genType); + if (parseResult.hasValue()) + return genType; + parser.emitError(typeLoc) << "unknown type `" + << mnemonic << "` in dialect `" << getNamespace() << "`"; + return {{}; +} +/// Print a type registered to this dialect. +void {0}::printType(::mlir::Type type, + ::mlir::DialectAsmPrinter &printer) const {{ + if (succeeded(generatedTypePrinter(type, printer))) + return; +} +)"; + /// The code block used to start the auto-generated printer function. /// /// {0}: The name of the base value type, e.g. Attribute or Type. @@ -1020,6 +1046,12 @@ os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, defs.front().getDialect().getCppClassName()); + // Emit the default parser/printer for Types if the dialect asked for it. + if (valueType == "Type" && + defs.front().getDialect().useDefaultTypePrinterParser()) + os << llvm::formatv(dialectDefaultTypePrinterParserDispatch, + defs.front().getDialect().getCppClassName()); + return false; } diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -210,7 +210,7 @@ // add the hooks for parsing/printing. if (!dialectAttrs.empty() || dialect.useDefaultAttributePrinterParser()) os << attrParserDecl; - if (!dialectTypes.empty()) + if (!dialectTypes.empty() || dialect.useDefaultTypePrinterParser()) os << typeParserDecl; // Add the decls for the various features of the dialect.