diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1758,11 +1758,10 @@ if (succeeded(printAlias(attr))) return; - if (!isa(attr.getDialect())) - return printDialectAttribute(attr); - auto attrType = attr.getType(); - if (auto opaqueAttr = attr.dyn_cast()) { + if (!isa(attr.getDialect())) { + printDialectAttribute(attr); + } else if (auto opaqueAttr = attr.dyn_cast()) { printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), opaqueAttr.getAttrData()); } else if (attr.isa()) { diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -55,7 +55,7 @@ def AttrWithSelfTypeParam : Test_Attr<"AttrWithSelfTypeParam"> { let mnemonic = "attr_with_self_type_param"; let parameters = (ins AttributeSelfTypeParameter<"">:$type); - let hasCustomAssemblyFormat = 1; + let assemblyFormat = ""; } // An attribute testing AttributeSelfTypeParameter. @@ -205,4 +205,13 @@ let assemblyFormat = "`<` $int_type `,` $any_type `>`"; } +// Test self type parameter with assembly format. +def TestAttrSelfTypeParameterFormat + : Test_Attr<"TestAttrSelfTypeParameterFormat"> { + let parameters = (ins "int":$a, AttributeSelfTypeParameter<"">:$type); + + let mnemonic = "attr_self_type_format"; + let assemblyFormat = "`<` $a `>`"; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -27,21 +27,6 @@ using namespace mlir; using namespace test; -//===----------------------------------------------------------------------===// -// AttrWithSelfTypeParamAttr -//===----------------------------------------------------------------------===// - -Attribute AttrWithSelfTypeParamAttr::parse(AsmParser &parser, Type type) { - Type selfType; - if (parser.parseType(selfType)) - return Attribute(); - return get(parser.getContext(), selfType); -} - -void AttrWithSelfTypeParamAttr::print(AsmPrinter &printer) const { - printer << " " << getType(); -} - //===----------------------------------------------------------------------===// // AttrWithTypeBuilderAttr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -14,7 +14,9 @@ // CHECK: #test.attr_params<42, 24> attr3 = #test.attr_params<42, 24>, // CHECK: #test.attr_with_type> - attr4 = #test.attr_with_type> + attr4 = #test.attr_with_type>, + // CHECK: #test.attr_self_type_format<5> : i32 + attr5 = #test.attr_self_type_format<5> : i32 } // CHECK-LABEL: @test_roundtrip_default_parsers_struct diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -162,6 +162,17 @@ let assemblyFormat = "params"; } +/// Test attribute with self type parameter + +// ATTR: TestGAttr::parse +// ATTR: return TestGAttr::get +// ATTR: odsType +def AttrD : TestAttr<"TestG"> { + let parameters = (ins "int":$a, AttributeSelfTypeParameter<"">:$type); + let mnemonic = "attr_d"; + let assemblyFormat = "$a"; +} + /// Test type parser and printer that mix variables and struct are generated /// correctly. diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir --- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir @@ -4,10 +4,10 @@ // CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]> func.func private @compoundA() attributes {foo = #test.cmpnd_a<1, !test.smpla, [5, 6]>} -// CHECK: test.result_has_same_type_as_attr #test<"attr_with_self_type_param i32"> -> i32 -%a = test.result_has_same_type_as_attr #test<"attr_with_self_type_param i32"> -> i32 +// CHECK: test.result_has_same_type_as_attr #test.attr_with_self_type_param : i32 -> i32 +%a = test.result_has_same_type_as_attr #test.attr_with_self_type_param : i32 -> i32 -// CHECK: test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16 +// CHECK: test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> : i16 -> i16 %b = test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16 // CHECK-LABEL: @qualifiedAttr() diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -234,7 +234,7 @@ const AttrOrTypeDef &def; /// The list of top-level format elements returned by the assembly format /// parser. - std::vector elements; + const std::vector elements; /// Flags for printing spaces. bool shouldEmitSpace = false; @@ -260,6 +260,8 @@ // a loop (parsers return FailureOr anyways). ArrayRef params = def.getParameters(); for (const AttrOrTypeParameter ¶m : params) { + if (isa(param)) + continue; os << formatv("::mlir::FailureOr<{0}> _result_{1};\n", param.getCppStorageType(), param.getName()); } @@ -277,10 +279,9 @@ // Emit an assert for each mandatory parameter. Triggering an assert means // the generated parser is incorrect (i.e. there is a bug in this code). for (const AttrOrTypeParameter ¶m : params) { - if (!param.isOptional()) { - os << formatv("assert(::mlir::succeeded(_result_{0}));\n", - param.getName()); - } + if (param.isOptional() || isa(param)) + continue; + os << formatv("assert(::mlir::succeeded(_result_{0}));\n", param.getName()); } // Generate call to the attribute or type builder. Use the checked getter @@ -293,15 +294,18 @@ def.getCppClassName()); } for (const AttrOrTypeParameter ¶m : params) { + os << ",\n "; if (param.isOptional()) { - os << formatv(",\n _result_{0}.getValueOr(", param.getName()); + os << formatv("_result_{0}.getValueOr(", param.getName()); if (Optional defaultValue = param.getDefaultValue()) os << tgfmt(*defaultValue, &ctx); else os << param.getCppStorageType() << "()"; os << ")"; + } else if (isa(param)) { + os << tgfmt("$_type", &ctx); } else { - os << formatv(",\n *_result_{0}", param.getName()); + os << formatv("*_result_{0}", param.getName()); } } os << ");"; @@ -666,7 +670,7 @@ ctx.addSubst("_ctx", "getContext()"); os.indent(); - /// Generate printers. + // Generate printers. shouldEmitSpace = true; lastWasPunctuation = false; for (FormatElement *el : elements) @@ -904,10 +908,18 @@ ArrayRef elements) { // Check that all parameters are referenced in the format. for (auto &it : llvm::enumerate(def.getParameters())) { - if (!it.value().isOptional() && !seenParams.test(it.index())) { + if (it.value().isOptional()) + continue; + if (!seenParams.test(it.index())) { + if (isa(it.value())) + continue; return emitError(loc, "format is missing reference to parameter: " + it.value().getName()); } + if (isa(it.value())) { + return emitError(loc, + "unexpected self type parameter in assembly format"); + } } if (elements.empty()) return success();