diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -298,6 +298,10 @@ // If this dialect overrides the hook for op interface fallback. bit hasOperationInterfaceFallback = 0; + // If this dialect should use default generated attribute parser boilerplate: + // it'll dispatch the parsing to every individual attributes directly. + bit useDefaultAttributePrinterParser = 0; + // If this dialect overrides the hook for canonicalization patterns. bit hasCanonicalizer = 0; diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -73,6 +73,10 @@ /// Returns true if this dialect has fallback interfaces for its operations. bool hasOperationInterfaceFallback() const; + /// Returns true if this dialect should generate the default dispatch for + /// attribute printing/parsing. + bool useDefaultAttributePrinterParser() const; + // Returns whether two dialects are equal by checking the equality of the // underlying record. bool operator==(const Dialect &other) const; diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -90,6 +90,10 @@ return def->getValueAsBit("hasOperationInterfaceFallback"); } +bool Dialect::useDefaultAttributePrinterParser() const { + return def->getValueAsBit("useDefaultAttributePrinterParser"); +} + Dialect::EmitPrefix Dialect::getEmitAccessorPrefix() const { int prefix = def->getValueAsInt("emitAccessorPrefix"); if (prefix < 0 || prefix > static_cast(EmitPrefix::Both)) diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -197,24 +197,3 @@ #include "TestAttrDefs.cpp.inc" >(); } - -Attribute TestDialect::parseAttribute(DialectAsmParser &parser, - Type type) const { - StringRef attrTag; - if (failed(parser.parseKeyword(&attrTag))) - return Attribute(); - { - Attribute attr; - auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); - if (parseResult.hasValue()) - return attr; - } - parser.emitError(parser.getNameLoc(), "unknown test attribute"); - return Attribute(); -} - -void TestDialect::printAttribute(Attribute attr, - DialectAsmPrinter &printer) const { - if (succeeded(generatedAttributePrinter(attr, printer))) - return; -} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -34,6 +34,7 @@ let hasRegionResultAttrVerify = 1; let hasOperationInterfaceFallback = 1; let hasNonDefaultDestructor = 1; + let useDefaultAttributePrinterParser = 1; let dependentDialects = ["::mlir::DLTIDialect"]; let extraClassDeclaration = [{ diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -489,6 +489,34 @@ ::mlir::{0} &value) {{ )"; +/// The code block for default attribute parser/printer dispatch boilerplate. +/// {0}: the dialect fully qualified class name. +static const char *const dialectDefaultAttrPrinterParserDispatch = R"( +/// Parse an attribute registered to this dialect. +::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser, + ::mlir::Type type) const {{ + ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); + ::llvm::StringRef attrTag; + if (failed(parser.parseKeyword(&attrTag))) + return {{}; + {{ + ::mlir::Attribute attr; + auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); + if (parseResult.hasValue()) + return attr; + } + parser.emitError(typeLoc) << "unknown attribute `" + << attrTag << "` in dialect `" << getNamespace() << "`"; + return {{}; +} +/// Print an attribute registered to this dialect. +void {0}::printAttribute(::mlir::Attribute attr, + ::mlir::DialectAsmPrinter &printer) const {{ + if (succeeded(generatedAttributePrinter(attr, printer))) + return; +} +)"; + /// The code block used to start the auto-generated printer function. /// /// {0}: The name of the base value type, e.g. Attribute or Type. @@ -952,6 +980,12 @@ << "::" << def.getCppClassName() << ")\n"; } + // Emit the default parser/printer for Attributes if the dialect asked for it. + if (valueType == "Attribute" && + defs.front().getDialect().useDefaultAttributePrinterParser()) + os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, + defs.front().getDialect().getCppClassName()); + return false; } diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -208,7 +208,7 @@ // Check for any attributes/types registered to this dialect. If there are, // add the hooks for parsing/printing. - if (!dialectAttrs.empty()) + if (!dialectAttrs.empty() || dialect.useDefaultAttributePrinterParser()) os << attrParserDecl; if (!dialectTypes.empty()) os << typeParserDecl;