diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -326,9 +326,27 @@ /// Parse a '<' token. virtual ParseResult parseLess() = 0; + /// Parse a '<' token if present. + virtual ParseResult parseOptionalLess() = 0; + /// Parse a '>' token. virtual ParseResult parseGreater() = 0; + /// Parse a '>' token if present. + virtual ParseResult parseOptionalGreater() = 0; + + /// Parse a '+' token. + virtual ParseResult parsePlus() = 0; + + /// Parse a '+' token if present. + virtual ParseResult parseOptionalPlus() = 0; + + /// Parse a '*' token. + virtual ParseResult parseStar() = 0; + + /// Parse a '*' token if present. + virtual ParseResult parseOptionalStar() = 0; + /// Parse a given keyword. ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") { auto loc = getCurrentLocation(); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -994,11 +994,21 @@ return parser.parseToken(Token::less, "expected '<'"); } + /// Parse a '<' token if present. + ParseResult parseOptionalLess() override { + return success(parser.consumeIf(Token::less)); + } + /// Parse a '>' token. ParseResult parseGreater() override { return parser.parseToken(Token::greater, "expected '>'"); } + /// Parse a '>' token if present. + ParseResult parseOptionalGreater() override { + return success(parser.consumeIf(Token::greater)); + } + /// Parse a `(` token. ParseResult parseLParen() override { return parser.parseToken(Token::l_paren, "expected '('"); @@ -1044,6 +1054,26 @@ return success(parser.consumeIf(Token::r_square)); } + /// Parses a '+' token. + ParseResult parsePlus() override { + return parser.parseToken(Token::plus, "expected '+'"); + } + + /// Parses a '+' token if present. + ParseResult parseOptionalPlus() override { + return success(parser.consumeIf(Token::plus)); + } + + /// Parses a '*' token. + ParseResult parseStar() override { + return parser.parseToken(Token::star, "expected '*'"); + } + + /// Parses a '*' token if present. + ParseResult parseOptionalStar() override { + return success(parser.consumeIf(Token::star)); + } + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1380,7 +1380,7 @@ def FormatLiteralOp : TEST_Op<"format_literal_op"> { let assemblyFormat = [{ - `keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` `` `(` ` ` `)` attr-dict + `keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` `` `(` ` ` `)` `+` `*` attr-dict }]; } diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -309,7 +309,7 @@ }]>; // CHECK-NOT: error def LiteralValid : TestFormat_Op<"literal_valid", [{ - `_` `:` `,` `=` `<` `>` `(` `)` `[` `]` ` ` `` `->` `abc$._` + `_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `+` `*` ` ` `` `->` `abc$._` attr-dict }]>; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -5,8 +5,8 @@ // CHECK: %[[MEMREF:.*]] = %memref = "foo.op"() : () -> (memref<1xf64>) -// CHECK: test.format_literal_op keyword_$. -> :, = <> () []( ) {foo.some_attr} -test.format_literal_op keyword_$. -> :, = <> () []( ) {foo.some_attr} +// CHECK: test.format_literal_op keyword_$. -> :, = <> () []( ) + * {foo.some_attr} +test.format_literal_op keyword_$. -> :, = <> () []( ) + * {foo.some_attr} // CHECK: test.format_attr_op 10 // CHECK-NOT: {attr diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -281,7 +281,7 @@ // If there is only one character, this must either be punctuation or a // single character bare identifier. if (value.size() == 1) - return isalpha(front) || StringRef("_:,=<>()[]{}?").contains(front); + return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front); // Check the punctuation that are larger than a single character. if (value == "->") @@ -762,7 +762,9 @@ .Case(")", "RParen()") .Case("[", "LSquare()") .Case("]", "RSquare()") - .Case("?", "Question()"); + .Case("?", "Question()") + .Case("+", "Plus()") + .Case("*", "Star()"); } /// Generate the storage code required for parsing the given element.