diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2174,6 +2174,11 @@ let assemblyFormat = "($attr^)? attr-dict"; } +def FormatOptionalDefaultAttr : TEST_Op<"format_optional_default_attr"> { + let arguments = (ins DefaultValuedAttr:$attr); + let assemblyFormat = "($attr^)? attr-dict"; +} + def FormatOptionalWithElse : TEST_Op<"format_optional_else"> { let arguments = (ins UnitAttr:$isFirstBranchPresent); let assemblyFormat = "(`then` $isFirstBranchPresent^):(`else`)? attr-dict"; diff --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td --- a/mlir/test/mlir-tblgen/op-format-invalid.td +++ b/mlir/test/mlir-tblgen/op-format-invalid.td @@ -369,7 +369,7 @@ def OptionalInvalidF : TestFormat_Op<[{ ($attr^ $attr2^)? attr-dict }]>, Arguments<(ins OptionalAttr:$attr, OptionalAttr:$attr2)>; -// CHECK: error: only optional attributes can be used to anchor an optional group +// CHECK: error: only optional or default-valued attributes can be used to anchor an optional group def OptionalInvalidG : TestFormat_Op<[{ ($attr^)? attr-dict }]>, Arguments<(ins I64Attr:$attr)>; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -197,6 +197,13 @@ // CHECK-NOT: "case5" test.format_optional_enum_attr +// CHECK: test.format_optional_default_attr "foo" +test.format_optional_default_attr "foo" + +// CHECK: test.format_optional_default_attr +// CHECK-NOT: default +test.format_optional_default_attr "default" + //===----------------------------------------------------------------------===// // Format optional operands and results //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td --- a/mlir/test/mlir-tblgen/op-format.td +++ b/mlir/test/mlir-tblgen/op-format.td @@ -71,3 +71,13 @@ def OptionalGroupB : TestFormat_Op<[{ (`foo`) : (`bar` $a^)? attr-dict }]>, Arguments<(ins UnitAttr:$a)>; + +// Optional group anchored on a default-valued attribute: +// CHECK-LABEL: OptionalGroupC::parse +// CHECK: if ((*this)->getAttr("a") != ::mlir::OpBuilder(getContext()).getStringAttr("default")) { +// CHECK-NEXT: odsPrinter << ' '; +// CHECK-NEXT: odsPrinter.printAttributeWithoutType(getAAttr()); +// CHECK-NEXT: } +def OptionalGroupC : TestFormat_Op<[{ + ($a^)? attr-dict +}]>, Arguments<(ins DefaultValuedAttr:$a)>; diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1142,8 +1142,7 @@ // an optional group after the guard are parsed as required. for (FormatElement *childElement : elements) if (childElement != elidedAnchorElement) - genElementParser(childElement, body, attrTypeCtx, - GenContext::Optional); + genElementParser(childElement, body, attrTypeCtx); }; ArrayRef thenElements = @@ -1153,7 +1152,7 @@ // parsing of the rest of the elements. FormatElement *firstElement = thenElements.front(); if (auto *attrVar = dyn_cast(firstElement)) { - genElementParser(attrVar, body, attrTypeCtx); + genElementParser(attrVar, body, attrTypeCtx, GenContext::Optional); body << " if (" << attrVar->getVar()->name << "Attr) {\n"; } else if (auto *literal = dyn_cast(firstElement)) { body << " if (::mlir::succeeded(parser.parseOptional"; @@ -1259,7 +1258,7 @@ } else { attrTypeStr = "::mlir::Type{}"; } - if (genCtx == GenContext::Normal && var->attr.isOptional()) { + if (genCtx == GenContext::Optional) { body << formatv(optionalAttrParserCode, var->name, attrTypeStr); } else { if (attr->shouldBeQualified() || @@ -1872,8 +1871,22 @@ .Case([&](FunctionalTypeDirective *element) { genOptionalGroupPrinterAnchor(element->getInputs(), op, body); }) - .Case([&](AttributeVariable *attr) { - body << "(*this)->getAttr(\"" << attr->getVar()->name << "\")"; + .Case([&](AttributeVariable *element) { + Attribute attr = element->getVar()->attr; + body << "(*this)->getAttr(\"" << element->getVar()->name << "\")"; + if (attr.isOptional()) + return; // done + if (attr.hasDefaultValue()) { + // Consider a default-valued attribute as present if it's not the + // default value. + FmtContext fctx; + fctx.withBuilder("::mlir::OpBuilder(getContext())"); + body << " != " + << tgfmt(attr.getConstBuilderTemplate(), &fctx, + attr.getDefaultValue()); + return; + } + llvm_unreachable("attribute must be optional or default-valued"); }); } @@ -3185,9 +3198,10 @@ // All attributes can be within the optional group, but only optional // attributes can be the anchor. .Case([&](AttributeVariable *attrEle) { - if (isAnchor && !attrEle->getVar()->attr.isOptional()) - return emitError(loc, "only optional attributes can be used to " - "anchor an optional group"); + Attribute attr = attrEle->getVar()->attr; + if (isAnchor && !(attr.isOptional() || attr.hasDefaultValue())) + return emitError(loc, "only optional or default-valued attributes " + "can be used to anchor an optional group"); return success(); }) // Only optional-like(i.e. variadic) operands can be within an optional