diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -291,4 +291,11 @@ def TestArrayOfEnums : ArrayOfAttr; +// Test custom directive as optional group anchor. +def TestCustomAnchor : Test_Attr<"TestCustomAnchor"> { + let parameters = (ins "int":$a, OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "custom_anchor"; + let assemblyFormat = "`<` $a (`>`) : (`,` ` ` custom($b)^ `>`)?"; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -174,6 +174,23 @@ return llvm::None; } +//===----------------------------------------------------------------------===// +// TestCustomAnchorAttr +//===----------------------------------------------------------------------===// + +static ParseResult parseTrueFalse(AsmParser &p, + FailureOr> &result) { + bool b; + if (p.parseInteger(b)) + return failure(); + result = Optional(b); + return success(); +} + +static void printTrueFalse(AsmPrinter &p, Optional result) { + p << (*result ? "true" : "false"); +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -16,7 +16,11 @@ // CHECK: #test.attr_with_type> attr4 = #test.attr_with_type>, // CHECK: #test.attr_self_type_format<5> : i32 - attr5 = #test.attr_self_type_format<5> : i32 + attr5 = #test.attr_self_type_format<5> : i32, + // CHECK: #test.custom_anchor<5> + attr6 = #test.custom_anchor<5>, + // CHECK: #test.custom_anchor<5, true> + attr7 = #test.custom_anchor<5, true> } // CHECK-LABEL: @test_roundtrip_default_parsers_struct diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -849,9 +849,18 @@ guardOnAny(ctx, os, llvm::makeArrayRef(param), el->isInverted()); } else if (auto *params = dyn_cast(anchor)) { guardOnAny(ctx, os, params->getParams(), el->isInverted()); - } else { - auto *strct = cast(anchor); + } else if (auto *strct = dyn_cast(anchor)) { guardOnAny(ctx, os, strct->getParams(), el->isInverted()); + } else { + auto *custom = cast(anchor); + guardOnAny(ctx, os, + llvm::make_filter_range( + llvm::map_range(custom->getArguments(), + [](FormatElement *el) { + return dyn_cast(el); + }), + [](ParameterElement *param) { return !!param; }), + el->isInverted()); } // Generate the printer for the contained elements. { @@ -994,13 +1003,33 @@ return emitError(loc, "`struct` is only allowed in an optional group " "if all captured parameters are optional"); } + } else if (auto *custom = dyn_cast(el)) { + for (FormatElement *el : custom->getArguments()) { + // If the custom argument is a variable, then it must be optional. + if (auto param = dyn_cast(el); + param && !param->isOptional()) + return emitError(loc, "`custom` is only allowed in an optional group " + "if all captured parameters are optional"); + } } } // The anchor must be a parameter or one of the aforementioned directives. - if (anchor && - !isa(anchor)) { - return emitError(loc, - "optional group anchor must be a parameter or directive"); + if (anchor) { + if (!isa(anchor)) { + return emitError( + loc, "optional group anchor must be a parameter or directive"); + } + // If the anchor is a custom directive, make sure at least one of its + // arguments is a bound parameter. + if (auto custom = dyn_cast(anchor)) { + auto bound = llvm::find_if(custom->getArguments(), [](FormatElement *el) { + return isa(el); + }); + if (bound == custom->getArguments().end()) + return emitError(loc, "`custom` directive with no bound parameters " + "cannot be used as optional group anchor"); + } } return success(); }