diff --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md --- a/mlir/docs/AttributesAndTypes.md +++ b/mlir/docs/AttributesAndTypes.md @@ -660,8 +660,9 @@ Only optional parameters or directives that only capture optional parameters can be used in optional groups. An optional group is a set of elements optionally -printed based on the presence of an anchor. Suppose parameter `a` is an -`IntegerAttr`. +printed based on the presence of an anchor. The group in which the anchor is +placed is printed if it is present, otherwise the other one is printed. Suppose +parameter `a` is an `IntegerAttr`. ``` ( `(` $a^ `)` ) : (`x`)? diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -856,17 +856,18 @@ information. An optional group is defined as follows: ``` -optional-group: `(` elements `)` (`:` `(` else-elements `)`)? `?` +optional-group: `(` then-elements `)` (`:` `(` else-elements `)`)? `?` ``` -The `elements` of an optional group have the following requirements: +The elements of an optional group have the following requirements: -* The first element of the group must either be a attribute, literal, operand, - or region. +* The first element of `then-elements` must either be a attribute, literal, + operand, or region. - This is because the first element must be optionally parsable. -* Exactly one argument variable or type directive within the group must be - marked as the anchor of the group. - - The anchor is the element whose presence controls whether the group +* Exactly one argument variable or type directive within either + `then-elements` or `else-elements` must be marked as the anchor of the + group. + - The anchor is the element whose presence controls which elements should be printed/parsed. - An element is marked as the anchor by adding a trailing `^`. - The first element is *not* required to be the anchor of the group. diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -315,7 +315,7 @@ // operator of `AsmPrinter` as necessary to print your type. Or you can // provide a custom printer. string printer = ?; - // Mark a parameter as optional. The C++ type of parameters marked as optional + // Mark a parameter as optional. // must be default constructible and be contextually convertible to `bool`. // Any `Optional` and any attribute type satisfies these requirements. bit isOptional = 0; diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -797,6 +797,11 @@ }]; } +def ElseAnchorOp : TEST_Op<"else_anchor"> { + let arguments = (ins Optional:$a); + let assemblyFormat = "`(` (`?`) : (`` $a^ `:` type($a))? `)` attr-dict"; +} + // This is used to test encoding of a string attribute into an SSA name of a // pretty printed value name. def StringAttrPrettyNameOp diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -332,4 +332,10 @@ custom(ref($foo)) `>` }]; } +def TestTypeElseAnchor : Test_Type<"TestTypeElseAnchor"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a); + let mnemonic = "else_anchor"; + let assemblyFormat = "`<` (`?`) : ($a^)? `>`"; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -571,3 +571,20 @@ let mnemonic = "type_l"; let assemblyFormat = [{ custom($a, "1") }]; } + +// TYPE-LABEL: ::mlir::Type TestOType::parse +// TYPE: if (odsParser.parseOptionalQuestion()) +// TYPE: _result_a = +// TYPE: else + +// TYPE-LABEL: void TestOType::print +// TYPE: if (!(getA())) +// TYPE: odsPrinter << ' ' << "?" +// TYPE: else +// TYPE: odsPrinter.printStrippedAttrOrType(getA()) + +def TypeM : TestType<"TestO"> { + let parameters = (ins OptionalParameter<"int">:$a); + let mnemonic = "type_m"; + let assemblyFormat = "(`?`) : ($a^)?"; +} diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -463,3 +463,18 @@ // CHECK: test.has_str_value test.has_str_value {} + +//===----------------------------------------------------------------------===// +// ElseAnchorOp +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @else_anchor_op +func.func @else_anchor_op(%a: !test.else_anchor, %b: !test.else_anchor<5>) { + // CHECK: test.else_anchor(?) + test.else_anchor(?) + // CHECK: test.else_anchor(%{{.*}} : !test.else_anchor) + test.else_anchor(%a : !test.else_anchor) + // CHECK: test.else_anchor(%{{.*}} : !test.else_anchor<5>) + test.else_anchor(%b : !test.else_anchor<5>) + return +} diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td --- a/mlir/test/mlir-tblgen/op-format.td +++ b/mlir/test/mlir-tblgen/op-format.td @@ -40,3 +40,34 @@ def CustomStringLiteralC : TestFormat_Op<[{ custom("$_builder.getStringAttr(\"foo\")") attr-dict }]>; + +//===----------------------------------------------------------------------===// +// Optional Groups +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: OptionalGroupA::parse +// CHECK: if (::mlir::succeeded(parser.parseOptionalQuestion()) +// CHECK-NEXT: else +// CHECK: parser.parseOptionalOperand +// CHECK-LABEL: OptionalGroupA::print +// CHECK: if (!getA()) +// CHECK-NEXT: odsPrinter << ' ' << "?"; +// CHECK-NEXT: else +// CHECK: odsPrinter << value; +def OptionalGroupA : TestFormat_Op<[{ + (`?`) : ($a^)? attr-dict +}]>, Arguments<(ins Optional:$a)>; + +// CHECK-LABEL: OptionalGroupB::parse +// CHECK: if (::mlir::succeeded(parser.parseOptionalKeyword("foo"))) +// CHECK-NEXT: else +// CHECK-NEXT: result.addAttribute("a", parser.getBuilder().getUnitAttr()) +// CHECK: parser.parseKeyword("bar") +// CHECK-LABEL: OptionalGroupB::print +// CHECK: if (!(*this)->getAttr("a")) +// CHECK-NEXT: odsPrinter << ' ' << "foo" +// CHECK-NEXT: else +// CHECK-NEXT: odsPrinter << ' ' << "bar" +def OptionalGroupB : TestFormat_Op<[{ + (`foo`) : (`bar` $a^)? attr-dict +}]>, Arguments<(ins UnitAttr:$a)>; diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -656,10 +656,10 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, MethodBody &os) { - ArrayRef elements = - el->getThenElements().drop_front(el->getParseStart()); + ArrayRef thenElements = + el->getThenElements(/*parseable=*/true); - FormatElement *first = elements.front(); + FormatElement *first = thenElements.front(); const auto guardOn = [&](auto params) { os << "if (!("; llvm::interleave( @@ -687,12 +687,12 @@ } os.indent(); - // Generate the parsers for the rest of the elements. - for (FormatElement *element : el->getElseElements()) + // Generate the parsers for the rest of the thenElements. + for (FormatElement *element : el->getElseElements(/*parseable=*/true)) genElementParser(element, ctx, os); os.unindent() << "} else {\n"; os.indent(); - for (FormatElement *element : elements.drop_front()) + for (FormatElement *element : thenElements.drop_front()) genElementParser(element, ctx, os); os.unindent() << "}\n"; } @@ -781,9 +781,11 @@ /// Generate code to guard printing on the presence of any optional parameters. template -static void guardOnAny(FmtContext &ctx, MethodBody &os, - ParameterRange &¶ms) { +static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &¶ms, + bool inverted = false) { os << "if ("; + if (inverted) + os << "!"; llvm::interleave( params, os, [&](ParameterElement *param) { param->genPrintGuard(ctx, os); }, " || "); @@ -860,12 +862,12 @@ MethodBody &os) { FormatElement *anchor = el->getAnchor(); if (auto *param = dyn_cast(anchor)) { - guardOnAny(ctx, os, llvm::makeArrayRef(param)); + guardOnAny(ctx, os, llvm::makeArrayRef(param), el->isInverted()); } else if (auto *params = dyn_cast(anchor)) { - guardOnAny(ctx, os, params->getParams()); + guardOnAny(ctx, os, params->getParams(), el->isInverted()); } else { auto *strct = cast(anchor); - guardOnAny(ctx, os, strct->getParams()); + guardOnAny(ctx, os, strct->getParams(), el->isInverted()); } // Generate the printer for the contained elements. { diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -378,33 +378,48 @@ /// Create an optional group with the given child elements. OptionalElement(std::vector &&thenElements, std::vector &&elseElements, - FormatElement *anchor, unsigned parseStart) + unsigned thenParseStart, unsigned elseParseStart, + FormatElement *anchor, bool inverted) : thenElements(std::move(thenElements)), - elseElements(std::move(elseElements)), anchor(anchor), - parseStart(parseStart) {} - - /// Return the `then` elements of the optional group. - ArrayRef getThenElements() const { return thenElements; } + elseElements(std::move(elseElements)), thenParseStart(thenParseStart), + elseParseStart(elseParseStart), anchor(anchor), inverted(inverted) {} + + /// Return the `then` elements of the optional group. Drops the first + /// `thenParseStart` whitespace elements if `parseable` is true. + ArrayRef getThenElements(bool parseable = false) const { + return llvm::makeArrayRef(thenElements) + .drop_front(parseable ? thenParseStart : 0); + } - /// Return the `else` elements of the optional group. - ArrayRef getElseElements() const { return elseElements; } + /// Return the `else` elements of the optional group. Drops the first + /// `thenParseStart` whitespace elements if `parseable` is true. + ArrayRef getElseElements(bool parseable = false) const { + return llvm::makeArrayRef(elseElements) + .drop_front(parseable ? elseParseStart : 0); + } /// Return the anchor of the optional group. FormatElement *getAnchor() const { return anchor; } - /// Return the index of the first element to be parsed. - unsigned getParseStart() const { return parseStart; } + /// Return true if the optional group is inverted. + bool isInverted() const { return inverted; } private: /// The child elements emitted when the anchor is present. std::vector thenElements; /// The child elements emitted when the anchor is not present. std::vector elseElements; - /// The anchor element of the optional group. - FormatElement *anchor; /// The index of the first element that is parsed in `thenElements`. That is, /// the first non-whitespace element. - unsigned parseStart; + unsigned thenParseStart; + /// The index of the first element that is parsed in `elseElements`. That is, + /// the first non-whitespace element. + unsigned elseParseStart; + /// The anchor element of the optional group. + FormatElement *anchor; + /// Whether the optional group condition is inverted and the anchor element is + /// in the else group. + bool inverted; }; //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -321,35 +321,42 @@ // Parse the child elements for this optional group. std::vector thenElements, elseElements; FormatElement *anchor = nullptr; - do { - FailureOr element = parseElement(TopLevelContext); - if (failed(element)) - return failure(); - // Check for an anchor. - if (curToken.is(FormatToken::caret)) { - if (anchor) - return emitError(curToken.getLoc(), "only one element can be marked as " - "the anchor of an optional group"); - anchor = *element; - consumeToken(); - } - thenElements.push_back(*element); - } while (!curToken.is(FormatToken::r_paren)); + auto parseChildElements = + [this, &anchor](std::vector &elements) -> LogicalResult { + do { + FailureOr element = parseElement(TopLevelContext); + if (failed(element)) + return failure(); + // Check for an anchor. + if (curToken.is(FormatToken::caret)) { + if (anchor) { + return emitError(curToken.getLoc(), + "only one element can be marked as the anchor of an " + "optional group"); + } + anchor = *element; + consumeToken(); + } + elements.push_back(*element); + } while (!curToken.is(FormatToken::r_paren)); + return success(); + }; + + // Parse the 'then' elements. If the anchor was found in this group, then the + // optional is not inverted. + if (failed(parseChildElements(thenElements))) + return failure(); consumeToken(); + bool inverted = !anchor; // Parse the `else` elements of this optional group. if (curToken.is(FormatToken::colon)) { consumeToken(); - if (failed( - parseToken(FormatToken::l_paren, - "expected '(' to start else branch of optional group"))) + if (failed(parseToken( + FormatToken::l_paren, + "expected '(' to start else branch of optional group")) || + failed(parseChildElements(elseElements))) return failure(); - do { - FailureOr element = parseElement(TopLevelContext); - if (failed(element)) - return failure(); - elseElements.push_back(*element); - } while (!curToken.is(FormatToken::r_paren)); consumeToken(); } if (failed(parseToken(FormatToken::question, @@ -367,18 +374,21 @@ // Get the first parsable element. It must be an element that can be // optionally-parsed. - auto parseBegin = llvm::find_if_not(thenElements, [](FormatElement *element) { + auto isWhitespace = [](FormatElement *element) { return isa(element); - }); - if (!isa(*parseBegin)) { + }; + auto thenParseBegin = llvm::find_if_not(thenElements, isWhitespace); + auto elseParseBegin = llvm::find_if_not(elseElements, isWhitespace); + unsigned thenParseStart = std::distance(thenElements.begin(), thenParseBegin); + unsigned elseParseStart = std::distance(elseElements.begin(), elseParseBegin); + + if (!isa(*thenParseBegin)) { return emitError(loc, "first parsable element of an optional group must be " "a literal or variable"); } - - unsigned parseStart = std::distance(thenElements.begin(), parseBegin); return create(std::move(thenElements), - std::move(elseElements), anchor, - parseStart); + std::move(elseElements), thenParseStart, + elseParseStart, anchor, inverted); } FailureOr FormatParser::parseCustomDirective(SMLoc loc, diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1119,17 +1119,43 @@ GenContext genCtx) { /// Optional Group. if (auto *optional = dyn_cast(element)) { - ArrayRef elements = - optional->getThenElements().drop_front(optional->getParseStart()); + auto genElementParsers = [&](FormatElement *firstElement, + ArrayRef elements, + bool thenGroup) { + // If the anchor is a unit attribute, we don't need to print it. When + // parsing, we will add this attribute if this group is present. + FormatElement *elidedAnchorElement = nullptr; + auto *anchorAttr = dyn_cast(optional->getAnchor()); + if (anchorAttr && anchorAttr != firstElement && + anchorAttr->isUnitAttr()) { + elidedAnchorElement = anchorAttr; + + if (!thenGroup == optional->isInverted()) { + // Add the anchor unit attribute to the operation state. + body << " result.addAttribute(\"" << anchorAttr->getVar()->name + << "\", parser.getBuilder().getUnitAttr());\n"; + } + } + + // Generate the rest of the elements inside an optional group. Elements in + // an optional group after the guard are parsed as required. + for (FormatElement *childElement : elements) + if (childElement != elidedAnchorElement) + genElementParser(childElement, body, attrTypeCtx, + GenContext::Optional); + }; + + ArrayRef thenElements = + optional->getThenElements(/*parseable=*/true); // Generate a special optional parser for the first element to gate the // parsing of the rest of the elements. - FormatElement *firstElement = elements.front(); + FormatElement *firstElement = thenElements.front(); if (auto *attrVar = dyn_cast(firstElement)) { genElementParser(attrVar, body, attrTypeCtx); body << " if (" << attrVar->getVar()->name << "Attr) {\n"; } else if (auto *literal = dyn_cast(firstElement)) { - body << " if (succeeded(parser.parseOptional"; + body << " if (::mlir::succeeded(parser.parseOptional"; genLiteralParser(literal->getSpelling(), body); body << ")) {\n"; } else if (auto *opVar = dyn_cast(firstElement)) { @@ -1151,31 +1177,16 @@ } } - // If the anchor is a unit attribute, we don't need to print it. When - // parsing, we will add this attribute if this group is present. - FormatElement *elidedAnchorElement = nullptr; - auto *anchorAttr = dyn_cast(optional->getAnchor()); - if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) { - elidedAnchorElement = anchorAttr; - - // Add the anchor unit attribute to the operation state. - body << " result.addAttribute(\"" << anchorAttr->getVar()->name - << "\", parser.getBuilder().getUnitAttr());\n"; - } - - // Generate the rest of the elements inside an optional group. Elements in - // an optional group after the guard are parsed as required. - for (FormatElement *childElement : llvm::drop_begin(elements, 1)) - if (childElement != elidedAnchorElement) - genElementParser(childElement, body, attrTypeCtx, GenContext::Optional); + genElementParsers(firstElement, thenElements.drop_front(), true); body << " }"; // Generate the else elements. auto elseElements = optional->getElseElements(); if (!elseElements.empty()) { body << " else {\n"; - for (FormatElement *childElement : elseElements) - genElementParser(childElement, body, attrTypeCtx); + ArrayRef elseElements = + optional->getElseElements(/*parsable=*/true); + genElementParsers(elseElements.front(), elseElements, false); body << " }"; } body << "\n"; @@ -1842,15 +1853,15 @@ const NamedTypeConstraint *var = element->getVar(); std::string name = op.getGetterName(var->name); if (var->isOptional()) - body << " if (" << name << "()) {\n"; + body << name << "()"; else if (var->isVariadic()) - body << " if (!" << name << "().empty()) {\n"; + body << "!" << name << "().empty()"; }) .Case([&](RegionVariable *element) { const NamedRegion *var = element->getVar(); std::string name = op.getGetterName(var->name); // TODO: Add a check for optional regions here when ODS supports it. - body << " if (!" << name << "().empty()) {\n"; + body << "!" << name << "().empty()"; }) .Case([&](TypeDirective *element) { genOptionalGroupPrinterAnchor(element->getArg(), op, body); @@ -1859,8 +1870,7 @@ genOptionalGroupPrinterAnchor(element->getInputs(), op, body); }) .Case([&](AttributeVariable *attr) { - body << " if ((*this)->getAttr(\"" << attr->getVar()->name - << "\")) {\n"; + body << "(*this)->getAttr(\"" << attr->getVar()->name << "\")"; }); } @@ -1912,39 +1922,45 @@ if (OptionalElement *optional = dyn_cast(element)) { // Emit the check for the presence of the anchor element. FormatElement *anchor = optional->getAnchor(); + body << " if ("; + if (optional->isInverted()) + body << "!"; genOptionalGroupPrinterAnchor(anchor, op, body); + body << ") {\n"; + body.indent(); // If the anchor is a unit attribute, we don't need to print it. When // parsing, we will add this attribute if this group is present. - auto elements = optional->getThenElements(); + ArrayRef thenElements = optional->getThenElements(); + ArrayRef elseElements = optional->getElseElements(); FormatElement *elidedAnchorElement = nullptr; auto *anchorAttr = dyn_cast(anchor); - if (anchorAttr && anchorAttr != elements.front() && + if (anchorAttr && anchorAttr != thenElements.front() && + (elseElements.empty() || anchorAttr != elseElements.front()) && anchorAttr->isUnitAttr()) { elidedAnchorElement = anchorAttr; } + auto genElementPrinters = [&](ArrayRef elements) { + for (FormatElement *childElement : elements) { + if (childElement != elidedAnchorElement) { + genElementPrinter(childElement, body, op, shouldEmitSpace, + lastWasPunctuation); + } + } + }; // Emit each of the elements. - for (FormatElement *childElement : elements) { - if (childElement != elidedAnchorElement) { - genElementPrinter(childElement, body, op, shouldEmitSpace, - lastWasPunctuation); - } - } - body << " }"; + genElementPrinters(thenElements); + body << "}"; // Emit each of the else elements. - auto elseElements = optional->getElseElements(); if (!elseElements.empty()) { body << " else {\n"; - for (FormatElement *childElement : elseElements) { - genElementPrinter(childElement, body, op, shouldEmitSpace, - lastWasPunctuation); - } - body << " }"; + genElementPrinters(elseElements); + body << "}"; } - body << "\n"; + body.unindent() << "\n"; return; }