diff --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td --- a/mlir/test/mlir-tblgen/op-format-invalid.td +++ b/mlir/test/mlir-tblgen/op-format-invalid.td @@ -389,7 +389,7 @@ def OptionalInvalidK : TestFormat_Op<[{ ($arg^) }]>, Arguments<(ins Variadic:$arg)>; -// CHECK: error: only variables and types can be used to anchor an optional group +// CHECK: error: only variable length operands can be used within an optional group def OptionalInvalidL : TestFormat_Op<[{ (custom($arg)^)? }]>, Arguments<(ins I64:$arg)>; diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -41,6 +41,9 @@ def DirectiveCustomValidC : TestFormat_Op<[{ custom($attr) attr-dict }]>, Arguments<(ins I64Attr:$attr)>; +def DirectiveCustomValidD : TestFormat_Op<[{ + (`(` custom($operand)^ `)`)? attr-dict +}]>, Arguments<(ins Optional:$operand)>; //===----------------------------------------------------------------------===// // functional-type diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td --- a/mlir/test/mlir-tblgen/op-format.td +++ b/mlir/test/mlir-tblgen/op-format.td @@ -83,3 +83,10 @@ def OptionalGroupC : TestFormat_Op<[{ ($a^)? attr-dict }]>, Arguments<(ins DefaultValuedStrAttr:$a)>; + +// CHECK-LABEL: OptionalGroupD::print +// CHECK-NEXT: if (((getA()) || (getB()))) { +// CHECK-NEXT: odsPrinter << "(" +def OptionalGroupD : TestFormat_Op<[{ + (`(` custom($a, $b)^ `)`)? attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Optional:$a, Optional:$b)>; diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -2008,19 +2008,19 @@ else if (var->isVariadic()) body << "!" << name << "().empty()"; }) - .Case([&](RegionVariable *element) { + .Case([&](RegionVariable *element) { const NamedRegion *var = element->getVar(); std::string name = op.getGetterName(var->name); // TODO: Add a check for optional regions here when ODS supports it. body << "!" << name << "().empty()"; }) - .Case([&](TypeDirective *element) { + .Case([&](TypeDirective *element) { genOptionalGroupPrinterAnchor(element->getArg(), op, body); }) - .Case([&](FunctionalTypeDirective *element) { + .Case([&](FunctionalTypeDirective *element) { genOptionalGroupPrinterAnchor(element->getInputs(), op, body); }) - .Case([&](AttributeVariable *element) { + .Case([&](AttributeVariable *element) { Attribute attr = element->getVar()->attr; body << op.getGetterName(element->getVar()->name) << "Attr()"; if (attr.isOptional()) @@ -2037,6 +2037,18 @@ return; } llvm_unreachable("attribute must be optional or default-valued"); + }) + .Case([&](CustomDirective *ele) { + body << '('; + llvm::interleave( + ele->getArguments(), body, + [&](FormatElement *child) { + body << '('; + genOptionalGroupPrinterAnchor(child, op, body); + body << ')'; + }, + " || "); + body << ')'; }); } @@ -3434,15 +3446,28 @@ return verifyOptionalGroupElement(loc, ele->getResults(), /*isAnchor=*/false); }) - // Literals, whitespace, and custom directives may be used, but they can't - // anchor the group. - .Case([&](FormatElement *) { - if (isAnchor) - return emitError(loc, "only variables and types can be used " - "to anchor an optional group"); + .Case([&](CustomDirective *ele) { + if (!isAnchor) + return success(); + // Verify each child as being valid in an optional group. They are all + // potential anchors if the custom directive was marked as one. + for (FormatElement *child : ele->getArguments()) { + if (isa(child)) + continue; + if (failed(verifyOptionalGroupElement(loc, child, /*isAnchor=*/true))) + return failure(); + } return success(); }) + // Literals, whitespace, and custom directives may be used, but they can't + // anchor the group. + .Case( + [&](FormatElement *) { + if (isAnchor) + return emitError(loc, "only variables and types can be used " + "to anchor an optional group"); + return success(); + }) .Default([&](FormatElement *) { return emitError(loc, "only literals, types, and variables can be " "used within an optional group");