diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -777,8 +777,8 @@ * The first element of the group must either be a attribute, literal, operand, or region. - This is because the first element must be optionally parsable. -* Exactly one argument variable within the group must be marked as the anchor - of the group. +* Exactly one argument variable or type directive within the group must be + marked as the anchor of the group. - The anchor is the element whose presence controls whether the group should be printed/parsed. - An element is marked as the anchor by adding a trailing `^`. @@ -789,11 +789,9 @@ valid elements within the group. - Any attribute variable may be used, but only optional attributes can be marked as the anchor. - - Only variadic or optional operand arguments can be used. + - Only variadic or optional results and operand arguments and can be used. - All region variables can be used. When a non-variable length region is used, if the group is not present the region is empty. - - The operands to a type directive must be defined within the optional - group. An example of an operation with an optional group is `std.return`, which has a variadic number of operands. diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1571,6 +1571,25 @@ (`[` $variadic^ `]`)? attr-dict }]>; +// Test optional result type formatting. +class FormatOptionalResultOpBase + : TEST_Op<"format_optional_result_" # suffix # "_op", + [AttrSizedResultSegments]> { + let results = (outs Optional:$optional, Variadic:$variadic); + let assemblyFormat = fmt; +} +def FormatOptionalResultAOp : FormatOptionalResultOpBase<"a", [{ + (`:` type($optional)^ `->` type($variadic))? attr-dict +}]>; + +def FormatOptionalResultBOp : FormatOptionalResultOpBase<"b", [{ + (`:` type($optional) `->` type($variadic)^)? attr-dict +}]>; + +def FormatOptionalResultCOp : FormatOptionalResultOpBase<"c", [{ + (`:` functional-type($optional, $variadic)^)? attr-dict +}]>; + def FormatTwoVariadicOperandsNoBuildableTypeOp : TEST_Op<"format_two_variadic_operands_no_buildable_type_op", [AttrSizedOperandSegments]> { diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -333,7 +333,7 @@ def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{ (type($operand) $operand^)? attr-dict }]>, Arguments<(ins Optional:$operand)>; -// CHECK: error: type directive can only refer to variables within the optional group +// CHECK: error: only literals, types, and variables can be used within an optional group def OptionalInvalidE : TestFormat_Op<"optional_invalid_e", [{ (`,` $attr^ type(operands))? attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; @@ -349,9 +349,9 @@ def OptionalInvalidH : TestFormat_Op<"optional_invalid_h", [{ ($arg^) attr-dict }]>, Arguments<(ins I64:$arg)>; -// CHECK: error: only variables can be used to anchor an optional group +// CHECK: error: only literals, types, and variables can be used within an optional group def OptionalInvalidI : TestFormat_Op<"optional_invalid_i", [{ - ($arg type($arg)^) attr-dict + (functional-type($arg, results)^)? attr-dict }]>, Arguments<(ins Variadic:$arg)>; // CHECK: error: only literals, types, and variables can be used within an optional group def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{ @@ -361,11 +361,11 @@ def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{ ($arg^) }]>, Arguments<(ins Variadic:$arg)>; -// CHECK: error: only variables can be used to anchor an optional group +// CHECK: error: only variables and types can be used to anchor an optional group def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{ (custom($arg)^)? }]>, Arguments<(ins I64:$arg)>; -// CHECK: error: only variables can be used to anchor an optional group +// CHECK: error: only variables and types can be used to anchor an optional group def OptionalInvalidM : TestFormat_Op<"optional_invalid_m", [{ (` `^)? }]>, Arguments<(ins)>; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -220,6 +220,25 @@ // CHECK: test.format_optional_operand_result_b_op : i64 test.format_optional_operand_result_b_op : i64 +//===----------------------------------------------------------------------===// +// Format optional results +//===----------------------------------------------------------------------===// + +// CHECK: test.format_optional_result_a_op +test.format_optional_result_a_op + +// CHECK: test.format_optional_result_a_op : i64 -> i64, i64 +test.format_optional_result_a_op : i64 -> i64, i64 + +// CHECK: test.format_optional_result_b_op +test.format_optional_result_b_op + +// CHECK: test.format_optional_result_b_op : i64 -> i64, i64 +test.format_optional_result_b_op : i64 -> i64, i64 + +// CHECK: test.format_optional_result_c_op : (i64) -> (i64, i64) +test.format_optional_result_c_op : (i64) -> (i64, i64) + //===----------------------------------------------------------------------===// // Format custom directives //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1749,6 +1749,33 @@ " }\n"; } +/// Generate the check for the anchor of an optional group. +static void genOptionalGroupPrinterAnchor(Element *anchor, OpMethodBody &body) { + TypeSwitch(anchor) + .Case([&](auto *element) { + const NamedTypeConstraint *var = element->getVar(); + if (var->isOptional()) + body << " if (" << var->name << "()) {\n"; + else if (var->isVariadic()) + body << " if (!" << var->name << "().empty()) {\n"; + }) + .Case([&](RegionVariable *element) { + const NamedRegion *var = element->getVar(); + // TODO: Add a check for optional regions here when ODS supports it. + body << " if (!" << var->name << "().empty()) {\n"; + }) + .Case([&](TypeDirective *element) { + genOptionalGroupPrinterAnchor(element->getOperand(), body); + }) + .Case([&](FunctionalTypeDirective *element) { + genOptionalGroupPrinterAnchor(element->getInputs(), body); + }) + .Case([&](AttributeVariable *attr) { + body << " if ((*this)->getAttr(\"" << attr->getVar()->name + << "\")) {\n"; + }); +} + void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body, Operator &op, bool &shouldEmitSpace, bool &lastWasPunctuation) { @@ -1769,21 +1796,7 @@ if (OptionalElement *optional = dyn_cast(element)) { // Emit the check for the presence of the anchor element. Element *anchor = optional->getAnchor(); - if (auto *operand = dyn_cast(anchor)) { - const NamedTypeConstraint *var = operand->getVar(); - if (var->isOptional()) - body << " if (" << var->name << "()) {\n"; - else if (var->isVariadic()) - body << " if (!" << var->name << "().empty()) {\n"; - } else if (auto *region = dyn_cast(anchor)) { - const NamedRegion *var = region->getVar(); - // TODO: Add a check for optional here when ODS supports it. - body << " if (!" << var->name << "().empty()) {\n"; - - } else { - body << " if ((*this)->getAttr(\"" - << cast(anchor)->getVar()->name << "\")) {\n"; - } + genOptionalGroupPrinterAnchor(anchor, body); // If the anchor is a unit attribute, we don't need to print it. When // parsing, we will add this attribute if this group is present. @@ -2244,8 +2257,9 @@ bool isTopLevel); LogicalResult parseOptionalChildElement( std::vector> &childElements, - SmallPtrSetImpl &seenVariables, Optional &anchorIdx); + LogicalResult verifyOptionalChildElement(Element *element, + llvm::SMLoc childLoc, bool isAnchor); /// Parse the various different directives. LogicalResult parseAttrDictDirective(std::unique_ptr &element, @@ -2315,7 +2329,6 @@ llvm::DenseSet seenOperands; llvm::DenseSet seenRegions; llvm::DenseSet seenSuccessors; - llvm::DenseSet optionalVariables; }; } // end anonymous namespace @@ -2760,10 +2773,9 @@ // Parse the child elements for this optional group. std::vector> elements; - SmallPtrSet seenVariables; Optional anchorIdx; do { - if (failed(parseOptionalChildElement(elements, seenVariables, anchorIdx))) + if (failed(parseOptionalChildElement(elements, anchorIdx))) return ::mlir::failure(); } while (curToken.getKind() != Token::r_paren); consumeToken(); @@ -2787,31 +2799,6 @@ "first parsable element of an operand group must be " "an attribute, literal, operand, or region"); - // After parsing all of the elements, ensure that all type directives refer - // only to elements within the group. - auto checkTypeOperand = [&](Element *typeEle) { - auto *opVar = dyn_cast(typeEle); - const NamedTypeConstraint *var = opVar ? opVar->getVar() : nullptr; - if (!seenVariables.count(var)) - return emitError(curLoc, "type directive can only refer to variables " - "within the optional group"); - return ::mlir::success(); - }; - for (auto &ele : elements) { - if (auto *typeEle = dyn_cast(ele.get())) { - if (failed(checkTypeOperand(typeEle->getOperand()))) - return failure(); - } else if (auto *typeEle = dyn_cast(ele.get())) { - if (failed(checkTypeOperand(typeEle->getOperand()))) - return ::mlir::failure(); - } else if (auto *typeEle = dyn_cast(ele.get())) { - if (failed(checkTypeOperand(typeEle->getInputs())) || - failed(checkTypeOperand(typeEle->getResults()))) - return ::mlir::failure(); - } - } - - optionalVariables.insert(seenVariables.begin(), seenVariables.end()); auto parseStart = parseBegin - elements.begin(); element = std::make_unique(std::move(elements), *anchorIdx, parseStart); @@ -2820,7 +2807,6 @@ LogicalResult FormatParser::parseOptionalChildElement( std::vector> &childElements, - SmallPtrSetImpl &seenVariables, Optional &anchorIdx) { llvm::SMLoc childLoc = curToken.getLoc(); childElements.push_back({}); @@ -2837,7 +2823,14 @@ consumeToken(); } - return TypeSwitch(childElements.back().get()) + return verifyOptionalChildElement(childElements.back().get(), childLoc, + isAnchor); +} + +LogicalResult FormatParser::verifyOptionalChildElement(Element *element, + llvm::SMLoc childLoc, + bool isAnchor) { + return TypeSwitch(element) // All attributes can be within the optional group, but only optional // attributes can be the anchor. .Case([&](AttributeVariable *attrEle) { @@ -2852,7 +2845,14 @@ if (!ele->getVar()->isVariableLength()) return emitError(childLoc, "only variable length operands can be " "used within an optional group"); - seenVariables.insert(ele->getVar()); + return ::mlir::success(); + }) + // Only optional-like(i.e. variadic) results can be within an optional + // group. + .Case([&](ResultVariable *ele) { + if (!ele->getVar()->isVariableLength()) + return emitError(childLoc, "only variable length results can be " + "used within an optional group"); return ::mlir::success(); }) .Case([&](RegionVariable *) { @@ -2860,16 +2860,27 @@ // a check here. return ::mlir::success(); }) - // Literals, whitespace, custom directives, and type directives may be - // used, but they can't anchor the group. - .Case([&](Element *) { - if (isAnchor) - return emitError(childLoc, "only variables can be used to anchor " - "an optional group"); - return ::mlir::success(); + .Case([&](TypeDirective *ele) { + return verifyOptionalChildElement(ele->getOperand(), childLoc, + /*isAnchor=*/false); }) + .Case([&](FunctionalTypeDirective *ele) { + if (failed(verifyOptionalChildElement(ele->getInputs(), childLoc, + /*isAnchor=*/false))) + return failure(); + return verifyOptionalChildElement(ele->getResults(), childLoc, + /*isAnchor=*/false); + }) + // Literals, whitespace, and custom directives may be used, but they can't + // anchor the group. + .Case( + [&](Element *) { + if (isAnchor) + return emitError(childLoc, "only variables and types can be used " + "to anchor an optional group"); + return ::mlir::success(); + }) .Default([&](Element *) { return emitError(childLoc, "only literals, types, and variables can be " "used within an optional group");