diff --git a/llvm/include/llvm/ADT/StringExtras.h b/llvm/include/llvm/ADT/StringExtras.h --- a/llvm/include/llvm/ADT/StringExtras.h +++ b/llvm/include/llvm/ADT/StringExtras.h @@ -292,6 +292,18 @@ /// printLowerCase - Print each character as lowercase if it is uppercase. void printLowerCase(StringRef String, raw_ostream &Out); +/// Converts a string from camel-case to snake-case by replacing all uppercase +/// letters with '_' followed by the letter in lowercase, except if the +/// uppercase letter is the first character of the string. +std::string convertToSnakeFromCamelCase(StringRef input); + +/// Converts a string from snake-case to camel-case by replacing all occurrences +/// of '_' followed by a lowercase letter with the letter in uppercase. +/// Optionally allow capitalization of the first letter (if it is a lowercase +/// letter) +std::string convertToCamelFromSnakeCase(StringRef input, + bool capitalizeFirst = false); + namespace detail { template diff --git a/llvm/lib/Support/StringExtras.cpp b/llvm/lib/Support/StringExtras.cpp --- a/llvm/lib/Support/StringExtras.cpp +++ b/llvm/lib/Support/StringExtras.cpp @@ -90,3 +90,46 @@ for (const char C : String) Out << toLower(C); } + +std::string llvm::convertToSnakeFromCamelCase(StringRef input) { + if (input.empty()) + return ""; + + std::string snakeCase; + snakeCase.reserve(input.size()); + for (char c : input) { + if (!std::isupper(c)) { + snakeCase.push_back(c); + continue; + } + + if (!snakeCase.empty() && snakeCase.back() != '_') + snakeCase.push_back('_'); + snakeCase.push_back(llvm::toLower(c)); + } + return snakeCase; +} + +std::string llvm::convertToCamelFromSnakeCase(StringRef input, + bool capitalizeFirst) { + if (input.empty()) + return ""; + + std::string output; + output.reserve(input.size()); + + // Push the first character, capatilizing if necessary. + if (capitalizeFirst && std::islower(input.front())) + output.push_back(llvm::toUpper(input.front())); + else + output.push_back(input.front()); + + // Walk the input converting any `*_[a-z]` snake case into `*[A-Z]` camelCase. + for (size_t pos = 1, e = input.size(); pos < e; ++pos) { + if (input[pos] == '_' && pos != (e - 1) && std::islower(input[pos + 1])) + output.push_back(llvm::toUpper(input[++pos])); + else + output.push_back(input[pos]); + } + return output; +} diff --git a/llvm/unittests/ADT/StringExtrasTest.cpp b/llvm/unittests/ADT/StringExtrasTest.cpp --- a/llvm/unittests/ADT/StringExtrasTest.cpp +++ b/llvm/unittests/ADT/StringExtrasTest.cpp @@ -118,3 +118,56 @@ printHTMLEscaped("ABCdef123&<>\"'", OS); EXPECT_EQ("ABCdef123&<>"'", OS.str()); } + +TEST(StringExtras, ConvertToSnakeFromCamelCase) { + auto testConvertToSnakeCase = [](llvm::StringRef input, + llvm::StringRef expected) { + EXPECT_EQ(convertToSnakeFromCamelCase(input), expected.str()); + }; + + testConvertToSnakeCase("OpName", "op_name"); + testConvertToSnakeCase("opName", "op_name"); + testConvertToSnakeCase("_OpName", "_op_name"); + testConvertToSnakeCase("Op_Name", "op_name"); + testConvertToSnakeCase("", ""); + testConvertToSnakeCase("A", "a"); + testConvertToSnakeCase("_", "_"); + testConvertToSnakeCase("a", "a"); + testConvertToSnakeCase("op_name", "op_name"); + testConvertToSnakeCase("_op_name", "_op_name"); + testConvertToSnakeCase("__op_name", "__op_name"); + testConvertToSnakeCase("op__name", "op__name"); +} + +TEST(StringExtras, ConvertToCamelFromSnakeCase) { + auto testConvertToCamelCase = [](bool capitalizeFirst, llvm::StringRef input, + llvm::StringRef expected) { + EXPECT_EQ(convertToCamelFromSnakeCase(input, capitalizeFirst), + expected.str()); + }; + + testConvertToCamelCase(false, "op_name", "opName"); + testConvertToCamelCase(false, "_op_name", "_opName"); + testConvertToCamelCase(false, "__op_name", "_OpName"); + testConvertToCamelCase(false, "op__name", "op_Name"); + testConvertToCamelCase(false, "", ""); + testConvertToCamelCase(false, "A", "A"); + testConvertToCamelCase(false, "_", "_"); + testConvertToCamelCase(false, "a", "a"); + testConvertToCamelCase(false, "OpName", "OpName"); + testConvertToCamelCase(false, "opName", "opName"); + testConvertToCamelCase(false, "_OpName", "_OpName"); + testConvertToCamelCase(false, "Op_Name", "Op_Name"); + testConvertToCamelCase(true, "op_name", "OpName"); + testConvertToCamelCase(true, "_op_name", "_opName"); + testConvertToCamelCase(true, "__op_name", "_OpName"); + testConvertToCamelCase(true, "op__name", "Op_Name"); + testConvertToCamelCase(true, "", ""); + testConvertToCamelCase(true, "A", "A"); + testConvertToCamelCase(true, "_", "_"); + testConvertToCamelCase(true, "a", "A"); + testConvertToCamelCase(true, "OpName", "OpName"); + testConvertToCamelCase(true, "_OpName", "_OpName"); + testConvertToCamelCase(true, "Op_Name", "Op_Name"); + testConvertToCamelCase(true, "opName", "OpName"); +} diff --git a/mlir/include/mlir/Support/StringExtras.h b/mlir/include/mlir/Support/StringExtras.h deleted file mode 100644 --- a/mlir/include/mlir/Support/StringExtras.h +++ /dev/null @@ -1,74 +0,0 @@ -//===- StringExtras.h - String utilities used by MLIR -----------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains string utility functions used within MLIR. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_SUPPORT_STRINGEXTRAS_H -#define MLIR_SUPPORT_STRINGEXTRAS_H - -#include "llvm/ADT/StringExtras.h" - -#include - -namespace mlir { -/// Converts a string to snake-case from camel-case by replacing all uppercase -/// letters with '_' followed by the letter in lowercase, except if the -/// uppercase letter is the first character of the string. -inline std::string convertToSnakeCase(llvm::StringRef input) { - std::string snakeCase; - snakeCase.reserve(input.size()); - for (auto c : input) { - if (std::isupper(c)) { - if (!snakeCase.empty() && snakeCase.back() != '_') { - snakeCase.push_back('_'); - } - snakeCase.push_back(llvm::toLower(c)); - } else { - snakeCase.push_back(c); - } - } - return snakeCase; -} - -/// Converts a string from camel-case to snake_case by replacing all occurrences -/// of '_' followed by a lowercase letter with the letter in -/// uppercase. Optionally allow capitalization of the first letter (if it is a -/// lowercase letter) -inline std::string convertToCamelCase(llvm::StringRef input, - bool capitalizeFirst = false) { - if (input.empty()) { - return ""; - } - std::string output; - output.reserve(input.size()); - size_t pos = 0; - if (capitalizeFirst && std::islower(input[pos])) { - output.push_back(llvm::toUpper(input[pos])); - pos++; - } - while (pos < input.size()) { - auto cur = input[pos]; - if (cur == '_') { - if (pos && (pos + 1 < input.size())) { - if (std::islower(input[pos + 1])) { - output.push_back(llvm::toUpper(input[pos + 1])); - pos += 2; - continue; - } - } - } - output.push_back(cur); - pos++; - } - return output; -} -} // namespace mlir - -#endif // MLIR_SUPPORT_STRINGEXTRAS_H diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -19,7 +19,6 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Parser.h" -#include "mlir/Support/StringExtras.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Sequence.h" @@ -133,7 +132,7 @@ } std::string SPIRVDialect::getAttributeName(Decoration decoration) { - return convertToSnakeCase(stringifyDecoration(decoration)); + return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -21,7 +21,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Interfaces/CallInterfaces.h" -#include "mlir/Support/StringExtras.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/bit.h" using namespace mlir; @@ -335,15 +335,15 @@ static ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state) { - auto builtInName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); + auto builtInName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::BuiltIn)); if (succeeded(parser.parseOptionalKeyword("bind"))) { Attribute set, binding; // Parse optional descriptor binding - auto descriptorSetName = convertToSnakeCase( + auto descriptorSetName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::DescriptorSet)); - auto bindingName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); + auto bindingName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::Binding)); Type i32Type = parser.getBuilder().getIntegerType(32); if (parser.parseLParen() || parser.parseAttribute(set, i32Type, descriptorSetName, @@ -373,10 +373,10 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl &elidedAttrs) { // Print optional descriptor binding - auto descriptorSetName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); - auto bindingName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); + auto descriptorSetName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::Binding)); auto descriptorSet = op->getAttrOfType(descriptorSetName); auto binding = op->getAttrOfType(bindingName); if (descriptorSet && binding) { @@ -387,8 +387,8 @@ } // Print BuiltIn attribute if present - auto builtInName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); + auto builtInName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::BuiltIn)); if (auto builtin = op->getAttrOfType(builtInName)) { printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; elidedAttrs.push_back(builtInName); @@ -2625,12 +2625,12 @@ // TODO(antiagainst): generate these strings using ODS. auto *op = varOp.getOperation(); - auto descriptorSetName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); - auto bindingName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); - auto builtInName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); + auto descriptorSetName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::Binding)); + auto builtInName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::BuiltIn)); for (const auto &attr : {descriptorSetName, bindingName, builtInName}) { if (op->getAttr(attr)) diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -20,11 +20,11 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Support/StringExtras.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/bit.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -647,7 +647,7 @@ if (decorationName.empty()) { return emitError(unknownLoc, "invalid Decoration code : ") << words[1]; } - auto attrName = convertToSnakeCase(decorationName); + auto attrName = llvm::convertToSnakeFromCamelCase(decorationName); auto symbol = opBuilder.getIdentifier(attrName); switch (static_cast(words[1])) { case spirv::Decoration::DescriptorSet: diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -20,11 +20,11 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Support/StringExtras.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" #include "llvm/Support/Debug.h" @@ -627,7 +627,7 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, NamedAttribute attr) { auto attrName = attr.first.strref(); - auto decorationName = mlir::convertToCamelCase(attrName, true); + auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); auto decoration = spirv::symbolizeDecoration(decorationName); if (!decoration) { return emitError( diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -10,7 +10,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Support/StringExtras.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/OpClass.h" diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -12,7 +12,6 @@ //===----------------------------------------------------------------------===// #include "OpFormatGen.h" -#include "mlir/Support/StringExtras.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/OpClass.h" diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -11,7 +11,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Support/StringExtras.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" @@ -1111,7 +1110,7 @@ enumName); os << " " << formatv("static constexpr const char attrName[] = \"{0}\";\n", - mlir::convertToSnakeCase(enumName)); + llvm::convertToSnakeFromCamelCase(enumName)); os << " return attrName;\n"; os << "}\n"; } diff --git a/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h b/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h --- a/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h +++ b/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h @@ -17,8 +17,8 @@ #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/IR/Module.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Support/StringExtras.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/ToolOutputFile.h" #include diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -2,7 +2,6 @@ AttributeTest.cpp DialectTest.cpp OperationSupportTest.cpp - StringExtrasTest.cpp ) target_link_libraries(MLIRIRTests PRIVATE diff --git a/mlir/unittests/IR/StringExtrasTest.cpp b/mlir/unittests/IR/StringExtrasTest.cpp deleted file mode 100644 --- a/mlir/unittests/IR/StringExtrasTest.cpp +++ /dev/null @@ -1,65 +0,0 @@ -//===- StringExtrasTest.cpp - Tests for utility methods in StringExtras.h -===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Support/StringExtras.h" -#include "gtest/gtest.h" - -using namespace mlir; - -static void testConvertToSnakeCase(llvm::StringRef input, - llvm::StringRef expected) { - EXPECT_EQ(convertToSnakeCase(input), expected.str()); -} - -TEST(StringExtras, ConvertToSnakeCase) { - testConvertToSnakeCase("OpName", "op_name"); - testConvertToSnakeCase("opName", "op_name"); - testConvertToSnakeCase("_OpName", "_op_name"); - testConvertToSnakeCase("Op_Name", "op_name"); - testConvertToSnakeCase("", ""); - testConvertToSnakeCase("A", "a"); - testConvertToSnakeCase("_", "_"); - testConvertToSnakeCase("a", "a"); - testConvertToSnakeCase("op_name", "op_name"); - testConvertToSnakeCase("_op_name", "_op_name"); - testConvertToSnakeCase("__op_name", "__op_name"); - testConvertToSnakeCase("op__name", "op__name"); -} - -template -static void testConvertToCamelCase(llvm::StringRef input, - llvm::StringRef expected) { - EXPECT_EQ(convertToCamelCase(input, capitalizeFirst), expected.str()); -} - -TEST(StringExtras, ConvertToCamelCase) { - testConvertToCamelCase("op_name", "opName"); - testConvertToCamelCase("_op_name", "_opName"); - testConvertToCamelCase("__op_name", "_OpName"); - testConvertToCamelCase("op__name", "op_Name"); - testConvertToCamelCase("", ""); - testConvertToCamelCase("A", "A"); - testConvertToCamelCase("_", "_"); - testConvertToCamelCase("a", "a"); - testConvertToCamelCase("OpName", "OpName"); - testConvertToCamelCase("opName", "opName"); - testConvertToCamelCase("_OpName", "_OpName"); - testConvertToCamelCase("Op_Name", "Op_Name"); - testConvertToCamelCase("op_name", "OpName"); - testConvertToCamelCase("_op_name", "_opName"); - testConvertToCamelCase("__op_name", "_OpName"); - testConvertToCamelCase("op__name", "Op_Name"); - testConvertToCamelCase("", ""); - testConvertToCamelCase("A", "A"); - testConvertToCamelCase("_", "_"); - testConvertToCamelCase("a", "A"); - testConvertToCamelCase("OpName", "OpName"); - testConvertToCamelCase("_OpName", "_OpName"); - testConvertToCamelCase("Op_Name", "Op_Name"); - testConvertToCamelCase("opName", "OpName"); -}