diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -979,6 +979,10 @@ def StrAttr : StringBasedAttr()">, "string attribute">; +// A string attribute that represents the name of a symbol. +def SymbolNameAttr : StringBasedAttr()">, + "string attribute">; + // String attribute that has a specific value type. class TypedStrAttr : StringBasedAttr()">, diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1298,6 +1298,18 @@ let assemblyFormat = "($opt_attr^)? attr-dict"; } +// Test that we format symbol name attributes properly. +def FormatSymbolNameAttrOp : TEST_Op<"format_symbol_name_attr_op"> { + let arguments = (ins SymbolNameAttr:$attr); + let assemblyFormat = "$attr attr-dict"; +} + +// Test that we format optional symbol name attributes properly. +def FormatOptSymbolNameAttrOp : TEST_Op<"format_opt_symbol_name_attr_op"> { + let arguments = (ins OptionalAttr:$opt_attr); + let assemblyFormat = "($opt_attr^)? attr-dict"; +} + // Test that we elide attributes that are within the syntax. def FormatAttrDictWithKeywordOp : TEST_Op<"format_attr_dict_w_keyword"> { let arguments = (ins I64Attr:$attr, OptionalAttr:$opt_attr); diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -22,6 +22,15 @@ test.format_opt_attr_op_b 10 test.format_opt_attr_op_b +// CHECK: test.format_symbol_name_attr_op @name +// CHECK-NOT: {attr +test.format_symbol_name_attr_op @name + +// CHECK: test.format_symbol_name_attr_op @opt_name +// CHECK-NOT: {attr +test.format_symbol_name_attr_op @opt_name +test.format_opt_symbol_name_attr_op + // CHECK: test.format_attr_dict_w_keyword attributes {attr = 10 : i64} test.format_attr_dict_w_keyword attributes {attr = 10 : i64} diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -396,6 +396,11 @@ !enumAttr->getConstBuilderTemplate().empty(); } +/// Returns if we can format the given attribute as an SymbolNameAttr. +static bool canFormatSymbolNameAttr(const NamedAttribute *attr) { + return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr"; +} + /// The code snippet used to generate a parser call for an attribute. /// /// {0}: The name of the attribute. @@ -413,6 +418,19 @@ } )"; +/// The code snippet used to generate a parser call for a symbol name attribute. +/// +/// {0}: The name of the attribute. +const char *const symbolNameAttrParserCode = R"( + if (parser.parseSymbolName({0}Attr, "{0}", result.attributes)) + return failure(); +)"; +const char *const optionalSymbolNameAttrParserCode = R"( + // Parsing an optional symbol name doesn't fail, so no need to check the + // result. + (void)parser.parseOptionalSymbolName({0}Attr, "{0}", result.attributes); +)"; + /// The code snippet used to generate a parser call for an enum attribute. /// /// {0}: The name of the attribute. @@ -862,6 +880,14 @@ return; } + // Check to see if we can parse this as a symbol name attribute. + if (canFormatSymbolNameAttr(var)) { + body << formatv(var->attr.isOptional() ? optionalSymbolNameAttrParserCode + : symbolNameAttrParserCode, + var->name); + return; + } + // If this attribute has a buildable type, use that when parsing the // attribute. std::string attrTypeStr; @@ -1340,6 +1366,12 @@ return; } + // If we are formatting as a symbol name, handle it as a symbol name. + if (canFormatSymbolNameAttr(var)) { + body << " p.printSymbolName(" << var->name << "Attr().getValue());\n"; + return; + } + // Elide the attribute type if it is buildable. if (attr->getTypeBuilder()) body << " p.printAttributeWithoutType(" << var->name << "Attr());\n";