diff --git a/llvm/include/llvm/TableGen/Parser.h b/llvm/include/llvm/TableGen/Parser.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/TableGen/Parser.h @@ -0,0 +1,14 @@ +#ifndef LLVM_TABLEGEN_PARSER_H_ +#define LLVM_TABLEGEN_PARSER_H_ + +namespace llvm { + +class Error; +class StringRef; +class RecordKeeper; + +Error TableGenParseString(StringRef String, RecordKeeper &Records); + +} // namespace llvm + +#endif // LLVM_TABLEGEN_PARSER_H_ diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp --- a/llvm/lib/TableGen/TGParser.cpp +++ b/llvm/lib/TableGen/TGParser.cpp @@ -19,9 +19,11 @@ #include "llvm/Config/llvm-config.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Compiler.h" +#include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/Parser.h" #include #include #include @@ -3406,3 +3408,13 @@ E.dump(); } #endif + +Error llvm::TableGenParseString(StringRef String, RecordKeeper &Records) { + SourceMgr SrcMgr; + SrcMgr.AddNewSourceBuffer(MemoryBuffer::getMemBufferCopy(String), + llvm::SMLoc()); + TGParser Parser(SrcMgr, None, Records); + if (Parser.ParseFile()) + return createStringError(inconvertibleErrorCode(), "parser failure"); + return Error::success(); +} diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -42,8 +42,8 @@ } }]>; -class LLVM_TwoBuilders { - list builders = [b1, b2]; +class LLVM_TwoBuilders { + list builders = [b1, b2]; } // Base class for LLVM operations with one result. diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt --- a/mlir/tools/CMakeLists.txt +++ b/mlir/tools/CMakeLists.txt @@ -5,5 +5,6 @@ add_subdirectory(mlir-reduce) add_subdirectory(mlir-rocm-runner) add_subdirectory(mlir-shlib) +add_subdirectory(mlir-tblgen-modernize) add_subdirectory(mlir-translate) -add_subdirectory(mlir-vulkan-runner) \ No newline at end of file +add_subdirectory(mlir-vulkan-runner) diff --git a/mlir/tools/mlir-tblgen-modernize/CMakeLists.txt b/mlir/tools/mlir-tblgen-modernize/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen-modernize/CMakeLists.txt @@ -0,0 +1,17 @@ +set(LLVM_LINK_COMPONENTS + Support + TableGen +) + +add_llvm_tool(mlir-tblgen-modernize + mlir-tblgen-modernize.cpp + + DEPENDS + MLIRSupport +) +target_link_libraries(mlir-tblgen-modernize + PRIVATE + MLIRSupport +) +llvm_update_compile_flags(mlir-tblgen-modernize) + diff --git a/mlir/tools/mlir-tblgen-modernize/mlir-tblgen-modernize.cpp b/mlir/tools/mlir-tblgen-modernize/mlir-tblgen-modernize.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen-modernize/mlir-tblgen-modernize.cpp @@ -0,0 +1,256 @@ +#include "mlir/Support/FileUtilities.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Parser.h" +#include "llvm/TableGen/Record.h" + +#include + +struct ArgumentSignature { + llvm::StringRef type; + llvm::StringRef name; + llvm::StringRef initializer; +}; + +static ArgumentSignature splitSignature(llvm::StringRef signature) { + signature = signature.trim(); + + // First, check if there is an initializer. We find it by looking for the + // first '=' because these are unlikely (althought not impossible) to find in + // the type signature. + size_t assignmentPos = signature.find('='); + llvm::StringRef initializer; + if (assignmentPos != std::string::npos) { + initializer = signature.drop_front(assignmentPos + 1).ltrim(); + signature = signature.take_front(assignmentPos).rtrim(); + } + + for (int64_t pos = signature.size() - 1; pos >= 0; --pos) { + // Iterate until we find a symbol that is clearly not a part of an + // identifier. + char c = signature[pos]; + if (llvm::isAlnum(c) || c == '_') + continue; + + // The part may be a name or a reserved keyword that belongs to the type, + // e.g., in "Type *const" the "const" part belongs to the type. Check if the + // identifier is a keyword or not before forming the type/name pair. + llvm::StringRef maybeName = signature.drop_front(pos + 1); + if (maybeName.trim() == "const" || maybeName.trim() == "volatile") + return {signature, llvm::StringRef(), initializer}; + return {signature.take_front(pos + 1).rtrim(), + signature.drop_front(pos + 1), initializer}; + } + + // If we only found characters that can form identifier, it's a type + // identifier rather than name. + return {signature, llvm::StringRef(), initializer}; +} + +// The builder signature is a comma-separated argument list. But it can also +// have commas inside template argument list. Filter those out. +static void +splitBuilderSignature(llvm::StringRef signature, + llvm::SmallVectorImpl &arguments) { + size_t numAngleBrackets = 0; + size_t start = 0; + for (size_t pos = 0; pos < signature.size(); ++pos) { + if (signature[pos] == '<') { + ++numAngleBrackets; + } else if (signature[pos] == '>') { + assert(numAngleBrackets > 0 && "unexpected syntax error"); + --numAngleBrackets; + } else if (signature[pos] == ',' && numAngleBrackets == 0) { + arguments.push_back(signature.substr(start, pos - start)); + start = pos + 1; + } + } + arguments.push_back(signature.drop_front(start)); +} + +llvm::StringRef printWithPositionedReplacement(llvm::StringRef body, size_t pos, + size_t length, + llvm::StringRef value) { + llvm::outs() << body.take_front(pos) << value; + return body.drop_front(pos + length); +} + +void printWithReplacements(llvm::StringRef body, llvm::StringRef pattern1, + llvm::StringRef value1, llvm::StringRef pattern2, + llvm::StringRef value2) { + do { + size_t pos1 = body.find(pattern1); + size_t pos2 = body.find(pattern2); + if (pos1 == std::string::npos && pos2 == std::string::npos) { + llvm::outs() << body; + return; + } + body = pos1 < pos2 ? printWithPositionedReplacement(body, pos1, + pattern1.size(), value1) + : printWithPositionedReplacement( + body, pos2, pattern2.size(), value2); + } while (true); +} + +std::string transformSignature(const ArgumentSignature &signature) { + if (!signature.initializer.empty()) + return llvm::formatv("CArg<\"{0}\", \"{1}\">:${2}", signature.type, + signature.initializer, signature.name) + .str(); + + return llvm::formatv("\"{0}\":${1}", signature.type, signature.name); +} + +void transformBuilder(llvm::StringRef builder, size_t indent) { + std::string input; + llvm::raw_string_ostream stream(input); + stream << R"TG( + class OpBuilder { + string params = p; + code body = b; + } + def : )TG" + << builder << ";"; + stream.flush(); + + llvm::RecordKeeper records; + if (llvm::TableGenParseString(input, records)) { + llvm::PrintFatalError("failed to parse builder"); + return; + } + + for (const auto &def : records.getDefs()) { + if (!def.second->isSubClassOf("OpBuilder")) + continue; + + llvm::SmallVector paramStrings; + splitBuilderSignature(def.second->getValueAsString("params"), paramStrings); + auto params = llvm::to_vector<8>(llvm::map_range( + paramStrings, [](llvm::StringRef p) { return splitSignature(p); })); + llvm::StringRef builderOrigName, stateOrigName; + if (!params.empty() && params[0].type.contains("OpBuilder")) { + builderOrigName = params[0].name; + stateOrigName = params[1].name; + params.erase(params.begin(), std::next(params.begin(), 2)); + } + + llvm::outs() << "\n"; + llvm::outs().indent(indent) << "OpBuilderDAG<(ins"; + size_t prefixWidth = indent + strlen("OpBuilderDag<(ins "); + for (size_t i = 0, e = params.size(); i < e; ++i) { + std::string signature = transformSignature(params[i]); + if (i == 0) + signature = " " + signature; + + if (prefixWidth + signature.size() + (i == 0 ? 0 : 2) > 80) { + if (i != 0) + llvm::outs() << ","; + llvm::outs() << "\n"; + llvm::outs().indent(indent + 2) << signature; + prefixWidth = indent + 2 + signature.size(); + } else { + if (i != 0) + llvm::outs() << ", "; + llvm::outs() << signature; + prefixWidth += signature.size() + (i == 0 ? 0 : 2); + } + } + + llvm::outs() << ")"; + llvm::StringRef body = def.second->getValueAsString("body"); + if (!body.empty()) { + llvm::outs() << ",\n"; + llvm::outs().indent(indent) << "[{"; + if (!builderOrigName.empty()) + printWithReplacements(body, builderOrigName.str() + ".", "$_builder.", + stateOrigName.str() + ".", "$_state."); + else + llvm::outs() << body; + llvm::outs() << "}]"; + } + llvm::outs() << ">"; + } +} + +int main(int argc, char **argv) { + llvm::cl::opt inputFilename(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + llvm::InitLLVM init(argc, argv); + llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR tablegen modernizer"); + + std::string errorMessage; + auto inputFile = mlir::openInputFile(inputFilename, &errorMessage); + if (!inputFile) { + llvm::errs() << errorMessage; + return EXIT_FAILURE; + } + + llvm::StringRef buffer = inputFile->getBuffer(); + do { + size_t pos = buffer.find("OpBuilder<"); + if (pos == std::string::npos) + break; + + // Find the leading indentation. + size_t indentation = 0; + size_t lineStart = buffer.take_front(pos).rfind('\n'); + lineStart = lineStart == std::string::npos ? 0 : lineStart + 1; + for (size_t i = lineStart; i < pos; ++i) { + if (buffer[i] == ' ') + indentation += 1; + else if (buffer[i] == '\t') + indentation += 2; + else + break; + } + if (!buffer.take_front(pos).drop_front(lineStart).trim().empty()) + indentation += 2; + + // Find the matching '>'. Maintain a counter of unclosed <> pairs. Ignore <> + // symbols inside quoted strings or [{}] code. Also ignore any backslash + // escaped symbol. + size_t numOpen = 1; + bool inQuotes = false; + bool inCode = false; + size_t prefixLength = llvm::StringRef("OpBuilder<").size(); + size_t endPos = pos + prefixLength; + bool skipNextChar = false; + for (size_t endBuffer = buffer.size(); endPos < endBuffer; ++endPos) { + if (skipNextChar) { + skipNextChar = false; + continue; + } + + if (buffer[endPos] == '<' && !inQuotes && !inCode) { + ++numOpen; + } else if (buffer[endPos] == '>' && !inQuotes && !inCode) { + --numOpen; + if (numOpen == 0) + break; + } else if (buffer[endPos] == '\\') { + skipNextChar = true; + } else if (buffer[endPos] == '"' && !inCode) { + inQuotes = !inQuotes; + } else if (buffer[endPos] == '[' && endPos != endBuffer - 1 && + buffer[endPos + 1] == '{' && !inQuotes) { + inCode = true; + ++endPos; + } else if (buffer[endPos] == '}' && endPos != endBuffer - 1 && + buffer[endPos + 1] == ']' && !inQuotes) { + inCode = false; + ++endPos; + } + } + + llvm::StringRef verbatim = buffer.take_front(pos); + llvm::outs() << verbatim.rtrim(); + transformBuilder(buffer.substr(pos, endPos - pos + 1), indentation); + buffer = buffer.drop_front(endPos + 1); + } while (true); + llvm::outs() << buffer; + + return EXIT_SUCCESS; +}