diff --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md --- a/mlir/docs/AttributesAndTypes.md +++ b/mlir/docs/AttributesAndTypes.md @@ -656,11 +656,16 @@ will be set to `llvm::None` and `Attribute` will be set to `nullptr`. The presence of these parameters is tested by comparing them to their "null" values. -Only optional parameters or directives that only capture optional parameters can -be used in optional groups. An optional group is a set of elements optionally -printed based on the presence of an anchor. The group in which the anchor is -placed is printed if it is present, otherwise the other one is printed. Suppose -parameter `a` is an `IntegerAttr`. +An optional group is a set of elements optionally printed based on the presence +of an anchor. Only optional parameters or directives that only capture optional +parameters can be used in optional groups. The group in which the anchor is +placed is printed if it is present, otherwise the other one is printed. If a +directive that captures more than one optional parameter is used as the anchor, +the optional group is printed if any of the captured parameters is present. For +example, a `custom` directive may only be used as an optional group anchor if it +captures at least one optional parameter. + +Suppose parameter `a` is an `IntegerAttr`. ``` ( `(` $a^ `)` ) : (`x`)? diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -291,4 +291,11 @@ def TestArrayOfEnums : ArrayOfAttr; +// Test custom directive as optional group anchor. +def TestCustomAnchor : Test_Attr<"TestCustomAnchor"> { + let parameters = (ins "int":$a, OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "custom_anchor"; + let assemblyFormat = "`<` $a (`>`) : (`,` ` ` custom($b)^ `>`)?"; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -174,6 +174,23 @@ return llvm::None; } +//===----------------------------------------------------------------------===// +// TestCustomAnchorAttr +//===----------------------------------------------------------------------===// + +static ParseResult parseTrueFalse(AsmParser &p, + FailureOr> &result) { + bool b; + if (p.parseInteger(b)) + return failure(); + result = Optional(b); + return success(); +} + +static void printTrueFalse(AsmPrinter &p, Optional result) { + p << (*result ? "true" : "false"); +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td --- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td @@ -132,3 +132,15 @@ // CHECK: `struct` can only be used at the top-level context let assemblyFormat = "custom(struct(params))"; } + +def InvalidTypeS : InvalidType<"InvalidTypeS", "invalid_s"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + // CHECK: `custom` is only allowed in an optional group if all captured parameters are optional + let assemblyFormat = "(`(` custom($a, $b)^ `)`)?"; +} + +def InvalidTypeT : InvalidType<"InvalidTypeT", "invalid_t"> { + let parameters = (ins OptionalParameter<"int">:$a); + // CHECK: `custom` directive with no bound parameters cannot be used as optional group anchor + let assemblyFormat = "$a (`(` custom(ref($a))^ `)`)?"; +} diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -16,7 +16,11 @@ // CHECK: #test.attr_with_type> attr4 = #test.attr_with_type>, // CHECK: #test.attr_self_type_format<5> : i32 - attr5 = #test.attr_self_type_format<5> : i32 + attr5 = #test.attr_self_type_format<5> : i32, + // CHECK: #test.custom_anchor<5> + attr6 = #test.custom_anchor<5>, + // CHECK: #test.custom_anchor<5, true> + attr7 = #test.custom_anchor<5, true> } // CHECK-LABEL: @test_roundtrip_default_parsers_struct diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -849,9 +849,18 @@ guardOnAny(ctx, os, llvm::makeArrayRef(param), el->isInverted()); } else if (auto *params = dyn_cast(anchor)) { guardOnAny(ctx, os, params->getParams(), el->isInverted()); - } else { - auto *strct = cast(anchor); + } else if (auto *strct = dyn_cast(anchor)) { guardOnAny(ctx, os, strct->getParams(), el->isInverted()); + } else { + auto *custom = cast(anchor); + guardOnAny(ctx, os, + llvm::make_filter_range( + llvm::map_range(custom->getArguments(), + [](FormatElement *el) { + return dyn_cast(el); + }), + [](ParameterElement *param) { return !!param; }), + el->isInverted()); } // Generate the printer for the contained elements. { @@ -994,13 +1003,34 @@ return emitError(loc, "`struct` is only allowed in an optional group " "if all captured parameters are optional"); } + } else if (auto *custom = dyn_cast(el)) { + for (FormatElement *el : custom->getArguments()) { + // If the custom argument is a variable, then it must be optional. + if (auto param = dyn_cast(el)) + if (!param->isOptional()) + return emitError(loc, + "`custom` is only allowed in an optional group if " + "all captured parameters are optional"); + } } } // The anchor must be a parameter or one of the aforementioned directives. - if (anchor && - !isa(anchor)) { - return emitError(loc, - "optional group anchor must be a parameter or directive"); + if (anchor) { + if (!isa(anchor)) { + return emitError( + loc, "optional group anchor must be a parameter or directive"); + } + // If the anchor is a custom directive, make sure at least one of its + // arguments is a bound parameter. + if (auto custom = dyn_cast(anchor)) { + auto bound = llvm::find_if(custom->getArguments(), [](FormatElement *el) { + return isa(el); + }); + if (bound == custom->getArguments().end()) + return emitError(loc, "`custom` directive with no bound parameters " + "cannot be used as optional group anchor"); + } } return success(); }