diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md --- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md +++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md @@ -558,6 +558,8 @@ mnemonic. * `struct`: generate a "struct-like" parser and printer for a list of key-value pairs. +* `custom`: dispatch a call to user-define parser and printer functions +* `ref`: in a custom directive, references a previously bound variable #### `params` Directive @@ -649,3 +651,44 @@ The order in which the parameters are printed is the order in which they are declared in the attribute's or type's `parameter` list. + +#### `custom` and `ref` directive + +The `custom` directive is used to dispatch calls to user-defined printer and +parser functions. For example, suppose we had the following type: + +```tablegen +let parameters = (ins "int":$foo, "int":$bar); +let assemblyFormat = "custom($foo) custom($bar, ref($foo))"; +``` + +The `custom` directive `custom($foo)` will in the parser and printer +respectively generate calls to: + +```c++ +LogicalResult parseFoo(AsmParser &parser, FailureOr &foo); +void printFoo(AsmPrinter &printer, int foo); +``` + +A previously bound variable can be passed as a parameter to a `custom` directive +by wrapping it in a `ref` directive. In the previous example, `$foo` is bound by +the first directive. The second directive references it and expects the +following printer and parser signatures: + +```c++ +LogicalResult parseBar(AsmParser &parser, FailureOr &bar, int foo); +void printBar(AsmPrinter &printer, int bar, int foo); +``` + +More complex C++ types can be used with the `custom` directive. The only caveat +is that the parameter for the parser must use the storage type of the parameter. +For example, `StringRefParameter` expects the parser and printer signatures as: + +```c++ +LogicalResult parseStringParam(AsmParser &parser, + FailureOr &value); +void printStringParam(AsmPrinter &printer, StringRef value); +``` + +The custom parser is considered to have failed if it returns failure or if any +bound parameters have failure values afterwards. diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -363,4 +363,18 @@ let assemblyFormat = "`<` (`(` $type^ `)`)? `>`"; } +def TestTypeCustom : Test_Type<"TestTypeCustom"> { + let parameters = (ins "int":$a, OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "custom_type"; + let assemblyFormat = [{ `<` custom($a) + custom(ref($a), $b) `>` }]; +} + +def TestTypeCustomString : Test_Type<"TestTypeCustomString"> { + let parameters = (ins StringRefParameter<>:$foo); + let mnemonic = "custom_type_string"; + let assemblyFormat = [{ `<` custom($foo) + custom(ref($foo)) `>` }]; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -208,6 +208,59 @@ return 1; } +//===----------------------------------------------------------------------===// +// TestCustomType +//===----------------------------------------------------------------------===// + +static LogicalResult parseCustomTypeA(AsmParser &parser, + FailureOr &a_result) { + a_result.emplace(); + return parser.parseInteger(*a_result); +} + +static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; } + +static LogicalResult parseCustomTypeB(AsmParser &parser, int a, + FailureOr> &b_result) { + if (a < 0) + return success(); + for (int i : llvm::seq(0, a)) + if (failed(parser.parseInteger(i))) + return failure(); + b_result.emplace(0); + return parser.parseInteger(**b_result); +} + +static void printCustomTypeB(AsmPrinter &printer, int a, Optional b) { + if (a < 0) + return; + printer << ' '; + for (int i : llvm::seq(0, a)) + printer << i << ' '; + printer << *b; +} + +static LogicalResult parseFooString(AsmParser &parser, + FailureOr &foo) { + std::string result; + if (parser.parseString(&result)) + return failure(); + foo = std::move(result); + return success(); +} + +static void printFooString(AsmPrinter &printer, StringRef foo) { + printer << '"' << foo << '"'; +} + +static LogicalResult parseBarString(AsmParser &parser, StringRef foo) { + return parser.parseKeyword(foo); +} + +static void printBarString(AsmPrinter &printer, StringRef foo) { + printer << ' ' << foo; +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td --- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td @@ -107,3 +107,27 @@ // CHECK: optional group anchor must be a parameter or directive let assemblyFormat = "(`(` $a `)`^)?"; } + +def InvalidTypeO : InvalidType<"InvalidTypeO", "invalid_o"> { + let parameters = (ins "int":$a); + // CHECK: `ref` is only allowed inside custom directives + let assemblyFormat = "$a ref($a)"; +} + +def InvalidTypeP : InvalidType<"InvalidTypeP", "invalid_p"> { + let parameters = (ins "int":$a); + // CHECK: parameter 'a' must be bound before it is referenced + let assemblyFormat = "custom(ref($a)) $a"; +} + +def InvalidTypeQ : InvalidType<"InvalidTypeQ", "invalid_q"> { + let parameters = (ins "int":$a); + // CHECK: `params` can only be used at the top-level context or within a `struct` directive + let assemblyFormat = "custom(params)"; +} + +def InvalidTypeR : InvalidType<"InvalidTypeR", "invalid_r"> { + let parameters = (ins "int":$a); + // CHECK: `struct` can only be used at the top-level context + let assemblyFormat = "custom(struct(params))"; +} diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -48,6 +48,10 @@ // CHECK: !test.ap_float<> // CHECK: !test.default_valued_type<(i64)> // CHECK: !test.default_valued_type<> +// CHECK: !test.custom_type<-5> +// CHECK: !test.custom_type<2 0 1 5> +// CHECK: !test.custom_type_string<"foo" foo> +// CHECK: !test.custom_type_string<"bar" bar> func private @test_roundtrip_default_parsers_struct( !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4> @@ -79,5 +83,9 @@ !test.ap_float<5.0>, !test.ap_float<>, !test.default_valued_type<(i64)>, - !test.default_valued_type<> + !test.default_valued_type<>, + !test.custom_type<-5>, + !test.custom_type<2 9 9 5>, + !test.custom_type_string<"foo" foo>, + !test.custom_type_string<"bar" bar> ) diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -499,3 +499,27 @@ let mnemonic = "type_k"; let assemblyFormat = "$a"; } + +// TYPE: ::mlir::Type TestLType::parse +// TYPE: auto odsCustomLoc = odsParser.getCurrentLocation() +// TYPE: auto odsCustomResult = parseA(odsParser, +// TYPE-NEXT: _result_a +// TYPE: if (::mlir::failed(odsCustomResult)) return {} +// TYPE: if (::mlir::failed(_result_a)) +// TYPE-NEXT: odsParser.emitError(odsCustomLoc, +// TYPE: auto odsCustomResult = parseB(odsParser, +// TYPE-NEXT: _result_b +// TYPE-NEXT: *_result_a + +// TYPE: void TestLType::print +// TYPE: printA(odsPrinter +// TYPE-NEXT: getA() +// TYPE: printB(odsPrinter +// TYPE-NEXT: getB() +// TYPE-NEXT: getA() + +def TypeJ : TestType<"TestL"> { + let parameters = (ins "int":$a, OptionalParameter<"Attribute">:$b); + let mnemonic = "type_j"; + let assemblyFormat = "custom($a) custom($b, ref($a))"; +} diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -199,6 +199,8 @@ void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the parser code for a `struct` directive. void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os); + /// Generate the parser code for a `custom` directive. + void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the parser code for an optional group. void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, MethodBody &os); @@ -218,6 +220,8 @@ void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a `struct` directive. void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os); + /// Generate the printer code for a `custom` directive. + void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for an optional group. void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, MethodBody &os); @@ -313,6 +317,8 @@ return genParamsParser(params, ctx, os); if (auto *strct = dyn_cast(el)) return genStructParser(strct, ctx, os); + if (auto *custom = dyn_cast(el)) + return genCustomParser(custom, ctx, os); if (auto *optional = dyn_cast(el)) return genOptionalGroupParser(optional, ctx, os); if (isa(el)) @@ -566,6 +572,47 @@ os.unindent() << "}\n"; } +void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx, + MethodBody &os) { + os << "{\n"; + os.indent(); + + // Bound variables are passed directly to the parser as `FailureOr &`. + // Referenced variables are passed as `T`. The custom parser fails if it + // returns failure or if any of the required parameters failed. + os << tgfmt("auto odsCustomLoc = $_parser.getCurrentLocation();\n", &ctx); + os << "(void)odsCustomLoc;\n"; + os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName()); + os.indent(); + for (FormatElement *arg : el->getArguments()) { + os << ",\n"; + FormatElement *param; + if (auto *ref = dyn_cast(arg)) { + os << "*"; + param = ref->getArg(); + } else { + param = arg; + } + os << "_result_" << cast(param)->getName(); + } + os.unindent() << ");\n"; + os << "if (::mlir::failed(odsCustomResult)) return {};\n"; + for (FormatElement *arg : el->getArguments()) { + if (auto *param = dyn_cast(arg)) { + if (param->isOptional()) + continue; + os << formatv("if (::mlir::failed(_result_{0})) {{\n", param->getName()); + os.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx) + << "\"custom parser failed to parse parameter '" + << param->getName() << "'\");\n"; + os << "return {};\n"; + os.unindent() << "}\n"; + } + } + + os.unindent() << "}\n"; +} + void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, MethodBody &os) { ArrayRef elements = @@ -634,6 +681,8 @@ return genParamsPrinter(params, ctx, os); if (auto *strct = dyn_cast(el)) return genStructPrinter(strct, ctx, os); + if (auto *custom = dyn_cast(el)) + return genCustomPrinter(custom, ctx, os); if (auto *var = dyn_cast(el)) return genVariablePrinter(var, ctx, os); if (auto *optional = dyn_cast(el)) @@ -746,6 +795,21 @@ }); } +void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx, + MethodBody &os) { + os << tgfmt("print$0($_printer", &ctx, el->getName()); + os.indent(); + for (FormatElement *arg : el->getArguments()) { + FormatElement *param = arg; + if (auto *ref = dyn_cast(arg)) + param = ref->getArg(); + os << ",\n" + << getParameterAccessorName(cast(param)->getName()) + << "()"; + } + os.unindent() << ");\n"; +} + void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, MethodBody &os) { FormatElement *anchor = el->getAnchor(); @@ -805,9 +869,7 @@ /// Verify the elements of a custom directive. LogicalResult verifyCustomDirectiveArguments(SMLoc loc, - ArrayRef arguments) override { - return emitError(loc, "'custom' not supported (yet)"); - } + ArrayRef arguments) override; /// Verify the elements of an optional group. LogicalResult verifyOptionalGroupElements(SMLoc loc, ArrayRef elements, @@ -822,11 +884,13 @@ private: /// Parse a `params` directive. - FailureOr parseParamsDirective(SMLoc loc); + FailureOr parseParamsDirective(SMLoc loc, Context ctx); /// Parse a `qualified` directive. FailureOr parseQualifiedDirective(SMLoc loc, Context ctx); /// Parse a `struct` directive. - FailureOr parseStructDirective(SMLoc loc); + FailureOr parseStructDirective(SMLoc loc, Context ctx); + /// Parse a `ref` directive. + FailureOr parseRefDirective(SMLoc loc, Context ctx); /// Attribute or type tablegen def. const AttrOrTypeDef &def; @@ -862,6 +926,12 @@ return success(); } +LogicalResult DefFormatParser::verifyCustomDirectiveArguments( + SMLoc loc, ArrayRef arguments) { + // Arguments are fully verified by the parser context. + return success(); +} + LogicalResult DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc, ArrayRef elements, @@ -915,9 +985,18 @@ def.getName() + " has no parameter named '" + name + "'"); } auto idx = std::distance(params.begin(), it); - if (seenParams.test(idx)) - return emitError(loc, "duplicate parameter '" + name + "'"); - seenParams.set(idx); + + if (ctx != RefDirectiveContext) { + // Check that the variable has not already been bound. + if (seenParams.test(idx)) + return emitError(loc, "duplicate parameter '" + name + "'"); + seenParams.set(idx); + + // Otherwise, to be referenced, a variable must have been bound. + } else if (!seenParams.test(idx)) { + return emitError(loc, "parameter '" + name + + "' must be bound before it is referenced"); + } return create(*it); } @@ -930,14 +1009,13 @@ case FormatToken::kw_qualified: return parseQualifiedDirective(loc, ctx); case FormatToken::kw_params: - return parseParamsDirective(loc); + return parseParamsDirective(loc, ctx); case FormatToken::kw_struct: - if (ctx != TopLevelContext) { - return emitError( - loc, - "`struct` may only be used in the top-level section of the format"); - } - return parseStructDirective(loc); + return parseStructDirective(loc, ctx); + case FormatToken::kw_ref: + return parseRefDirective(loc, ctx); + case FormatToken::kw_custom: + return parseCustomDirective(loc, ctx); default: return emitError(loc, "unsupported directive kind"); @@ -961,10 +1039,18 @@ return var; } -FailureOr DefFormatParser::parseParamsDirective(SMLoc loc) { - // Collect all of the attribute's or type's parameters. +FailureOr DefFormatParser::parseParamsDirective(SMLoc loc, + Context ctx) { + // It doesn't make sense to allow references to all parameters in a custom + // directive because parameters are the only things that can be bound. + if (ctx != TopLevelContext && ctx != StructDirectiveContext) { + return emitError(loc, "`params` can only be used at the top-level context " + "or within a `struct` directive"); + } + + // Collect all of the attribute's or type's parameters and ensure that none of + // the parameters have already been captured. std::vector vars; - // Ensure that none of the parameters have already been captured. for (const auto &it : llvm::enumerate(def.getParameters())) { if (seenParams.test(it.index())) { return emitError(loc, "`params` captures duplicate parameter: " + @@ -976,7 +1062,11 @@ return create(std::move(vars)); } -FailureOr DefFormatParser::parseStructDirective(SMLoc loc) { +FailureOr DefFormatParser::parseStructDirective(SMLoc loc, + Context ctx) { + if (ctx != TopLevelContext) + return emitError(loc, "`struct` can only be used at the top-level context"); + if (failed(parseToken(FormatToken::l_paren, "expected '(' before `struct` argument list"))) return failure(); @@ -1012,6 +1102,22 @@ return create(std::move(vars)); } +FailureOr DefFormatParser::parseRefDirective(SMLoc loc, + Context ctx) { + if (ctx != CustomDirectiveContext) + return emitError(loc, "`ref` is only allowed inside custom directives"); + + // Parse the child parameter element. + FailureOr child; + if (failed(parseToken(FormatToken::l_paren, "expected '('")) || + failed(child = parseElement(RefDirectiveContext)) || + failed(parseToken(FormatToken::r_paren, "expeced ')'"))) + return failure(); + + // Only parameter elements are allowed to be parsed under a `ref` directive. + return create(*child); +} + //===----------------------------------------------------------------------===// // Interface //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -338,6 +338,22 @@ std::vector arguments; }; +/// This class represents a reference directive. This directive can be used to +/// reference but not bind a previously bound variable or format object. Its +/// current only use is to pass variables as arguments to the custom directive. +class RefDirective : public DirectiveElementBase { +public: + /// Create a reference directive with the single referenced child. + RefDirective(FormatElement *arg) : arg(arg) {} + + /// Get the reference argument. + FormatElement *getArg() const { return arg; } + +private: + /// The referenced argument. + FormatElement *arg; +}; + /// This class represents a group of elements that are optionally emitted based /// on an optional variable "anchor" and a group of elements that are emitted /// when the anchor element is not present. diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -153,18 +153,6 @@ FormatElement *inputs, *results; }; -/// This class represents the `ref` directive. -class RefDirective : public DirectiveElementBase { -public: - RefDirective(FormatElement *arg) : arg(arg) {} - - FormatElement *getArg() const { return arg; } - -private: - /// The argument that is used to format the directive. - FormatElement *arg; -}; - /// This class represents the `type` directive. class TypeDirective : public DirectiveElementBase { public: