diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -54,6 +54,11 @@ /// provide a valid type for the attribute. virtual void printAttributeWithoutType(Attribute attr); + /// Print the given string as an identifier. The identifer will be + /// surrounded with ""s and escaped if it has any special or non-printable + /// characters in it. + virtual void printIdentifier(StringRef identifier); + /// Print the given string as a symbol reference, i.e. a form representable by /// a SymbolRefAttr. A symbol reference is represented as a string prefixed /// with '@'. The reference is surrounded with ""'s and escaped if it has any @@ -686,6 +691,17 @@ // Identifier Parsing //===--------------------------------------------------------------------===// + /// Parse an identifier and set instance into 'result'. + ParseResult parseIdentifier(StringAttr &result) { + if (failed(parseOptionalIdentifier(result))) + return emitError(getCurrentLocation()) << "expected valid identifier"; + return success(); + } + + /// Parse an optional identifier and store it in a string attribute named + /// 'attrName'. + virtual ParseResult parseOptionalIdentifier(StringAttr &result) = 0; + /// Parse an @-identifier and store it (without the '@' symbol) in a string /// attribute named 'attrName'. ParseResult parseSymbolName(StringAttr &result, StringRef attrName, diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -518,6 +518,7 @@ // guaranteed to go unused. os << "%"; } + void printIdentifier(StringRef) override {} void printSymbolName(StringRef) override {} void printSuccessor(Block *) override {} void printSuccessorAndUseList(Block *, ValueRange) override {} @@ -1548,36 +1549,40 @@ /// Returns true if the given string can be represented as a bare identifier. static bool isBareIdentifier(StringRef name) { - assert(!name.empty() && "invalid name"); - // By making this unsigned, the value passed in to isalnum will always be // in the range 0-255. This is important when building with MSVC because // its implementation will assert. This situation can arise when dealing // with UTF-8 multibyte characters. - unsigned char firstChar = static_cast(name[0]); - if (!isalpha(firstChar) && firstChar != '_') + if (name.empty() || (!isalpha(name[0]) && name[0] != '_')) return false; return llvm::all_of(name.drop_front(), [](unsigned char c) { return isalnum(c) || c == '_' || c == '$' || c == '.'; }); } +/// Print the given string as an identifier. An identifier will be wrapped in +/// ""'s and escaped if it is empty or if it contains special or non-printable +/// characters. +static void printIdentifier(StringRef identifier, raw_ostream &os) { + // If it can be represented as a bare identifier, write it directly. + if (isBareIdentifier(identifier)) { + os << identifier; + return; + } + + // Otherwise, output the identifier wrapped in quotes with proper escaping. + os << "\""; + printEscapedString(identifier, os); + os << '"'; +} + /// Print the given string as a symbol reference. A symbol reference is /// represented as a string prefixed with '@'. The reference is surrounded with /// ""'s and escaped if it has any special or non-printable characters in it. static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { assert(!symbolRef.empty() && "expected valid symbol reference"); - - // If the symbol can be represented as a bare identifier, write it directly. - if (isBareIdentifier(symbolRef)) { - os << '@' << symbolRef; - return; - } - - // Otherwise, output the reference wrapped in quotes with proper escaping. - os << "@\""; - printEscapedString(symbolRef, os); - os << '"'; + os << '@'; + printIdentifier(symbolRef, os); } // Print out a valid ElementsAttr that is succinct and can represent any @@ -2037,13 +2042,10 @@ } void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { - if (isBareIdentifier(attr.first)) { - os << attr.first; - } else { - os << '"'; - printEscapedString(attr.first.strref(), os); - os << '"'; - } + assert(attr.first.size() != 0 && "expected valid named attribute"); + + // Print the name without quotes if possible. + ::printIdentifier(attr.first.strref(), os); // Pretty printing elides the attribute value for unit attributes. if (attr.second.isa()) @@ -2114,6 +2116,11 @@ impl->printAttribute(attr, Impl::AttrTypeElision::Must); } +void AsmPrinter::printIdentifier(StringRef identifier) { + assert(impl && "expected AsmPrinter::printIdentifier to be overriden"); + ::printIdentifier(identifier, impl->getStream()); +} + void AsmPrinter::printSymbolName(StringRef symbolRef) { assert(impl && "expected AsmPrinter::printSymbolName to be overriden"); ::printSymbolReference(symbolRef, impl->getStream()); diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h --- a/mlir/lib/Parser/AsmParserImpl.h +++ b/mlir/lib/Parser/AsmParserImpl.h @@ -388,6 +388,23 @@ // Identifier Parsing //===--------------------------------------------------------------------===// + /// Parse an optional identifier and set instance into 'result'. + ParseResult parseOptionalIdentifier(StringAttr &result) override { + StringRef keyword; + if (succeeded(parseOptionalKeyword(&keyword))) { + result = getBuilder().getStringAttr(keyword); + return success(); + } + + std::string string; + if (succeeded(parseOptionalString(&string))) { + result = getBuilder().getStringAttr(string); + return success(); + } + + return failure(); + } + /// Parse an optional @-identifier and store it (without the '@' symbol) in a /// string attribute named 'attrName'. ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName,