diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -589,6 +589,12 @@ /// Parse a '|' token if present. virtual ParseResult parseOptionalVerticalBar() = 0; + /// Parse a '^' token. + virtual ParseResult parseCaret() = 0; + + /// Parse a bare identifier. + virtual ParseResult parseBareIdentifier(std::string *string) = 0; + /// Parse a quoted string token. ParseResult parseString(std::string *string) { auto loc = getCurrentLocation(); @@ -1479,9 +1485,9 @@ //===--------------------------------------------------------------------===// struct Argument { - UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. - Type type; // Type. - DictionaryAttr attrs; // Attributes if present. + UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. + Type type; // Type. + DictionaryAttr attrs; // Attributes if present. std::optional sourceLoc; // Source location specifier if present. }; diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -236,6 +236,22 @@ return success(parser.consumeIf(Token::vertical_bar)); } + /// Parse a '^' token. + ParseResult parseCaret() override { + return parser.parseToken(Token::caret, "expected '^'"); + } + + /// Parse a bare identifier + ParseResult parseBareIdentifier(std::string *string) override { + if (!parser.getToken().is(Token::bare_identifier)) + return emitError(getCurrentLocation(), "expected bare identifier"); + + if (string) + *string = parser.getToken().getBareIdentifier(); + parser.consumeToken(); + return success(); + } + /// Parses a quoted string token if present. ParseResult parseOptionalString(std::string *string) override { if (!parser.getToken().is(Token::string)) diff --git a/mlir/lib/AsmParser/Token.h b/mlir/lib/AsmParser/Token.h --- a/mlir/lib/AsmParser/Token.h +++ b/mlir/lib/AsmParser/Token.h @@ -115,6 +115,10 @@ /// value. std::string getSymbolReference() const; + /// Given a token containing a bare identifier, return the unescaped string + /// value. + std::string getBareIdentifier() const; + // Location processing. SMLoc getLoc() const; SMLoc getEndLoc() const; diff --git a/mlir/lib/AsmParser/Token.cpp b/mlir/lib/AsmParser/Token.cpp --- a/mlir/lib/AsmParser/Token.cpp +++ b/mlir/lib/AsmParser/Token.cpp @@ -155,6 +155,13 @@ return std::string(nameStr); } +/// Given a token containing a bare identifier, return the unescaped string +/// value. +std::string Token::getBareIdentifier() const { + assert(is(Token::bare_identifier) && "expected bare identifier"); + return std::string(getSpelling()); +} + /// Given a hash_identifier token like #123, try to parse the number out of /// the identifier, returning std::nullopt if it is a named identifier like #x /// or if the integer doesn't fit. diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -72,6 +72,7 @@ TOK_PUNCTUATION(r_square, "]") TOK_PUNCTUATION(star, "*") TOK_PUNCTUATION(vertical_bar, "|") +TOK_PUNCTUATION(caret, "^") TOK_PUNCTUATION(file_metadata_begin, "{-#") TOK_PUNCTUATION(file_metadata_end, "#-}") diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1195,6 +1195,13 @@ return } +// GENERIC-LABEL: parse_bare_id +func.func @parse_bare_id_test() { + // CHECK: test.parse_bare_id + test.parse_bare_id + return +} + // CHECK-LABEL: func @"\22_string_symbol_reference\22" func.func @"\"_string_symbol_reference\""() { // CHECK: ref = @"\22_string_symbol_reference\22" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1108,6 +1108,22 @@ p << " \"" << llvm::encodeBase64(getB64()) << "\""; } +ParseResult ParseBareIdOp::parse(OpAsmParser &parser, OperationState &result) { + if (parser.parseLess()) + return failure(); + + std::string identifier; + if (parser.parseBareIdentifier(&identifier)) + return failure(); + if (parser.parseGreater()) + return failure(); + + result.addAttribute("id", parser.getBuilder().getStringAttr(identifier)); + return success(); +} + +void ParseBareIdOp::print(OpAsmPrinter &p) { p << "<" << getId() << ">"; } + //===----------------------------------------------------------------------===// // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -785,7 +785,7 @@ let results = (outs AnyTensor); } -def OpWithShapedTypeInferTypeAdaptorInterfaceOp : +def OpWithShapedTypeInferTypeAdaptorInterfaceOp : TEST_Op<"op_with_shaped_type_infer_type_adaptor_if", [InferTensorTypeAdaptorWithReify]> { let arguments = (ins AnyTensor:$operand1, AnyTensor:$operand2); @@ -1949,6 +1949,11 @@ let hasCustomAssemblyFormat = 1; } +def ParseBareIdOp : TEST_Op<"parse_bare_id"> { + let arguments = (ins StrAttr:$id); + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Test region argument list parsing. @@ -3470,7 +3475,7 @@ def TestStoreWithARegion : TEST_Op<"store_with_a_region", [DeclareOpInterfaceMethods, SingleBlock]> { - let arguments = (ins + let arguments = (ins Arg:$address, BoolAttr:$store_before_region ); diff --git a/mlir/unittests/Parser/ParserTest.cpp b/mlir/unittests/Parser/ParserTest.cpp --- a/mlir/unittests/Parser/ParserTest.cpp +++ b/mlir/unittests/Parser/ParserTest.cpp @@ -101,4 +101,5 @@ EXPECT_EQ(attr, b.getI64IntegerAttr(9)); } } + } // namespace