diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -21,9 +21,6 @@ def SimpleTypeA : Test_Type<"SimpleA"> { let mnemonic = "smpla"; - - let printer = [{ $_printer << "smpla"; }]; - let parser = [{ return get($_ctxt); }]; } // A more complex parameterized type. diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp @@ -537,12 +537,21 @@ os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext* " "ctxt, " "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n"; - for (const TypeDef &type : types) - if (type.getMnemonic()) + for (const TypeDef &type : types) { + if (type.getMnemonic()) { os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return " - "{0}::{1}::parse(ctxt, parser);\n", + "{0}::{1}::", type.getDialect().getCppNamespace(), type.getCppClassName()); + + // If the type has no parameters and no parser code, just invoke a normal + // `get`. + if (type.getNumParameters() == 0 && !type.getParserCode()) + os << "get(ctxt);\n"; + else + os << "parse(ctxt, parser);\n"; + } + } os << " return ::mlir::Type();\n"; os << "}\n\n"; @@ -551,17 +560,26 @@ os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type " "type, " "::mlir::DialectAsmPrinter& printer) {\n" - << " ::mlir::LogicalResult found = ::mlir::success();\n" - << " ::llvm::TypeSwitch<::mlir::Type>(type)\n"; - for (const TypeDef &type : types) - if (type.getMnemonic()) - os << formatv(" .Case<{0}::{1}>([&](::mlir::Type t) {{ " - "t.dyn_cast<{0}::{1}>().print(printer); })\n", - type.getDialect().getCppNamespace(), - type.getCppClassName()); - os << " .Default([&found](::mlir::Type) { found = ::mlir::failure(); " - "});\n" - << " return found;\n" + << " return ::llvm::TypeSwitch<::mlir::Type, " + "::mlir::LogicalResult>(type)\n"; + for (const TypeDef &type : types) { + if (Optional mnemonic = type.getMnemonic()) { + StringRef cppNamespace = type.getDialect().getCppNamespace(); + StringRef cppClassName = type.getCppClassName(); + os << formatv(" .Case<{0}::{1}>([&]({0}::{1} t) {{\n ", + cppNamespace, cppClassName); + + // If the type has no parameters and no printer code, just print the + // mnemonic. + if (type.getNumParameters() == 0 && !type.getParserCode()) + os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace, + cppClassName); + else + os << "t.print(printer);"; + os << "\n return ::mlir::success();\n })\n"; + } + } + os << " .Default([](::mlir::Type) { return ::mlir::failure(); });\n" << "}\n\n"; }