diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -54,6 +54,10 @@ /// provide a valid type for the attribute. virtual void printAttributeWithoutType(Attribute attr); + /// Print the given string as a keyword, or a quoted and escaped string if it + /// has any special or non-printable characters in it. + virtual void printKeywordOrString(StringRef keyword); + /// Print the given string as a symbol reference, i.e. a form representable by /// a SymbolRefAttr. A symbol reference is represented as a string prefixed /// with '@'. The reference is surrounded with ""'s and escaped if it has any @@ -461,6 +465,17 @@ parseOptionalKeyword(StringRef *keyword, ArrayRef allowedValues) = 0; + /// Parse a keyword or a quoted string. + ParseResult parseKeywordOrString(std::string *result) { + if (failed(parseOptionalKeywordOrString(result))) + return emitError(getCurrentLocation()) + << "expected valid keyword or string"; + return success(); + } + + /// Parse an optional keyword or string. + virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0; + /// Parse a `(` token. virtual ParseResult parseLParen() = 0; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -518,6 +518,7 @@ // guaranteed to go unused. os << "%"; } + void printKeywordOrString(StringRef) override {} void printSymbolName(StringRef) override {} void printSuccessor(Block *) override {} void printSuccessorAndUseList(Block *, ValueRange) override {} @@ -1548,36 +1549,39 @@ /// Returns true if the given string can be represented as a bare identifier. static bool isBareIdentifier(StringRef name) { - assert(!name.empty() && "invalid name"); - // By making this unsigned, the value passed in to isalnum will always be // in the range 0-255. This is important when building with MSVC because // its implementation will assert. This situation can arise when dealing // with UTF-8 multibyte characters. - unsigned char firstChar = static_cast(name[0]); - if (!isalpha(firstChar) && firstChar != '_') + if (name.empty() || (!isalpha(name[0]) && name[0] != '_')) return false; return llvm::all_of(name.drop_front(), [](unsigned char c) { return isalnum(c) || c == '_' || c == '$' || c == '.'; }); } +/// Print the given string as a keyword, or a quoted and escaped string if it +/// has any special or non-printable characters in it. +static void printKeywordOrString(StringRef keyword, raw_ostream &os) { + // If it can be represented as a bare identifier, write it directly. + if (isBareIdentifier(keyword)) { + os << keyword; + return; + } + + // Otherwise, output the keyword wrapped in quotes with proper escaping. + os << "\""; + printEscapedString(keyword, os); + os << '"'; +} + /// Print the given string as a symbol reference. A symbol reference is /// represented as a string prefixed with '@'. The reference is surrounded with /// ""'s and escaped if it has any special or non-printable characters in it. static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { assert(!symbolRef.empty() && "expected valid symbol reference"); - - // If the symbol can be represented as a bare identifier, write it directly. - if (isBareIdentifier(symbolRef)) { - os << '@' << symbolRef; - return; - } - - // Otherwise, output the reference wrapped in quotes with proper escaping. - os << "@\""; - printEscapedString(symbolRef, os); - os << '"'; + os << '@'; + printKeywordOrString(symbolRef, os); } // Print out a valid ElementsAttr that is succinct and can represent any @@ -2038,13 +2042,10 @@ } void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { - if (isBareIdentifier(attr.first)) { - os << attr.first; - } else { - os << '"'; - printEscapedString(attr.first.strref(), os); - os << '"'; - } + assert(attr.first.size() != 0 && "expected valid named attribute"); + + // Print the name without quotes if possible. + ::printKeywordOrString(attr.first.strref(), os); // Pretty printing elides the attribute value for unit attributes. if (attr.second.isa()) @@ -2115,6 +2116,11 @@ impl->printAttribute(attr, Impl::AttrTypeElision::Must); } +void AsmPrinter::printKeywordOrString(StringRef keyword) { + assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden"); + ::printKeywordOrString(keyword, impl->getStream()); +} + void AsmPrinter::printSymbolName(StringRef symbolRef) { assert(impl && "expected AsmPrinter::printSymbolName to be overriden"); ::printSymbolReference(symbolRef, impl->getStream()); diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h --- a/mlir/lib/Parser/AsmParserImpl.h +++ b/mlir/lib/Parser/AsmParserImpl.h @@ -276,6 +276,17 @@ return failure(); } + /// Parse an optional keyword or string and set instance into 'result'.` + ParseResult parseOptionalKeywordOrString(std::string *result) override { + StringRef keyword; + if (succeeded(parseOptionalKeyword(&keyword))) { + *result = keyword.str(); + return success(); + } + + return parseOptionalString(result); + } + /// Parse a floating point value from the stream. ParseResult parseFloat(double &result) override { bool isNegative = parser.consumeIf(Token::minus);