diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -713,7 +713,8 @@ information. An optional group is defined by wrapping a set of elements within `()` followed by a `?` and has the following requirements: -* The first element of the group must either be a literal or an operand. +* The first element of the group must either be a literal, attribute, or an + operand. - This is because the first element must be optionally parsable. * Exactly one argument variable within the group must be marked as the anchor of the group. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -384,6 +384,17 @@ StringRef attrName, NamedAttrList &attrs) = 0; + /// Parse an optional attribute. + virtual OptionalParseResult parseOptionalAttribute(Attribute &result, + Type type, + StringRef attrName, + NamedAttrList &attrs) = 0; + OptionalParseResult parseOptionalAttribute(Attribute &result, + StringRef attrName, + NamedAttrList &attrs) { + return parseOptionalAttribute(result, Type(), attrName, attrs); + } + /// Parse an attribute of a specific kind and type. template ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -187,6 +187,40 @@ } } +/// Parse an optional attribute with the provided type. +OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute, + Type type) { + switch (getToken().getKind()) { + case Token::at_identifier: + case Token::floatliteral: + case Token::integer: + case Token::hash_identifier: + case Token::kw_affine_map: + case Token::kw_affine_set: + case Token::kw_dense: + case Token::kw_false: + case Token::kw_loc: + case Token::kw_opaque: + case Token::kw_sparse: + case Token::kw_true: + case Token::kw_unit: + case Token::l_brace: + case Token::l_square: + case Token::minus: + case Token::string: + attribute = parseAttribute(type); + return success(attribute != nullptr); + + default: + // Parse an optional type attribute. + Type type; + OptionalParseResult result = parseOptionalType(type); + if (result.hasValue() && succeeded(*result)) + attribute = TypeAttr::get(type); + return result; + } +} + /// Attribute dictionary. /// /// attribute-dict ::= `{` `}` diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -184,6 +184,10 @@ /// Parse an arbitrary attribute with an optional type. Attribute parseAttribute(Type type = {}); + /// Parse an optional attribute with the provided type. + OptionalParseResult parseOptionalAttribute(Attribute &attribute, + Type type = {}); + /// Parse an attribute dictionary. ParseResult parseAttributeDict(NamedAttrList &attributes); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1011,6 +1011,17 @@ return success(); } + /// Parse an optional attribute. + OptionalParseResult parseOptionalAttribute(Attribute &result, Type type, + StringRef attrName, + NamedAttrList &attrs) override { + OptionalParseResult parseResult = + parser.parseOptionalAttribute(result, type); + if (parseResult.hasValue() && succeeded(*parseResult)) + attrs.push_back(parser.builder.getNamedAttr(attrName, result)); + return parseResult; + } + /// Parse a named dictionary into 'result' if it is present. ParseResult parseOptionalAttrDict(NamedAttrList &result) override { if (parser.getToken().isNot(Token::l_brace)) diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1253,9 +1253,13 @@ } // Test that we elide optional attributes that are within the syntax. -def FormatOptAttrOp : TEST_Op<"format_opt_attr_op"> { +def FormatOptAttrAOp : TEST_Op<"format_opt_attr_op_a"> { let arguments = (ins OptionalAttr:$opt_attr); - let assemblyFormat = "(`(`$opt_attr^`)`)? attr-dict"; + let assemblyFormat = "(`(` $opt_attr^ `)` )? attr-dict"; +} +def FormatOptAttrBOp : TEST_Op<"format_opt_attr_op_b"> { + let arguments = (ins OptionalAttr:$opt_attr); + let assemblyFormat = "($opt_attr^)? attr-dict"; } // Test that we elide attributes that are within the syntax. diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -206,10 +206,10 @@ def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{ ($attr)? attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; -// CHECK: error: first element of an operand group must be a literal or operand +// CHECK: error: first element of an operand group must be an attribute, literal, or operand def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{ - ($attr^)? attr-dict -}]>, Arguments<(ins OptionalAttr:$attr)>; + (type($operand) $operand^)? attr-dict +}]>, Arguments<(ins Optional:$operand)>; // CHECK: error: type directive can only refer to variables within the optional group def OptionalInvalidE : TestFormat_Op<"optional_invalid_e", [{ (`,` $attr^ type(operands))? attr-dict diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -12,9 +12,15 @@ // CHECK-NOT: {attr test.format_attr_op 10 -// CHECK: test.format_opt_attr_op(10) +// CHECK: test.format_opt_attr_op_a(10) // CHECK-NOT: {opt_attr -test.format_opt_attr_op(10) +test.format_opt_attr_op_a(10) +test.format_opt_attr_op_a + +// CHECK: test.format_opt_attr_op_b 10 +// CHECK-NOT: {opt_attr +test.format_opt_attr_op_b 10 +test.format_opt_attr_op_b // CHECK: test.format_attr_dict_w_keyword attributes {attr = 10 : i64} test.format_attr_dict_w_keyword attributes {attr = 10 : i64} diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -373,6 +373,15 @@ if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes)) return failure(); )"; +const char *const optionalAttrParserCode = R"( + {0} {1}Attr; + { + ::mlir::OptionalParseResult parseResult = + parser.parseOptionalAttribute({1}Attr{2}, "{1}", result.attributes); + if (parseResult.hasValue() && failed(*parseResult)) + return failure(); + } +)"; /// The code snippet used to generate a parser call for an enum attribute. /// @@ -397,6 +406,30 @@ result.addAttribute("{0}", {3}); } )"; +const char *const optionalEnumAttrParserCode = R"( + Attribute {0}Attr; + { + ::mlir::StringAttr attrVal; + ::mlir::NamedAttrList attrStorage; + auto loc = parser.getCurrentLocation(); + + ::mlir::OptionalParseResult parseResult = + parser.parseOptionalAttribute(attrVal, parser.getBuilder().getNoneType(), + "{0}", attrStorage); + if (parseResult.hasValue()) { + if (failed(*parseResult)) + return failure(); + + auto attrOptional = {1}::{2}(attrVal.getValue()); + if (!attrOptional) + return parser.emitError(loc, "invalid ") + << "{0} attribute specification: " << attrVal; + + {0}Attr = {3}; + result.addAttribute("{0}", {0}Attr); + } + } +)"; /// The code snippet used to generate a parser call for an operand. /// @@ -599,11 +632,15 @@ // Generate a special optional parser for the first element to gate the // parsing of the rest of the elements. - if (auto *literal = dyn_cast(&*elements.begin())) { + Element *firstElement = &*elements.begin(); + if (auto *attrVar = dyn_cast(firstElement)) { + genElementParser(attrVar, body, attrTypeCtx); + body << " if (" << attrVar->getVar()->name << "Attr) {\n"; + } else if (auto *literal = dyn_cast(firstElement)) { body << " if (succeeded(parser.parseOptional"; genLiteralParser(literal->getLiteral(), body); body << ")) {\n"; - } else if (auto *opVar = dyn_cast(&*elements.begin())) { + } else if (auto *opVar = dyn_cast(firstElement)) { genElementParser(opVar, body, attrTypeCtx); body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n"; } @@ -635,7 +672,9 @@ "attrOptional.getValue()"); } - body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(), + body << formatv(var->attr.isOptional() ? optionalEnumAttrParserCode + : enumAttrParserCode, + var->name, enumAttr.getCppNamespace(), enumAttr.getStringToSymbolFnName(), attrBuilderStr); return; } @@ -648,8 +687,9 @@ os << ", " << tgfmt(*typeBuilder, &attrTypeCtx); } - body << formatv(attrParserCode, var->attr.getStorageType(), var->name, - attrTypeStr); + body << formatv(var->attr.isOptional() ? optionalAttrParserCode + : attrParserCode, + var->attr.getStorageType(), var->name, attrTypeStr); } else if (auto *operand = dyn_cast(element)) { ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); StringRef name = operand->getVar()->name; @@ -1910,10 +1950,11 @@ // The first element of the group must be one that can be parsed/printed in an // optional fashion. - if (!isa(&*elements.front()) && - !isa(&*elements.front())) - return emitError(curLoc, "first element of an operand group must be a " - "literal or operand"); + Element *firstElement = &*elements.front(); + if (!isa(firstElement) && + !isa(firstElement) && !isa(firstElement)) + return emitError(curLoc, "first element of an operand group must be an " + "attribute, literal, or operand"); // After parsing all of the elements, ensure that all type directives refer // only to elements within the group.