diff --git a/mlir/test/mlir-tblgen/default-type-attr-print-parser.td b/mlir/test/mlir-tblgen/default-type-attr-print-parser.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/default-type-attr-print-parser.td @@ -0,0 +1,76 @@ +// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=ATTR +// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=TYPE + +include "mlir/IR/OpBase.td" + +/// Test that attribute and type printers and parsers are correctly generated. +def Test_Dialect : Dialect { + let name = "TestDialect"; + let cppNamespace = "::test"; + + let useDefaultAttributePrinterParser = 1; + + let useDefaultTypePrinterParser = 1; +} + +class TestAttr : AttrDef; +class TestType : TypeDef; + +def AttrA : TestAttr<"AttrA"> { + let mnemonic = "attr_a"; +} + +// ATTR: namespace test { + +// ATTR: ::mlir::Attribute TestDialect::parseAttribute(::mlir::DialectAsmParser &parser, +// ATTR: ::mlir::Type type) const { +// ATTR: ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); +// ATTR: ::llvm::StringRef attrTag; +// ATTR: if (::mlir::failed(parser.parseKeyword(&attrTag))) +// ATTR: return {}; +// ATTR: { +// ATTR: ::mlir::Attribute attr; +// ATTR: auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); +// ATTR: if (parseResult.hasValue()) +// ATTR: return attr; +// ATTR: } +// ATTR: parser.emitError(typeLoc) << "unknown attribute `" +// ATTR: << attrTag << "` in dialect `" << getNamespace() << "`"; +// ATTR: return {} +// ATTR: } + +// ATTR: void TestDialect::printAttribute(::mlir::Attribute attr, +// ATTR: ::mlir::DialectAsmPrinter &printer) const { +// ATTR: if (::mlir::succeeded(generatedAttributePrinter(attr, printer))) +// ATTR: return; +// ATTR: } + +// ATTR: } // namespace test + +def TypeA : TestType<"TypeA"> { + let mnemonic = "type_a"; +} + +// TYPE: namespace test { + +// TYPE: ::mlir::Type TestDialect::parseType(::mlir::DialectAsmParser &parser) const { +// TYPE: ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); +// TYPE: ::llvm::StringRef mnemonic; +// TYPE: if (parser.parseKeyword(&mnemonic)) +// TYPE: return ::mlir::Type(); +// TYPE: ::mlir::Type genType; +// TYPE: auto parseResult = generatedTypeParser(parser, mnemonic, genType); +// TYPE: if (parseResult.hasValue()) +// TYPE: return genType; +// TYPE: parser.emitError(typeLoc) << "unknown type `" +// TYPE: << mnemonic << "` in dialect `" << getNamespace() << "`"; +// TYPE: return {}; +// TYPE: } + +// TYPE: void TestDialect::printType(::mlir::Type type, +// TYPE: ::mlir::DialectAsmPrinter &printer) const { +// TYPE: if (::mlir::succeeded(generatedTypePrinter(type, printer))) +// TYPE: return; +// TYPE: } + +// TYPE: } // namespace test diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -510,7 +510,7 @@ ::mlir::Type type) const {{ ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); ::llvm::StringRef attrTag; - if (failed(parser.parseKeyword(&attrTag))) + if (::mlir::failed(parser.parseKeyword(&attrTag))) return {{}; {{ ::mlir::Attribute attr; @@ -525,7 +525,7 @@ /// Print an attribute registered to this dialect. void {0}::printAttribute(::mlir::Attribute attr, ::mlir::DialectAsmPrinter &printer) const {{ - if (succeeded(generatedAttributePrinter(attr, printer))) + if (::mlir::succeeded(generatedAttributePrinter(attr, printer))) return; } )"; @@ -535,13 +535,12 @@ static const char *const dialectDefaultTypePrinterParserDispatch = R"( /// Parse a type registered to this dialect. ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{ - llvm::SMLoc typeLoc = parser.getCurrentLocation(); - StringRef mnemonic; + ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); + ::llvm::StringRef mnemonic; if (parser.parseKeyword(&mnemonic)) - return Type(); - Type genType; - OptionalParseResult parseResult = - generatedTypeParser(parser, mnemonic, genType); + return ::mlir::Type(); + ::mlir::Type genType; + auto parseResult = generatedTypeParser(parser, mnemonic, genType); if (parseResult.hasValue()) return genType; parser.emitError(typeLoc) << "unknown type `" @@ -551,7 +550,7 @@ /// Print a type registered to this dialect. void {0}::printType(::mlir::Type type, ::mlir::DialectAsmPrinter &printer) const {{ - if (succeeded(generatedTypePrinter(type, printer))) + if (::mlir::succeeded(generatedTypePrinter(type, printer))) return; } )"; @@ -1040,17 +1039,22 @@ << "::" << def.getCppClassName() << ")\n"; } - // Emit the default parser/printer for Attributes if the dialect asked for it. + Dialect firstDialect = defs.front().getDialect(); + // Emit the default parser/printer for Attributes if the dialect asked for + // it. if (valueType == "Attribute" && - defs.front().getDialect().useDefaultAttributePrinterParser()) + firstDialect.useDefaultAttributePrinterParser()) { + NamespaceEmitter nsEmitter(os, firstDialect); os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, - defs.front().getDialect().getCppClassName()); + firstDialect.getCppClassName()); + } // Emit the default parser/printer for Types if the dialect asked for it. - if (valueType == "Type" && - defs.front().getDialect().useDefaultTypePrinterParser()) + if (valueType == "Type" && firstDialect.useDefaultTypePrinterParser()) { + NamespaceEmitter nsEmitter(os, firstDialect); os << llvm::formatv(dialectDefaultTypePrinterParserDispatch, - defs.front().getDialect().getCppClassName()); + firstDialect.getCppClassName()); + } return false; }