diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -651,6 +651,16 @@ - `input` must be either an operand or result [variable](#variables), the `operands` directive, or the `results` directive. +* `qualified` ( type_or_attribute ) + + - Wraps a `type` directive or an attribute parameter. + - Used to force printing the type or attribute prefixed with its dialect + and mnemonic. For example the `vector.multi_reduction` operation has a + `kind` attribute ; by default the declarative assembly will print: + `vector.multi_reduction , ...` but using `qualified($kind)` in the + declarative assembly format will print it instead as: + `vector.multi_reduction #vector.kind, ...`. + #### Literals A literal is either a keyword or punctuation surrounded by \`\`. diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md --- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md +++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md @@ -489,9 +489,11 @@ Attribute and type assembly formats have the following directives: -* `params`: capture all parameters of an attribute or type. -* `struct`: generate a "struct-like" parser and printer for a list of key-value - pairs. +* `params`: capture all parameters of an attribute or type. +* `qualified`: mark a parameter to be printed with its leading dialect and + mnemonic. +* `struct`: generate a "struct-like" parser and printer for a list of + key-value pairs. #### `params` Directive @@ -517,6 +519,34 @@ as an argument that refers to all parameters in place of explicitly listing all parameters as variables. +#### `qualified` Directive + +This directive can be used to wrap attribute or type parameters such that they +are printed in a fully qualified form, i.e., they include the dialect name and +mnemonic prefix. + +For example: + +```tablegen +def OuterType : TypeDef { + let parameters = (ins MyPairType:$inner); + let mnemonic = "outer"; + let assemblyFormat = "`<` pair `:` $inner `>`"; +} +def OuterQualifiedType : TypeDef { + let parameters = (ins MyPairType:$inner); + let mnemonic = "outer_qual"; + let assemblyFormat = "`<` pair `:` qualified($inner) `>`"; +} +``` + +In the IR, the types will appear as: + +```mlir +!my_dialect.outer> +!my_dialect.outer_qual> +``` + #### `struct` Directive The `struct` directive accepts a list of variables to capture and will generate diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -144,6 +144,14 @@ let assemblyFormat = "`<` `i` $inner `>`"; } +def CompoundNestedOuterQual : Test_Attr<"CompoundNestedOuterQual"> { + let mnemonic = "cmpnd_nested_outer_qual"; + + // List of type parameters. + let parameters = (ins CompoundNestedInner:$inner); + let assemblyFormat = "`<` `i` qualified($inner) `>`"; +} + def TestParamOne : AttrParameter<"int64_t", ""> {} def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1955,11 +1955,21 @@ let assemblyFormat = "`nested` $nested attr-dict-with-keyword"; } +def FormatQualifiedCompoundAttr : TEST_Op<"format_qual_cpmd_nested_attr"> { + let arguments = (ins CompoundNestedOuter:$nested); + let assemblyFormat = "`nested` qualified($nested) attr-dict-with-keyword"; +} + def FormatNestedType : TEST_Op<"format_cpmd_nested_type"> { let arguments = (ins CompoundNestedOuterType:$nested); let assemblyFormat = "$nested `nested` type($nested) attr-dict-with-keyword"; } +def FormatQualifiedNestedType : TEST_Op<"format_qual_cpmd_nested_type"> { + let arguments = (ins CompoundNestedOuterType:$nested); + let assemblyFormat = "$nested `nested` qualified(type($nested)) attr-dict-with-keyword"; +} + //===----------------------------------------------------------------------===// // Custom Directives diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -65,12 +65,20 @@ def CompoundNestedOuterType : Test_Type<"CompoundNestedOuter"> { let mnemonic = "cmpnd_nested_outer"; + // List of type parameters. + let parameters = (ins CompoundNestedInnerType:$inner); + let assemblyFormat = "`<` `i` $inner `>`"; +} + +def CompoundNestedOuterTypeQual : Test_Type<"CompoundNestedOuterQual"> { + let mnemonic = "cmpnd_nested_outer_qual"; + // List of type parameters. let parameters = ( ins CompoundNestedInnerType:$inner ); - let assemblyFormat = "`<` `i` $inner `>`"; + let assemblyFormat = "`<` `i` qualified($inner) `>`"; } // An example of how one could implement a standard integer. diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -301,6 +301,24 @@ // CHECK: test.format_cpmd_nested_attr nested >> test.format_cpmd_nested_attr nested >> +//----- + +// CHECK: test.format_qual_cpmd_nested_attr nested #test.cmpnd_nested_outer>> +test.format_qual_cpmd_nested_attr nested #test.cmpnd_nested_outer>> + +//----- + +// Check the `qualified` directive in the declarative assembly format. +// CHECK: @qualifiedCompoundNestedExplicit(%arg0: !test.cmpnd_nested_outer>>) +func @qualifiedCompoundNestedExplicit(%arg0: !test.cmpnd_nested_outer>>) -> () { + // Verify that the type prefix is not elided + // CHECK: format_qual_cpmd_nested_type %arg0 nested !test.cmpnd_nested_outer>> + test.format_qual_cpmd_nested_type %arg0 nested !test.cmpnd_nested_outer>> + return +} + +//----- + //===----------------------------------------------------------------------===// // Format custom directives //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir --- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir @@ -9,3 +9,7 @@ // CHECK: test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16 %b = test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16 + +// CHECK-LABEL: @qualifiedAttr() +// CHECK-SAME: #test.cmpnd_nested_outer_qual>> +func private @qualifiedAttr() attributes {foo = #test.cmpnd_nested_outer_qual>>} diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir --- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir @@ -29,6 +29,10 @@ return } +// CHECK-LABEL: @compoundNestedQual +// CHECK-SAME: !test.cmpnd_nested_outer_qual>> +func private @compoundNestedQual(%arg0: !test.cmpnd_nested_outer_qual>>) -> () + // CHECK: @testInt(%arg0: !test.int, %arg1: !test.int, %arg2: !test.int) func @testInt(%A : !test.int, %B : !test.int, %C : !test.int) { return diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -89,7 +89,15 @@ /// Get the parameter in the element. const AttrOrTypeParameter &getParam() const { return param; } + /// Indicate if this variable is printed "qualified" (that is it is + /// prefixed with the `#dialect.mnemonic`). + bool shouldBeQualified() { return shouldBeQualifiedFlag; } + void setShouldBeQualified(bool qualified = true) { + shouldBeQualifiedFlag = qualified; + } + private: + bool shouldBeQualifiedFlag = false; AttrOrTypeParameter param; }; @@ -166,6 +174,10 @@ static const char *const defaultParameterPrinter = "$_printer.printStrippedAttrOrType($_self)"; +/// Qualified printer for attribute or type parameters: it does not elide +/// dialect and mnemonic. +static const char *const qualifiedParameterPrinter = "$_printer << $_self"; + /// Print an error when failing to parse an element. /// /// $0: The parameter C++ class name. @@ -251,7 +263,7 @@ void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a variable. void genVariablePrinter(const AttrOrTypeParameter ¶m, FmtContext &ctx, - MethodBody &os); + MethodBody &os, bool printQualified = false); /// Generate the printer code for a `params` directive. void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a `struct` directive. @@ -435,7 +447,8 @@ if (auto *strct = dyn_cast(el)) return genStructPrinter(strct, ctx, os); if (auto *var = dyn_cast(el)) - return genVariablePrinter(var->getParam(), ctx, os); + return genVariablePrinter(var->getParam(), ctx, os, + var->shouldBeQualified()); llvm_unreachable("unknown format element"); } @@ -455,7 +468,8 @@ } void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter ¶m, - FmtContext &ctx, MethodBody &os) { + FmtContext &ctx, MethodBody &os, + bool printQualified) { /// Insert a space before the next parameter, if necessary. if (shouldEmitSpace || !lastWasPunctuation) os << tgfmt(" $_printer << ' ';\n", &ctx); @@ -464,7 +478,9 @@ ctx.withSelf(getParameterAccessorName(param.getName()) + "()"); os << " "; - if (auto printer = param.getPrinter()) + if (printQualified) + os << tgfmt(qualifiedParameterPrinter, &ctx) << ";\n"; + else if (auto printer = param.getPrinter()) os << tgfmt(*printer, &ctx) << ";\n"; else os << tgfmt(defaultParameterPrinter, &ctx) << ";\n"; @@ -546,6 +562,9 @@ FailureOr> parseDirective(ParserContext ctx); /// Parse a `params` directive. FailureOr> parseParamsDirective(); + /// Parse a `qualified` directive. + FailureOr> + parseQualifiedDirective(ParserContext ctx); /// Parse a `struct` directive. FailureOr> parseStructDirective(); @@ -643,6 +662,8 @@ FormatParser::parseDirective(ParserContext ctx) { switch (curToken.getKind()) { + case FormatToken::kw_qualified: + return parseQualifiedDirective(ctx); case FormatToken::kw_params: return parseParamsDirective(); case FormatToken::kw_struct: @@ -656,6 +677,24 @@ } } +FailureOr> +FormatParser::parseQualifiedDirective(ParserContext ctx) { + consumeToken(); + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before argument list"))) + return failure(); + FailureOr> var = parseElement(ctx); + if (failed(var)) + return var; + if (!isa(*var)) + return emitError("`qualified` argument list expected a variable"); + cast(var->get())->setShouldBeQualified(); + if (failed( + parseToken(FormatToken::r_paren, "expected ')' after argument list"))) + return failure(); + return var; +} + FailureOr> FormatParser::parseParamsDirective() { consumeToken(); /// Collect all of the attribute's or type's parameters. diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -59,6 +59,7 @@ kw_functional_type, kw_operands, kw_params, + kw_qualified, kw_ref, kw_regions, kw_results, diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -172,6 +172,7 @@ .Case("struct", FormatToken::kw_struct) .Case("successors", FormatToken::kw_successors) .Case("type", FormatToken::kw_type) + .Case("qualified", FormatToken::kw_qualified) .Default(FormatToken::identifier); return FormatToken(kind, str); } diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -117,6 +117,16 @@ bool isUnitAttr() const { return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr"; } + + /// Indicate if this attribute is printed "qualified" (that is it is + /// prefixed with the `#dialect.mnemonic`). + bool shouldBeQualified() { return shouldBeQualifiedFlag; } + void setShouldBeQualified(bool qualified = true) { + shouldBeQualifiedFlag = qualified; + } + +private: + bool shouldBeQualifiedFlag = false; }; /// This class represents a variable that refers to an operand argument. @@ -237,9 +247,18 @@ TypeDirective(std::unique_ptr arg) : operand(std::move(arg)) {} Element *getOperand() const { return operand.get(); } + /// Indicate if this type is printed "qualified" (that is it is + /// prefixed with the `!dialect.mnemonic`). + bool shouldBeQualified() { return shouldBeQualifiedFlag; } + void setShouldBeQualified(bool qualified = true) { + shouldBeQualifiedFlag = qualified; + } + private: /// The operand that is used to format the directive. std::unique_ptr operand; + + bool shouldBeQualifiedFlag = false; }; } // namespace @@ -658,6 +677,10 @@ {1}RawTypes[0] = type; } )"; +const char *const qualifiedTypeParserCode = R"( + if (parser.parseType({1}RawTypes[0])) + return ::mlir::failure(); +)"; /// The code snippet used to generate a parser call for a functional type. /// @@ -1296,7 +1319,8 @@ if (var->attr.isOptional()) { body << formatv(optionalAttrParserCode, var->name, attrTypeStr); } else { - if (var->attr.getStorageType() == "::mlir::Attribute") + if (attr->shouldBeQualified() || + var->attr.getStorageType() == "::mlir::Attribute") body << formatv(genericAttrParserCode, var->name, attrTypeStr); else body << formatv(attrParserCode, var->name, attrTypeStr); @@ -1368,14 +1392,16 @@ } else if (lengthKind == ArgumentLengthKind::Optional) { body << llvm::formatv(optionalTypeParserCode, listName); } else { + const char *parserCode = + dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode; TypeSwitch(dir->getOperand()) .Case([&](auto operand) { - body << formatv(typeParserCode, + body << formatv(parserCode, operand->getVar()->constraint.getCPPClassName(), listName); }) .Default([&](auto operand) { - body << formatv(typeParserCode, "::mlir::Type", listName); + body << formatv(parserCode, "::mlir::Type", listName); }); } } else if (auto *dir = dyn_cast(element)) { @@ -2025,7 +2051,8 @@ else if (var->attr.isOptional()) body << "_odsPrinter.printAttribute(" << op.getGetterName(var->name) << "Attr());\n"; - else if (var->attr.getStorageType() == "::mlir::Attribute") + else if (attr->shouldBeQualified() || + var->attr.getStorageType() == "::mlir::Attribute") body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name) << "Attr());\n"; else @@ -2093,6 +2120,11 @@ if (var && !var->isVariadicOfVariadic() && !var->isVariadic() && !var->isOptional()) { std::string cppClass = var->constraint.getCPPClassName(); + if (dir->shouldBeQualified()) { + body << " _odsPrinter << " << op.getGetterName(var->name) + << "().getType();\n"; + return; + } body << " {\n" << " auto type = " << op.getGetterName(var->name) << "().getType();\n" @@ -2253,6 +2285,8 @@ ParserContext context); LogicalResult parseOperandsDirective(std::unique_ptr &element, llvm::SMLoc loc, ParserContext context); + LogicalResult parseQualifiedDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context); LogicalResult parseReferenceDirective(std::unique_ptr &element, llvm::SMLoc loc, ParserContext context); LogicalResult parseRegionsDirective(std::unique_ptr &element, @@ -2762,6 +2796,8 @@ return parseFunctionalTypeDirective(element, dirTok, context); case FormatToken::kw_operands: return parseOperandsDirective(element, dirTok.getLoc(), context); + case FormatToken::kw_qualified: + return parseQualifiedDirective(element, dirTok, context); case FormatToken::kw_regions: return parseRegionsDirective(element, dirTok.getLoc(), context); case FormatToken::kw_results: @@ -3176,6 +3212,27 @@ return ::mlir::success(); } +LogicalResult +FormatParser::parseQualifiedDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context) { + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before argument list")) || + failed(parseElement(element, context)) || + failed( + parseToken(FormatToken::r_paren, "expected ')' after argument list"))) + return failure(); + if (auto *attr = dyn_cast(element.get())) { + attr->setShouldBeQualified(); + } else if (auto *type = dyn_cast(element.get())) { + type->setShouldBeQualified(); + } else { + return emitError( + tok.getLoc(), + "'qualified' directive expects an attribute or a `type` directive"); + } + return success(); +} + LogicalResult FormatParser::parseTypeDirectiveOperand(std::unique_ptr &element, bool isRefChild) {