diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2183,6 +2183,13 @@ let assemblyFormat = "($attr^)? attr-dict"; } +def FormatOptionalDefaultAttrs : TEST_Op<"format_optional_default_attrs"> { + let arguments = (ins DefaultValuedStrAttr:$str, + DefaultValuedStrAttr:$sym, + DefaultValuedAttr:$e); + let assemblyFormat = "($str^)? ($sym^)? ($e^)? attr-dict"; +} + def FormatOptionalWithElse : TEST_Op<"format_optional_else"> { let arguments = (ins UnitAttr:$isFirstBranchPresent); let assemblyFormat = "(`then` $isFirstBranchPresent^):(`else`)? attr-dict"; diff --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td --- a/mlir/test/mlir-tblgen/op-format-invalid.td +++ b/mlir/test/mlir-tblgen/op-format-invalid.td @@ -369,7 +369,7 @@ def OptionalInvalidF : TestFormat_Op<[{ ($attr^ $attr2^)? attr-dict }]>, Arguments<(ins OptionalAttr:$attr, OptionalAttr:$attr2)>; -// CHECK: error: only optional attributes can be used to anchor an optional group +// CHECK: error: only optional or default-valued attributes can be used to anchor an optional group def OptionalInvalidG : TestFormat_Op<[{ ($attr^)? attr-dict }]>, Arguments<(ins I64Attr:$attr)>; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -197,6 +197,15 @@ // CHECK-NOT: "case5" test.format_optional_enum_attr +// CHECK: test.format_optional_default_attrs "foo" @foo case10 +test.format_optional_default_attrs "foo" @foo case10 + +// CHECK: test.format_optional_default_attr +// CHECK-NOT: "default" +// CHECK-NOT: @default +// CHECK-NOT: case5 +test.format_optional_default_attrs "default" @default case5 + //===----------------------------------------------------------------------===// // Format optional operands and results //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td --- a/mlir/test/mlir-tblgen/op-format.td +++ b/mlir/test/mlir-tblgen/op-format.td @@ -71,3 +71,13 @@ def OptionalGroupB : TestFormat_Op<[{ (`foo`) : (`bar` $a^)? attr-dict }]>, Arguments<(ins UnitAttr:$a)>; + +// Optional group anchored on a default-valued attribute: +// CHECK-LABEL: OptionalGroupC::parse +// CHECK: if ((*this)->getAttr("a") != ::mlir::OpBuilder((*this)->getContext()).getStringAttr("default")) { +// CHECK-NEXT: odsPrinter << ' '; +// CHECK-NEXT: odsPrinter.printAttributeWithoutType(getAAttr()); +// CHECK-NEXT: } +def OptionalGroupC : TestFormat_Op<[{ + ($a^)? attr-dict +}]>, Arguments<(ins DefaultValuedStrAttr:$a)>; diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1041,7 +1041,7 @@ /// Generate the parser for a enum attribute. static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body, - FmtContext &attrTypeCtx) { + FmtContext &attrTypeCtx, bool parseAsOptional) { Attribute baseAttr = var->attr.getBaseAttr(); const EnumAttr &enumAttr = cast(baseAttr); std::vector cases = enumAttr.getAllCases(); @@ -1065,7 +1065,7 @@ // If the attribute is not optional, build an error message for the missing // attribute. std::string errorMessage; - if (!var->attr.isOptional()) { + if (!parseAsOptional) { llvm::raw_string_ostream errorMessageOS(errorMessage); errorMessageOS << "return parser.emitError(loc, \"expected string or " @@ -1082,6 +1082,43 @@ validCaseKeywordsStr, errorMessage); } +// Generate the parser for an attribute. +static void genAttrParser(AttributeVariable *attr, MethodBody &body, + FmtContext &attrTypeCtx, bool parseAsOptional) { + const NamedAttribute *var = attr->getVar(); + + // Check to see if we can parse this as an enum attribute. + if (canFormatEnumAttr(var)) + return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional); + + // Check to see if we should parse this as a symbol name attribute. + if (shouldFormatSymbolNameAttr(var)) { + body << formatv(parseAsOptional ? optionalSymbolNameAttrParserCode + : symbolNameAttrParserCode, + var->name); + return; + } + + // If this attribute has a buildable type, use that when parsing the + // attribute. + std::string attrTypeStr; + if (Optional typeBuilder = attr->getTypeBuilder()) { + llvm::raw_string_ostream os(attrTypeStr); + os << tgfmt(*typeBuilder, &attrTypeCtx); + } else { + attrTypeStr = "::mlir::Type{}"; + } + if (parseAsOptional) { + body << formatv(optionalAttrParserCode, var->name, attrTypeStr); + } else { + if (attr->shouldBeQualified() || + var->attr.getStorageType() == "::mlir::Attribute") + body << formatv(genericAttrParserCode, var->name, attrTypeStr); + else + body << formatv(attrParserCode, var->name, attrTypeStr); + } +} + void OperationFormat::genParser(Operator &op, OpClass &opClass) { SmallVector paramList; paramList.emplace_back("::mlir::OpAsmParser &", "parser"); @@ -1153,7 +1190,7 @@ // parsing of the rest of the elements. FormatElement *firstElement = thenElements.front(); if (auto *attrVar = dyn_cast(firstElement)) { - genElementParser(attrVar, body, attrTypeCtx); + genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true); body << " if (" << attrVar->getVar()->name << "Attr) {\n"; } else if (auto *literal = dyn_cast(firstElement)) { body << " if (::mlir::succeeded(parser.parseOptional"; @@ -1236,38 +1273,9 @@ /// Arguments. } else if (auto *attr = dyn_cast(element)) { - const NamedAttribute *var = attr->getVar(); - - // Check to see if we can parse this as an enum attribute. - if (canFormatEnumAttr(var)) - return genEnumAttrParser(var, body, attrTypeCtx); - - // Check to see if we should parse this as a symbol name attribute. - if (shouldFormatSymbolNameAttr(var)) { - body << formatv(var->attr.isOptional() ? optionalSymbolNameAttrParserCode - : symbolNameAttrParserCode, - var->name); - return; - } - - // If this attribute has a buildable type, use that when parsing the - // attribute. - std::string attrTypeStr; - if (Optional typeBuilder = attr->getTypeBuilder()) { - llvm::raw_string_ostream os(attrTypeStr); - os << tgfmt(*typeBuilder, &attrTypeCtx); - } else { - attrTypeStr = "::mlir::Type{}"; - } - if (genCtx == GenContext::Normal && var->attr.isOptional()) { - body << formatv(optionalAttrParserCode, var->name, attrTypeStr); - } else { - if (attr->shouldBeQualified() || - var->attr.getStorageType() == "::mlir::Attribute") - body << formatv(genericAttrParserCode, var->name, attrTypeStr); - else - body << formatv(attrParserCode, var->name, attrTypeStr); - } + bool parseAsOptional = + (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional()); + genAttrParser(attr, body, attrTypeCtx, parseAsOptional); } else if (auto *operand = dyn_cast(element)) { ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); @@ -1872,8 +1880,22 @@ .Case([&](FunctionalTypeDirective *element) { genOptionalGroupPrinterAnchor(element->getInputs(), op, body); }) - .Case([&](AttributeVariable *attr) { - body << "(*this)->getAttr(\"" << attr->getVar()->name << "\")"; + .Case([&](AttributeVariable *element) { + Attribute attr = element->getVar()->attr; + body << "(*this)->getAttr(\"" << element->getVar()->name << "\")"; + if (attr.isOptional()) + return; // done + if (attr.hasDefaultValue()) { + // Consider a default-valued attribute as present if it's not the + // default value. + FmtContext fctx; + fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())"); + body << " != " + << tgfmt(attr.getConstBuilderTemplate(), &fctx, + attr.getDefaultValue()); + return; + } + llvm_unreachable("attribute must be optional or default-valued"); }); } @@ -3185,9 +3207,10 @@ // All attributes can be within the optional group, but only optional // attributes can be the anchor. .Case([&](AttributeVariable *attrEle) { - if (isAnchor && !attrEle->getVar()->attr.isOptional()) - return emitError(loc, "only optional attributes can be used to " - "anchor an optional group"); + Attribute attr = attrEle->getVar()->attr; + if (isAnchor && !(attr.isOptional() || attr.hasDefaultValue())) + return emitError(loc, "only optional or default-valued attributes " + "can be used to anchor an optional group"); return success(); }) // Only optional-like(i.e. variadic) operands can be within an optional