diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -651,6 +651,16 @@ - `input` must be either an operand or result [variable](#variables), the `operands` directive, or the `results` directive. +* `qualified` ( type_or_attribute ) + + - Wrap a `type` directive or an attribute parameter. + - Used to force printing the type or attribute prefixed with its dialect + and mnemonic. For example the `vector.multi_reduction` operation has a + `kind` attribute, by default the declarative assembly will print: + `vector.multi_reduction , ...` but using `qualified($kind)` in the + declarative assembly format will print it instead as: + `vector.multi_reduction #vector.kind, ...`. + #### Literals A literal is either a keyword or punctuation surrounded by \`\`. diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md --- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md +++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md @@ -489,9 +489,10 @@ Attribute and type assembly formats have the following directives: -* `params`: capture all parameters of an attribute or type. -* `struct`: generate a "struct-like" parser and printer for a list of key-value - pairs. +* `params`: capture all parameters of an attribute or type. +* `qualified`: capture all parameters of an attribute or type. +* `struct`: generate a "struct-like" parser and printer for a list of + key-value pairs. #### `params` Directive @@ -517,6 +518,34 @@ as an argument that refers to all parameters in place of explicitly listing all parameters as variables. +#### `qualified` Directive + +This directive can be used to wrap parameters so that nested types and/or +attributes are printed fully qualified, that is they include the dialect name +and mnemonic prefix. + +For example: + +```tablegen +def OuterType : TypeDef { + let parameters = (ins MyPairType:$inner); + let mnemonic = "outer"; + let assemblyFormat = "`<` pair `:` $inner `>`"; +} +def OuterQualifiedType : TypeDef { + let parameters = (ins MyPairType:$inner); + let mnemonic = "outer_qual"; + let assemblyFormat = "`<` pair `:` qualified($inner) `>`"; +} +``` + +In the IR, the types will appear as: + +```mlir +!my_dialect.outer> +!my_dialect.outer_qual> +``` + #### `struct` Directive The `struct` directive accepts a list of variables to capture and will generate diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -144,6 +144,17 @@ let assemblyFormat = "`<` `i` $inner `>`"; } +def CompoundNestedOuterQual : Test_Attr<"CompoundNestedOuterQual"> { + let mnemonic = "cmpnd_nested_outer_qual"; + + // List of type parameters. + let parameters = ( + ins + CompoundNestedInner:$inner + ); + let assemblyFormat = "`<` `i` qualified($inner) `>`"; +} + def TestParamOne : AttrParameter<"int64_t", ""> {} def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1955,11 +1955,21 @@ let assemblyFormat = "`nested` $nested attr-dict-with-keyword"; } +def FormatQualifiedCompoundAttr : TEST_Op<"format_qual_cpmd_nested_attr"> { + let arguments = (ins CompoundNestedOuter:$nested); + let assemblyFormat = "`nested` qualified($nested) attr-dict-with-keyword"; +} + def FormatNestedType : TEST_Op<"format_cpmd_nested_type"> { let arguments = (ins CompoundNestedOuterType:$nested); let assemblyFormat = "$nested `nested` type($nested) attr-dict-with-keyword"; } +def FormatQualifiedNestedType : TEST_Op<"format_qual_cpmd_nested_type"> { + let arguments = (ins CompoundNestedOuterType:$nested); + let assemblyFormat = "$nested `nested` qualified(type($nested)) attr-dict-with-keyword"; +} + //===----------------------------------------------------------------------===// // Custom Directives diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -73,6 +73,17 @@ let assemblyFormat = "`<` `i` $inner `>`"; } +def CompoundNestedOuterTypeQual : Test_Type<"CompoundNestedOuterQual"> { + let mnemonic = "cmpnd_nested_outer_qual"; + + // List of type parameters. + let parameters = ( + ins + CompoundNestedInnerType:$inner + ); + let assemblyFormat = "`<` `i` qualified($inner) `>`"; +} + // An example of how one could implement a standard integer. def IntegerType : Test_Type<"TestInteger"> { let mnemonic = "int"; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -301,6 +301,24 @@ // CHECK: test.format_cpmd_nested_attr nested >> test.format_cpmd_nested_attr nested >> +//----- + +// CHECK: test.format_qual_cpmd_nested_attr nested #test.cmpnd_nested_outer>> +test.format_qual_cpmd_nested_attr nested #test.cmpnd_nested_outer>> + +//----- + +// Check the `qualified` directive in the declarative assembly format. +// CHECK: @qualifiedCompoundNestedExplicit(%arg0: !test.cmpnd_nested_outer>>) +func @qualifiedCompoundNestedExplicit(%arg0: !test.cmpnd_nested_outer>>) -> () { +// Verify that the type prefix is not elided +// CHECK: format_qual_cpmd_nested_type %arg0 nested !test.cmpnd_nested_outer>> + test.format_qual_cpmd_nested_type %arg0 nested !test.cmpnd_nested_outer>> + return +} + +//----- + //===----------------------------------------------------------------------===// // Format custom directives //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir --- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir @@ -9,3 +9,7 @@ // CHECK: test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16 %b = test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16 + +// CHECK-LABEL: @qualifiedAttr() +// CHECK-SAME: #test.cmpnd_nested_outer_qual>> +func private @qualifiedAttr() attributes {foo = #test.cmpnd_nested_outer_qual>>} diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir --- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir @@ -29,6 +29,11 @@ return } +// CHECK: @compoundNestedQual(%arg0: !test.cmpnd_nested_outer_qual>>) +func @compoundNestedQual(%arg0: !test.cmpnd_nested_outer_qual>>) -> () { + return +} + // CHECK: @testInt(%arg0: !test.int, %arg1: !test.int, %arg2: !test.int) func @testInt(%A : !test.int, %B : !test.int, %C : !test.int) { return diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -89,7 +89,13 @@ /// Get the parameter in the element. const AttrOrTypeParameter &getParam() const { return param; } + bool isQualified() { return qualified; } + void setQualified(bool qualified = true) { this->qualified = qualified; } + private: + // Flag to indicate if this variable is printed "qualified" (that is it is + // prefixed with the `#dialect.mnemonic`). + bool qualified = false; AttrOrTypeParameter param; }; @@ -166,6 +172,10 @@ static const char *const defaultParameterPrinter = "$_printer.printStrippedAttrOrType($_self)"; +/// Qualified printer for attribute or type parameters: it does not elide +/// dialect and mnemonic. +static const char *const qualifiedParameterPrinter = "$_printer << $_self"; + /// Print an error when failing to parse an element. /// /// $0: The parameter C++ class name. @@ -251,7 +261,7 @@ void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a variable. void genVariablePrinter(const AttrOrTypeParameter ¶m, FmtContext &ctx, - MethodBody &os); + MethodBody &os, bool printQualified = false); /// Generate the printer code for a `params` directive. void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a `struct` directive. @@ -435,7 +445,7 @@ if (auto *strct = dyn_cast(el)) return genStructPrinter(strct, ctx, os); if (auto *var = dyn_cast(el)) - return genVariablePrinter(var->getParam(), ctx, os); + return genVariablePrinter(var->getParam(), ctx, os, var->isQualified()); llvm_unreachable("unknown format element"); } @@ -455,7 +465,8 @@ } void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter ¶m, - FmtContext &ctx, MethodBody &os) { + FmtContext &ctx, MethodBody &os, + bool printQualified) { /// Insert a space before the next parameter, if necessary. if (shouldEmitSpace || !lastWasPunctuation) os << tgfmt(" $_printer << ' ';\n", &ctx); @@ -464,7 +475,9 @@ ctx.withSelf(getParameterAccessorName(param.getName()) + "()"); os << " "; - if (auto printer = param.getPrinter()) + if (printQualified) + os << tgfmt(qualifiedParameterPrinter, &ctx) << ";\n"; + else if (auto printer = param.getPrinter()) os << tgfmt(*printer, &ctx) << ";\n"; else os << tgfmt(defaultParameterPrinter, &ctx) << ";\n"; @@ -546,6 +559,9 @@ FailureOr> parseDirective(ParserContext ctx); /// Parse a `params` directive. FailureOr> parseParamsDirective(); + /// Parse a `qualified` directive. + FailureOr> + parseQualifiedDirective(ParserContext ctx); /// Parse a `struct` directive. FailureOr> parseStructDirective(); @@ -643,6 +659,8 @@ FormatParser::parseDirective(ParserContext ctx) { switch (curToken.getKind()) { + case FormatToken::kw_qualified: + return parseQualifiedDirective(ctx); case FormatToken::kw_params: return parseParamsDirective(); case FormatToken::kw_struct: @@ -656,6 +674,22 @@ } } +FailureOr> +FormatParser::parseQualifiedDirective(ParserContext ctx) { + consumeToken(); + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before argument list"))) + return failure(); + FailureOr> var = parseElement(ctx); + if (failed(var) || !isa(*var)) + return emitError("`qualified` argument list expected a variable"); + cast(var->get())->setQualified(); + if (failed( + parseToken(FormatToken::r_paren, "expected ')' after argument list"))) + return failure(); + return var; +} + FailureOr> FormatParser::parseParamsDirective() { consumeToken(); /// Collect all of the attribute's or type's parameters. diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -59,6 +59,7 @@ kw_functional_type, kw_operands, kw_params, + kw_qualified, kw_ref, kw_regions, kw_results, diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -172,6 +172,7 @@ .Case("struct", FormatToken::kw_struct) .Case("successors", FormatToken::kw_successors) .Case("type", FormatToken::kw_type) + .Case("qualified", FormatToken::kw_qualified) .Default(FormatToken::identifier); return FormatToken(kind, str); } diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -45,6 +45,7 @@ CustomDirective, FunctionalTypeDirective, OperandsDirective, + QualifiedDirective, RefDirective, RegionsDirective, ResultsDirective, @@ -117,6 +118,14 @@ bool isUnitAttr() const { return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr"; } + + bool isQualified() { return qualified; } + void setQualified(bool qualified = true) { this->qualified = qualified; } + +private: + // Flag to indicate if this attribute is printed "qualified" (that is it is + // prefixed with the `#dialect.mnemonic`). + bool qualified = false; }; /// This class represents a variable that refers to an operand argument. @@ -236,10 +245,16 @@ public: TypeDirective(std::unique_ptr arg) : operand(std::move(arg)) {} Element *getOperand() const { return operand.get(); } + bool isQualified() { return qualified; } + void setQualified(bool qualified = true) { this->qualified = qualified; } private: /// The operand that is used to format the directive. std::unique_ptr operand; + + // Flag to indicate if this type is printed "qualified" (that is it is + // prefixed with the `!dialect.mnemonic`). + bool qualified = false; }; } // namespace @@ -658,6 +673,10 @@ {1}RawTypes[0] = type; } )"; +const char *const qualifiedTypeParserCode = R"( + if (parser.parseType({1}RawTypes[0])) + return ::mlir::failure(); +)"; /// The code snippet used to generate a parser call for a functional type. /// @@ -1296,7 +1315,8 @@ if (var->attr.isOptional()) { body << formatv(optionalAttrParserCode, var->name, attrTypeStr); } else { - if (var->attr.getStorageType() == "::mlir::Attribute") + if (attr->isQualified() || + var->attr.getStorageType() == "::mlir::Attribute") body << formatv(genericAttrParserCode, var->name, attrTypeStr); else body << formatv(attrParserCode, var->name, attrTypeStr); @@ -1368,14 +1388,16 @@ } else if (lengthKind == ArgumentLengthKind::Optional) { body << llvm::formatv(optionalTypeParserCode, listName); } else { + const char *parserCode = + dir->isQualified() ? qualifiedTypeParserCode : typeParserCode; TypeSwitch(dir->getOperand()) .Case([&](auto operand) { - body << formatv(typeParserCode, + body << formatv(parserCode, operand->getVar()->constraint.getCPPClassName(), listName); }) .Default([&](auto operand) { - body << formatv(typeParserCode, "::mlir::Type", listName); + body << formatv(parserCode, "::mlir::Type", listName); }); } } else if (auto *dir = dyn_cast(element)) { @@ -2025,7 +2047,8 @@ else if (var->attr.isOptional()) body << "_odsPrinter.printAttribute(" << op.getGetterName(var->name) << "Attr());\n"; - else if (var->attr.getStorageType() == "::mlir::Attribute") + else if (attr->isQualified() || + var->attr.getStorageType() == "::mlir::Attribute") body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name) << "Attr());\n"; else @@ -2093,6 +2116,11 @@ if (var && !var->isVariadicOfVariadic() && !var->isVariadic() && !var->isOptional()) { std::string cppClass = var->constraint.getCPPClassName(); + if (dir->isQualified()) { + body << " _odsPrinter << " << op.getGetterName(var->name) + << "().getType();\n"; + return; + } body << " {\n" << " auto type = " << op.getGetterName(var->name) << "().getType();\n" @@ -2253,6 +2281,8 @@ ParserContext context); LogicalResult parseOperandsDirective(std::unique_ptr &element, llvm::SMLoc loc, ParserContext context); + LogicalResult parseQualifiedDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context); LogicalResult parseReferenceDirective(std::unique_ptr &element, llvm::SMLoc loc, ParserContext context); LogicalResult parseRegionsDirective(std::unique_ptr &element, @@ -2762,6 +2792,8 @@ return parseFunctionalTypeDirective(element, dirTok, context); case FormatToken::kw_operands: return parseOperandsDirective(element, dirTok.getLoc(), context); + case FormatToken::kw_qualified: + return parseQualifiedDirective(element, dirTok, context); case FormatToken::kw_regions: return parseRegionsDirective(element, dirTok.getLoc(), context); case FormatToken::kw_results: @@ -3176,6 +3208,27 @@ return ::mlir::success(); } +LogicalResult +FormatParser::parseQualifiedDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context) { + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before argument list")) || + failed(parseElement(element, context)) || + failed( + parseToken(FormatToken::r_paren, "expected ')' after argument list"))) + return ::mlir::failure(); + if (auto *attr = dyn_cast(element.get())) + attr->setQualified(); + else if (auto *type = dyn_cast(element.get())) + type->setQualified(); + else + return emitError( + tok.getLoc(), + "'qualified' directive expects an attribute or a `type` directive"); + + return ::mlir::success(); +} + LogicalResult FormatParser::parseTypeDirectiveOperand(std::unique_ptr &element, bool isRefChild) {