diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -733,7 +733,7 @@ An example of an operation with an optional group is `std.return`, which has a variadic number of operands. -``` +```tablegen def ReturnOp : ... { let arguments = (ins Variadic:$operands); @@ -743,6 +743,36 @@ } ``` +##### Unit Attributes + +In MLIR, the [`unit` Attribute](LangRef.md#unit-attribute) is special in that it +only has one possible value, i.e. it derives meaning from its existence. When a +unit attribute is used to anchor an optional group and is not the first element +of the group, the presence of the unit attribute can be directly correlated with +the presence of the optional group itself. As such, in these situations the unit +attribute will not be printed or present in the output and will be automatically +inferred when parsing by the presence of the optional group itself. + +For example, the following operation: + +```tablegen +def FooOp : ... { + let arguments = (ins UnitAttr:$is_read_only); + + let assemblyFormat = "attr-dict (`is_read_only` $is_read_only^)?"; +} +``` + +would be formatted as such: + +```mlir +// When the unit attribute is present: +foo.op is_read_only + +// When the unit attribute is not present: +foo.op +``` + #### Requirements The format specification has a certain set of requirements that must be adhered diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1391,6 +1391,17 @@ let assemblyFormat = "$operands attr-dict `:` type($result)"; } +def FormatOptionalUnitAttr : TEST_Op<"format_optional_unit_attribute"> { + let arguments = (ins UnitAttr:$is_optional); + let assemblyFormat = "(`is_optional` $is_optional^)? attr-dict"; +} + +def FormatOptionalUnitAttrNoElide + : TEST_Op<"format_optional_unit_attribute_no_elide"> { + let arguments = (ins UnitAttr:$is_optional); + let assemblyFormat = "($is_optional^)? attr-dict"; +} + //===----------------------------------------------------------------------===// // AllTypesMatch type inference //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -82,6 +82,20 @@ }) { arg_names = ["i", "j", "k"] } : () -> () +//===----------------------------------------------------------------------===// +// Format optional attributes +//===----------------------------------------------------------------------===// + +// CHECK: test.format_optional_unit_attribute is_optional +test.format_optional_unit_attribute is_optional + +// CHECK: test.format_optional_unit_attribute +// CHECK-NOT: is_optional +test.format_optional_unit_attribute + +// CHECK: test.format_optional_unit_attribute_no_elide unit +test.format_optional_unit_attribute_no_elide unit + //===----------------------------------------------------------------------===// // Format optional operands and results //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -107,6 +107,11 @@ Optional attrType = var->attr.getValueType(); return attrType ? attrType->getBuilderCall() : llvm::None; } + + /// Return if this attribute refers to a UnitAttr. + bool isUnitAttr() const { + return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr"; + } }; /// This class represents a variable that refers to an operand argument. @@ -645,9 +650,23 @@ body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n"; } + // If the anchor is a unit attribute, we don't need to print it. When + // parsing, we will add this attribute if this group is present. + Element *elidedAnchorElement = nullptr; + auto *anchorAttr = dyn_cast(optional->getAnchor()); + if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) { + elidedAnchorElement = anchorAttr; + + // Add the anchor unit attribute to the operation state. + body << " result.addAttribute(\"" << anchorAttr->getVar()->name + << "\", parser.getBuilder().getUnitAttr());\n"; + } + // Generate the rest of the elements normally. - for (auto &childElement : llvm::drop_begin(elements, 1)) - genElementParser(&childElement, body, attrTypeCtx); + for (Element &childElement : llvm::drop_begin(elements, 1)) { + if (&childElement != elidedAnchorElement) + genElementParser(&childElement, body, attrTypeCtx); + } body << " }\n"; /// Literals. @@ -1058,10 +1077,23 @@ << cast(anchor)->getVar()->name << "\")) {\n"; } + // If the anchor is a unit attribute, we don't need to print it. When + // parsing, we will add this attribute if this group is present. + auto elements = optional->getElements(); + Element *elidedAnchorElement = nullptr; + auto *anchorAttr = dyn_cast(anchor); + if (anchorAttr && anchorAttr != &*elements.begin() && + anchorAttr->isUnitAttr()) { + elidedAnchorElement = anchorAttr; + } + // Emit each of the elements. - for (Element &childElement : optional->getElements()) - genElementPrinter(&childElement, body, fmt, op, shouldEmitSpace, - lastWasPunctuation); + for (Element &childElement : elements) { + if (&childElement != elidedAnchorElement) { + genElementPrinter(&childElement, body, fmt, op, shouldEmitSpace, + lastWasPunctuation); + } + } body << " }\n"; return; }