diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.h b/mlir/lib/Tools/lsp-server-support/Protocol.h --- a/mlir/lib/Tools/lsp-server-support/Protocol.h +++ b/mlir/lib/Tools/lsp-server-support/Protocol.h @@ -53,6 +53,7 @@ // Defined by the protocol. RequestCancelled = -32800, ContentModified = -32801, + RequestFailed = -32803, }; /// Defines how the host (editor) should sync document changes to the language @@ -103,8 +104,10 @@ /// Try to build a URIForFile from the given URI string. static llvm::Expected fromURI(StringRef uri); - /// Try to build a URIForFile from the given absolute file path. - static llvm::Expected fromFile(StringRef absoluteFilepath); + /// Try to build a URIForFile from the given absolute file path and optional + /// scheme. + static llvm::Expected fromFile(StringRef absoluteFilepath, + StringRef scheme = "file"); /// Returns the absolute path to the file. StringRef file() const { return filePath; } @@ -112,6 +115,9 @@ /// Returns the original uri of the file. StringRef uri() const { return uriStr; } + /// Return the scheme of the uri. + StringRef scheme() const; + explicit operator bool() const { return !filePath.empty(); } friend bool operator==(const URIForFile &lhs, const URIForFile &rhs) { @@ -124,6 +130,11 @@ return lhs.filePath < rhs.filePath; } + /// Register a supported URI scheme. The protocol supports `file` by default, + /// so this is only necessary for any additional schemes that a server wants + /// to support. + static void registerSupportedScheme(StringRef scheme); + private: explicit URIForFile(std::string &&filePath, std::string &&uriStr) : filePath(std::move(filePath)), uriStr(uriStr) {} diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.cpp b/mlir/lib/Tools/lsp-server-support/Protocol.cpp --- a/mlir/lib/Tools/lsp-server-support/Protocol.cpp +++ b/mlir/lib/Tools/lsp-server-support/Protocol.cpp @@ -15,6 +15,7 @@ #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringSet.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/Format.h" @@ -116,7 +117,16 @@ return result; } -static bool isValidScheme(StringRef scheme) { +/// Return the set containing the supported URI schemes. +static StringSet<> &getSupportedSchemes() { + static StringSet<> schemes({"file", "test"}); + return schemes; +} + +/// Returns true if the given scheme is structurally valid, i.e. it does not +/// contain any invalid scheme characters. This does not check that the scheme +/// is actually supported. +static bool isStructurallyValidScheme(StringRef scheme) { if (scheme.empty()) return false; if (!llvm::isAlpha(scheme[0])) @@ -126,7 +136,8 @@ }); } -static llvm::Expected uriFromAbsolutePath(StringRef absolutePath) { +static llvm::Expected uriFromAbsolutePath(StringRef absolutePath, + StringRef scheme) { std::string body; StringRef authority; StringRef root = llvm::sys::path::root_name(absolutePath); @@ -140,7 +151,7 @@ } body += llvm::sys::path::convert_to_slash(absolutePath); - std::string uri = "file:"; + std::string uri = scheme.str() + ":"; if (authority.empty() && body.empty()) return uri; @@ -186,7 +197,7 @@ origUri); StringRef schemeStr = uri.substr(0, pos); std::string uriScheme = percentDecode(schemeStr); - if (!isValidScheme(uriScheme)) + if (!isStructurallyValidScheme(uriScheme)) return llvm::createStringError(llvm::inconvertibleErrorCode(), "Invalid scheme: " + schemeStr + " (decoded: " + uriScheme + ")"); @@ -204,10 +215,10 @@ std::string uriBody = percentDecode(uri); // Compute the absolute path for this uri. - if (uriScheme != "file" && uriScheme != "test") { - return llvm::createStringError( - llvm::inconvertibleErrorCode(), - "mlir-lsp-server only supports 'file' URI scheme for workspace files"); + if (!getSupportedSchemes().contains(uriScheme)) { + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "unsupported URI scheme `" + uriScheme + + "' for workspace files"); } return getAbsolutePath(uriAuthority, uriBody); } @@ -219,13 +230,21 @@ return URIForFile(std::move(*filePath), uri.str()); } -llvm::Expected URIForFile::fromFile(StringRef absoluteFilepath) { - llvm::Expected uri = uriFromAbsolutePath(absoluteFilepath); +llvm::Expected URIForFile::fromFile(StringRef absoluteFilepath, + StringRef scheme) { + llvm::Expected uri = + uriFromAbsolutePath(absoluteFilepath, scheme); if (!uri) return uri.takeError(); return fromURI(*uri); } +StringRef URIForFile::scheme() const { return uri().split(':').first; } + +void URIForFile::registerSupportedScheme(StringRef scheme) { + getSupportedSchemes().insert(scheme); +} + bool mlir::lsp::fromJSON(const llvm::json::Value &value, URIForFile &result, llvm::json::Path path) { if (Optional str = value.getAsString()) { diff --git a/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt b/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt --- a/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt @@ -2,11 +2,13 @@ LSPServer.cpp MLIRServer.cpp MlirLspServerMain.cpp + Protocol.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-lsp-server LINK_LIBS PUBLIC + MLIRBytecodeWriter MLIRIR MLIRLspServerSupportLib MLIRParser diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp @@ -8,9 +8,9 @@ #include "LSPServer.h" #include "../lsp-server-support/Logging.h" -#include "../lsp-server-support/Protocol.h" #include "../lsp-server-support/Transport.h" #include "MLIRServer.h" +#include "Protocol.h" #include "llvm/ADT/FunctionExtras.h" #include "llvm/ADT/StringMap.h" @@ -74,6 +74,14 @@ void onCodeAction(const CodeActionParams ¶ms, Callback reply); + //===--------------------------------------------------------------------===// + // Bytecode + + void onConvertFromBytecode(const MLIRConvertBytecodeParams ¶ms, + Callback reply); + void onConvertToBytecode(const MLIRConvertBytecodeParams ¶ms, + Callback reply); + //===--------------------------------------------------------------------===// // Fields //===--------------------------------------------------------------------===// @@ -254,6 +262,20 @@ reply(std::move(actions)); } +//===----------------------------------------------------------------------===// +// Bytecode + +void LSPServer::onConvertFromBytecode( + const MLIRConvertBytecodeParams ¶ms, + Callback reply) { + reply(server.convertFromBytecode(params.uri)); +} + +void LSPServer::onConvertToBytecode(const MLIRConvertBytecodeParams ¶ms, + Callback reply) { + reply(server.convertToBytecode(params.uri)); +} + //===----------------------------------------------------------------------===// // Entry point //===----------------------------------------------------------------------===// @@ -298,6 +320,12 @@ messageHandler.method("textDocument/codeAction", &lspServer, &LSPServer::onCodeAction); + // Bytecode + messageHandler.method("mlir/convertFromBytecode", &lspServer, + &LSPServer::onConvertFromBytecode); + messageHandler.method("mlir/convertToBytecode", &lspServer, + &LSPServer::onConvertToBytecode); + // Diagnostics lspServer.publishDiagnostics = messageHandler.outgoingNotification( diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h @@ -10,6 +10,7 @@ #define LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_ #include "mlir/Support/LLVM.h" +#include "llvm/Support/Error.h" #include namespace mlir { @@ -23,6 +24,7 @@ struct DocumentSymbol; struct Hover; struct Location; +struct MLIRConvertBytecodeResult; struct Position; struct Range; class URIForFile; @@ -73,6 +75,14 @@ const CodeActionContext &context, std::vector &actions); + /// Convert the given bytecode file to the textual format. + llvm::Expected + convertFromBytecode(const URIForFile &uri); + + /// Convert the given textual file to the bytecode format. + llvm::Expected + convertToBytecode(const URIForFile &uri); + private: struct Impl; diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -8,21 +8,26 @@ #include "MLIRServer.h" #include "../lsp-server-support/Logging.h" -#include "../lsp-server-support/Protocol.h" #include "../lsp-server-support/SourceMgrUtils.h" +#include "Protocol.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/AsmParser/AsmParserState.h" #include "mlir/AsmParser/CodeComplete.h" +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/Operation.h" +#include "mlir/Parser/Parser.h" +#include "llvm/Support/Base64.h" #include "llvm/Support/SourceMgr.h" using namespace mlir; /// Returns a language server location from the given MLIR file location. -static Optional getLocationFromLoc(FileLineColLoc loc) { +/// `uriScheme` is the scheme to use when building new uris. +static Optional getLocationFromLoc(StringRef uriScheme, + FileLineColLoc loc) { llvm::Expected sourceURI = - lsp::URIForFile::fromFile(loc.getFilename()); + lsp::URIForFile::fromFile(loc.getFilename(), uriScheme); if (!sourceURI) { lsp::Logger::error("Failed to create URI for file `{0}`: {1}", loc.getFilename(), @@ -37,18 +42,19 @@ } /// Returns a language server location from the given MLIR location, or None if -/// one couldn't be created. `uri` is an optional additional filter that, when -/// present, is used to filter sub locations that do not share the same uri. +/// one couldn't be created. `uriScheme` is the scheme to use when building new +/// uris. `uri` is an optional additional filter that, when present, is used to +/// filter sub locations that do not share the same uri. static Optional getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc, - const lsp::URIForFile *uri = nullptr) { + StringRef uriScheme, const lsp::URIForFile *uri = nullptr) { Optional location; loc->walk([&](Location nestedLoc) { FileLineColLoc fileLoc = nestedLoc.dyn_cast(); if (!fileLoc) return WalkResult::advance(); - Optional sourceLoc = getLocationFromLoc(fileLoc); + Optional sourceLoc = getLocationFromLoc(uriScheme, fileLoc); if (sourceLoc && (!uri || sourceLoc->uri == *uri)) { location = *sourceLoc; SMLoc loc = sourceMgr.FindLocForLineAndColumn( @@ -80,7 +86,8 @@ if (!fileLoc || !visitedLocs.insert(nestedLoc)) return WalkResult::advance(); - Optional sourceLoc = getLocationFromLoc(fileLoc); + Optional sourceLoc = + getLocationFromLoc(uri.scheme(), fileLoc); if (sourceLoc && sourceLoc->uri != uri) locations.push_back(*sourceLoc); return WalkResult::advance(); @@ -191,8 +198,9 @@ // Try to grab a file location for this diagnostic. // TODO: For simplicity, we just grab the first one. It may be likely that we // will need a more interesting heuristic here.' + StringRef uriScheme = uri.scheme(); Optional lspLocation = - getLocationFromLoc(sourceMgr, diag.getLocation(), &uri); + getLocationFromLoc(sourceMgr, diag.getLocation(), uriScheme, &uri); if (lspLocation) lspDiag.range = lspLocation->range; @@ -217,7 +225,7 @@ for (Diagnostic ¬e : diag.getNotes()) { lsp::Location noteLoc; if (Optional loc = - getLocationFromLoc(sourceMgr, note.getLocation())) + getLocationFromLoc(sourceMgr, note.getLocation(), uriScheme)) noteLoc = *loc; else noteLoc.uri = uri; @@ -294,6 +302,12 @@ StringRef message, std::vector &edits); + //===--------------------------------------------------------------------===// + // Bytecode + //===--------------------------------------------------------------------===// + + llvm::Expected convertToBytecode(); + //===--------------------------------------------------------------------===// // Fields //===--------------------------------------------------------------------===// @@ -840,6 +854,35 @@ edits.emplace_back(std::move(edit)); } +//===----------------------------------------------------------------------===// +// MLIRDocument: Bytecode +//===----------------------------------------------------------------------===// + +llvm::Expected +MLIRDocument::convertToBytecode() { + // TODO: We currently require a single top-level operation, but this could + // conceptually be relaxed. + if (!llvm::hasSingleElement(parsedIR)) { + if (parsedIR.empty()) { + return llvm::make_error( + "expected a single and valid top-level operation, please ensure " + "there are no errors", + lsp::ErrorCode::RequestFailed); + } + return llvm::make_error( + "expected a single top-level operation", lsp::ErrorCode::RequestFailed); + } + + lsp::MLIRConvertBytecodeResult result; + { + std::string rawBytecodeBuffer; + llvm::raw_string_ostream os(rawBytecodeBuffer); + writeBytecodeToFile(&parsedIR.front(), os); + result.output = llvm::encodeBase64(rawBytecodeBuffer); + } + return result; +} + //===----------------------------------------------------------------------===// // MLIRTextFileChunk //===----------------------------------------------------------------------===// @@ -900,6 +943,7 @@ void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos, const lsp::CodeActionContext &context, std::vector &actions); + llvm::Expected convertToBytecode(); private: /// Find the MLIR document that contains the given position, and update the @@ -1115,6 +1159,17 @@ } } +llvm::Expected +MLIRTextFile::convertToBytecode() { + // Bail out if there is more than one chunk, bytecode wants a single module. + if (chunks.size() != 1) { + return llvm::make_error( + "unexpected split file, please remove all `// -----`", + lsp::ErrorCode::RequestFailed); + } + return chunks.front()->document.convertToBytecode(); +} + MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) { if (chunks.size() == 1) return *chunks.front(); @@ -1217,3 +1272,57 @@ if (fileIt != impl->files.end()) fileIt->second->getCodeActions(uri, pos, context, actions); } + +llvm::Expected +lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) { + MLIRContext tempContext(impl->registry); + tempContext.allowUnregisteredDialects(); + + // Collect any errors during parsing. + std::string errorMsg; + ScopedDiagnosticHandler diagHandler( + &tempContext, + [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; }); + + // Try to parse the given source file. + // TODO: This won't preserve external resources or the producer, we should try + // to fix this. + Block parsedBlock; + if (failed(parseSourceFile(uri.file(), &parsedBlock, &tempContext))) { + return llvm::make_error( + "failed to parse bytecode source file: " + errorMsg, + lsp::ErrorCode::RequestFailed); + } + + // TODO: We currently expect a single top-level operation, but this could + // conceptually be relaxed. + if (!llvm::hasSingleElement(parsedBlock)) { + return llvm::make_error( + "expected bytecode to contain a single top-level operation", + lsp::ErrorCode::RequestFailed); + } + + // Print the module to a buffer. + lsp::MLIRConvertBytecodeResult result; + { + // Extract the top-level op so that aliases get printed. + // FIXME: We should be able to enable aliases without having to do this! + OwningOpRef topOp = &parsedBlock.front(); + (*topOp)->remove(); + + llvm::raw_string_ostream os(result.output); + (*topOp)->print(os, OpPrintingFlags().enableDebugInfo().assumeVerified()); + } + return std::move(result); +} + +llvm::Expected +lsp::MLIRServer::convertToBytecode(const URIForFile &uri) { + auto fileIt = impl->files.find(uri.file()); + if (fileIt == impl->files.end()) { + return llvm::make_error( + "language server does not contain an entry for this source file", + lsp::ErrorCode::RequestFailed); + } + return fileIt->second->convertToBytecode(); +} diff --git a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp --- a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp @@ -68,6 +68,9 @@ llvm::sys::ChangeStdinToBinary(); JSONTransport transport(stdin, llvm::outs(), inputStyle, prettyPrint); + // Register the additionally supported URI schemes for the MLIR server. + URIForFile::registerSupportedScheme("mlir.bytecode-mlir"); + // Configure the servers and start the main language server. MLIRServer server(registry); return runMlirLSPServer(server, transport); diff --git a/mlir/lib/Tools/mlir-lsp-server/Protocol.h b/mlir/lib/Tools/mlir-lsp-server/Protocol.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-lsp-server/Protocol.h @@ -0,0 +1,59 @@ +//===--- Protocol.h - Language Server Protocol Implementation ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains structs for LSP commands that are specific to the MLIR +// server. +// +// Each struct has a toJSON and fromJSON function, that converts between +// the struct and a JSON representation. (See JSON.h) +// +// Some structs also have operator<< serialization. This is for debugging and +// tests, and is not generally machine-readable. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_PROTOCOL_H_ +#define LIB_MLIR_TOOLS_MLIRLSPSERVER_PROTOCOL_H_ + +#include "../lsp-server-support/Protocol.h" + +namespace mlir { +namespace lsp { +//===----------------------------------------------------------------------===// +// MLIRConvertBytecodeParams +//===----------------------------------------------------------------------===// + +/// This class represents the parameters used when converting between MLIR's +/// bytecode and textual format. +struct MLIRConvertBytecodeParams { + /// The input file containing the bytecode or textual format. + URIForFile uri; +}; + +/// Add support for JSON serialization. +bool fromJSON(const llvm::json::Value &value, MLIRConvertBytecodeParams &result, + llvm::json::Path path); + +//===----------------------------------------------------------------------===// +// MLIRConvertBytecodeResult +//===----------------------------------------------------------------------===// + +/// This class represents the result of converting between MLIR's bytecode and +/// textual format. +struct MLIRConvertBytecodeResult { + /// The resultant output of the conversion. + std::string output; +}; + +/// Add support for JSON serialization. +llvm::json::Value toJSON(const MLIRConvertBytecodeResult &value); + +} // namespace lsp +} // namespace mlir + +#endif diff --git a/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp @@ -0,0 +1,44 @@ +//===--- Protocol.cpp - Language Server Protocol Implementation -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the serialization code for the MLIR specific LSP structs. +// +//===----------------------------------------------------------------------===// + +#include "Protocol.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::lsp; + +//===----------------------------------------------------------------------===// +// MLIRConvertBytecodeParams +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + MLIRConvertBytecodeParams &result, + llvm::json::Path path) { + llvm::json::ObjectMapper o(value, path); + return o && o.map("uri", result.uri); +} + +//===----------------------------------------------------------------------===// +// MLIRConvertBytecodeResult +//===----------------------------------------------------------------------===// + +llvm::json::Value mlir::lsp::toJSON(const MLIRConvertBytecodeResult &value) { + return llvm::json::Object{{"output", value.output}}; +} diff --git a/mlir/utils/vscode/package-lock.json b/mlir/utils/vscode/package-lock.json --- a/mlir/utils/vscode/package-lock.json +++ b/mlir/utils/vscode/package-lock.json @@ -1,13 +1,14 @@ { "name": "vscode-mlir", - "version": "0.0.9", + "version": "0.0.10", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "vscode-mlir", - "version": "0.0.9", + "version": "0.0.10", "dependencies": { + "base64-js": "^1.5.1", "chokidar": "3.5.2", "vscode-languageclient": "^8.0.2-next.5" }, @@ -137,7 +138,6 @@ "version": "1.5.1", "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", "integrity": "sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==", - "dev": true, "funding": [ { "type": "github", @@ -2037,8 +2037,7 @@ "base64-js": { "version": "1.5.1", "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", - "integrity": "sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==", - "dev": true + "integrity": "sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==" }, "big-integer": { "version": "1.6.48", diff --git a/mlir/utils/vscode/package.json b/mlir/utils/vscode/package.json --- a/mlir/utils/vscode/package.json +++ b/mlir/utils/vscode/package.json @@ -21,6 +21,8 @@ "tablegen" ], "activationEvents": [ + "onFileSystem:mlir.bytecode-mlir", + "onCustomEditor:mlir.bytecode", "onLanguage:mlir", "onLanguage:pdll", "onLanguage:tablegen" @@ -35,6 +37,7 @@ "git-clang-format": "git-clang-format" }, "dependencies": { + "base64-js": "^1.5.1", "chokidar": "3.5.2", "vscode-languageclient": "^8.0.2-next.5" }, @@ -52,6 +55,18 @@ "url": "https://github.com/llvm/vscode-mlir.git" }, "contributes": { + "customEditors": [ + { + "viewType": "mlir.bytecode", + "displayName": "MLIR Bytecode", + "priority": "default", + "selector": [ + { + "filenamePattern": "*.mlirbc" + } + ] + } + ], "languages": [ { "id": "mlir", @@ -60,7 +75,8 @@ "mlir" ], "extensions": [ - ".mlir" + ".mlir", + ".mlirbc" ], "configuration": "./language-configuration.json" }, diff --git a/mlir/utils/vscode/src/MLIR/bytecodeProvider.ts b/mlir/utils/vscode/src/MLIR/bytecodeProvider.ts new file mode 100644 --- /dev/null +++ b/mlir/utils/vscode/src/MLIR/bytecodeProvider.ts @@ -0,0 +1,170 @@ +import * as base64 from 'base64-js' +import * as vscode from 'vscode' + +import {MLIRContext} from '../mlirContext'; + +/** + * The parameters to the mlir/convert(To|From)Bytecode commands. These + * parameters are: + * - `uri`: The URI of the file to convert. + */ +type ConvertBytecodeParams = Partial<{uri : string}>; + +/** + * The output of the mlir/convert(To|From)Bytecode commands: + * - `output`: The output buffer of the command, e.g. a .mlir or bytecode + * buffer. + */ +type ConvertBytecodeResult = Partial<{output : string}>; + +/** + * A custom filesystem that is used to convert MLIR bytecode files to text for + * use in the editor, but still use bytecode on disk. + */ +class BytecodeFS implements vscode.FileSystemProvider { + mlirContext: MLIRContext; + + constructor(mlirContext: MLIRContext) { this.mlirContext = mlirContext; } + + /* + * Forward to the default filesystem for the various methods that don't need + * to understand the bytecode <-> text translation. + */ + readDirectory(uri: vscode.Uri): Thenable<[ string, vscode.FileType ][]> { + return vscode.workspace.fs.readDirectory(uri); + } + delete(uri: vscode.Uri): void { + vscode.workspace.fs.delete(uri.with({scheme : "file"})); + } + stat(uri: vscode.Uri): Thenable { + return vscode.workspace.fs.stat(uri.with({scheme : "file"})); + } + rename(oldUri: vscode.Uri, newUri: vscode.Uri, + options: {overwrite: boolean}): void { + vscode.workspace.fs.rename(oldUri.with({scheme : "file"}), + newUri.with({scheme : "file"}), options); + } + createDirectory(uri: vscode.Uri): void { + vscode.workspace.fs.createDirectory(uri.with({scheme : "file"})); + } + watch(_uri: vscode.Uri, _options: { + readonly recursive: boolean; readonly excludes : readonly string[] + }): vscode.Disposable { + return new vscode.Disposable(() => {}); + } + + private _emitter = new vscode.EventEmitter(); + readonly onDidChangeFile: vscode.Event = + this._emitter.event; + + /* + * Read in a bytecode file, converting it to text before returning it to the + * caller. + */ + async readFile(uri: vscode.Uri): Promise { + // Try to start a language client for this file so that we can parse + // it. + const client = + await this.mlirContext.getOrActivateLanguageClient(uri, 'mlir'); + if (!client) { + throw new Error( + 'Failed to activate mlir language server to read bytecode'); + } + // Ask the client to do the conversion. + let convertParams: ConvertBytecodeParams = {uri : uri.toString()}; + try { + const result: ConvertBytecodeResult = + await client.sendRequest('mlir/convertFromBytecode', convertParams); + return new TextEncoder().encode(result.output); + } catch (e) { + vscode.window.showErrorMessage(e.message); + throw new Error(`Failed to read bytecode file: ${e}`); + } + } + + /* + * Save the provided content, which contains MLIR text, as bytecode. + */ + async writeFile(uri: vscode.Uri, content: Uint8Array, + _options: {create: boolean, overwrite: boolean}) { + // Get the language client managing this file. + let client = this.mlirContext.getLanguageClient(uri, 'mlir'); + if (!client) { + throw new Error( + 'Failed to activate mlir language server to write bytecode'); + } + + // Ask the client to do the conversion. + let convertParams: ConvertBytecodeParams = { + uri : uri.toString(), + }; + const result: ConvertBytecodeResult = + await client.sendRequest('mlir/convertToBytecode', convertParams); + await vscode.workspace.fs.writeFile(uri.with({scheme : "file"}), + base64.toByteArray(result.output)); + } +} + +/** + * A custom bytecode document for use by the custom editor provider below. + */ +class BytecodeDocument implements vscode.CustomDocument { + readonly uri: vscode.Uri; + + constructor(uri: vscode.Uri) { this.uri = uri; } + dispose(): void {} +} + +/** + * A custom editor provider for MLIR bytecode that allows for non-binary + * interpretation. + */ +class BytecodeEditorProvider implements + vscode.CustomReadonlyEditorProvider { + public async openCustomDocument(uri: vscode.Uri, _openContext: any, + _token: vscode.CancellationToken): + Promise { + return new BytecodeDocument(uri); + } + + public async resolveCustomEditor(document: BytecodeDocument, + _webviewPanel: vscode.WebviewPanel, + _token: vscode.CancellationToken): + Promise { + // Ask the user for the desired view type. + const editType = await vscode.window.showQuickPick( + [ {label : '.mlir', description : "Edit as a .mlir text file"} ], + {title : 'Select an editor for the bytecode.'}, + ); + + // If we don't have a valid view type, just bail. + if (!editType) { + await vscode.commands.executeCommand( + 'workbench.action.closeActiveEditor'); + return; + } + + // TODO: We should also provide a non-`.mlir` way of viewing the + // bytecode, which should also ideally have some support for invalid + // bytecode files. + + // Close the active editor given that we aren't using it. + await vscode.commands.executeCommand('workbench.action.closeActiveEditor'); + + // Display the file using a .mlir format. + await vscode.window.showTextDocument( + document.uri.with({scheme : "mlir.bytecode-mlir"}), + {preview : true, preserveFocus : false}); + } +} + +/** + * Register the necessary providers for supporting MLIR bytecode. + */ +export function registerMLIRBytecodeExtensions(context: vscode.ExtensionContext, + mlirContext: MLIRContext) { + vscode.workspace.registerFileSystemProvider("mlir.bytecode-mlir", + new BytecodeFS(mlirContext)); + vscode.window.registerCustomEditorProvider('mlir.bytecode', + new BytecodeEditorProvider()); +} diff --git a/mlir/utils/vscode/src/MLIR/mlir.ts b/mlir/utils/vscode/src/MLIR/mlir.ts new file mode 100644 --- /dev/null +++ b/mlir/utils/vscode/src/MLIR/mlir.ts @@ -0,0 +1,12 @@ +import * as vscode from 'vscode'; + +import {MLIRContext} from '../mlirContext'; +import {registerMLIRBytecodeExtensions} from './bytecodeProvider'; + +/** + * Register the necessary extensions for supporting MLIR. + */ +export function registerMLIRExtensions(context: vscode.ExtensionContext, + mlirContext: MLIRContext) { + registerMLIRBytecodeExtensions(context, mlirContext); +} diff --git a/mlir/utils/vscode/src/PDLL/commands/viewOutput.ts b/mlir/utils/vscode/src/PDLL/commands/viewOutput.ts --- a/mlir/utils/vscode/src/PDLL/commands/viewOutput.ts +++ b/mlir/utils/vscode/src/PDLL/commands/viewOutput.ts @@ -28,9 +28,8 @@ return; // Check to see if a language client is active for this document. - const workspaceFolder = - vscode.workspace.getWorkspaceFolder(editor.document.uri); - const pdllClient = this.context.getLanguageClient(workspaceFolder, "pdll"); + const pdllClient = + this.context.getLanguageClient(editor.document.uri, "pdll"); if (!pdllClient) { return; } diff --git a/mlir/utils/vscode/src/PDLL/pdll.ts b/mlir/utils/vscode/src/PDLL/pdll.ts --- a/mlir/utils/vscode/src/PDLL/pdll.ts +++ b/mlir/utils/vscode/src/PDLL/pdll.ts @@ -4,9 +4,9 @@ import {ViewPDLLCommand} from './commands/viewOutput'; /** - * Register the necessary context and commands for PDLL. + * Register the necessary extensions for supporting PDLL. */ -export function registerPDLLCommands(context: vscode.ExtensionContext, - mlirContext: MLIRContext) { +export function registerPDLLExtensions(context: vscode.ExtensionContext, + mlirContext: MLIRContext) { context.subscriptions.push(new ViewPDLLCommand(mlirContext)); } diff --git a/mlir/utils/vscode/src/extension.ts b/mlir/utils/vscode/src/extension.ts --- a/mlir/utils/vscode/src/extension.ts +++ b/mlir/utils/vscode/src/extension.ts @@ -1,7 +1,8 @@ import * as vscode from 'vscode'; +import {registerMLIRExtensions} from './MLIR/mlir'; import {MLIRContext} from './mlirContext'; -import {registerPDLLCommands} from './PDLL/pdll'; +import {registerPDLLExtensions} from './PDLL/pdll'; /** * This method is called when the extension is activated. The extension is @@ -21,7 +22,8 @@ mlirContext.dispose(); await mlirContext.activate(outputChannel); })); - registerPDLLCommands(context, mlirContext); + registerMLIRExtensions(context, mlirContext); + registerPDLLExtensions(context, mlirContext); mlirContext.activate(outputChannel); } diff --git a/mlir/utils/vscode/src/mlirContext.ts b/mlir/utils/vscode/src/mlirContext.ts --- a/mlir/utils/vscode/src/mlirContext.ts +++ b/mlir/utils/vscode/src/mlirContext.ts @@ -25,49 +25,19 @@ export class MLIRContext implements vscode.Disposable { subscriptions: vscode.Disposable[] = []; workspaceFolders: Map = new Map(); + outputChannel: vscode.OutputChannel; /** * Activate the MLIR context, and start the language clients. */ async activate(outputChannel: vscode.OutputChannel) { + this.outputChannel = outputChannel; + // This lambda is used to lazily start language clients for the given // document. It removes the need to pro-actively start language clients for // every folder within the workspace and every language type we provide. const startClientOnOpenDocument = async (document: vscode.TextDocument) => { - if (document.uri.scheme !== 'file') { - return; - } - let serverSettingName: string; - if (document.languageId === 'mlir') { - serverSettingName = 'server_path'; - } else if (document.languageId === 'pdll') { - serverSettingName = 'pdll_server_path'; - } else if (document.languageId === 'tablegen') { - serverSettingName = 'tablegen_server_path'; - } else { - return; - } - - // Resolve the workspace folder if this document is in one. We use the - // workspace folder when determining if a server needs to be started. - const uri = document.uri; - let workspaceFolder = vscode.workspace.getWorkspaceFolder(uri); - let workspaceFolderStr = - workspaceFolder ? workspaceFolder.uri.toString() : ""; - - // Get or create a client context for this folder. - let folderContext = this.workspaceFolders.get(workspaceFolderStr); - if (!folderContext) { - folderContext = new WorkspaceFolderContext(); - this.workspaceFolders.set(workspaceFolderStr, folderContext); - } - // Start the client for this language if necessary. - if (!folderContext.clients.has(document.languageId)) { - let client = await this.activateWorkspaceFolder( - workspaceFolder, serverSettingName, document.languageId, - outputChannel); - folderContext.clients.set(document.languageId, client); - } + await this.getOrActivateLanguageClient(document.uri, document.languageId); }; // Process any existing documents. for (const textDoc of vscode.workspace.textDocuments) { @@ -89,6 +59,50 @@ })); } + /** + * Open or return a language server for the given uri and language. + */ + async getOrActivateLanguageClient(uri: vscode.Uri, languageId: string): + Promise { + let serverSettingName: string; + if (languageId === 'mlir') { + serverSettingName = 'server_path'; + } else if (languageId === 'pdll') { + serverSettingName = 'pdll_server_path'; + } else if (languageId === 'tablegen') { + serverSettingName = 'tablegen_server_path'; + } else { + return null; + } + + // Check the scheme of the uri. + let validSchemes = [ 'file', 'mlir.bytecode-mlir' ]; + if (!validSchemes.includes(uri.scheme)) { + return null; + } + + // Resolve the workspace folder if this document is in one. We use the + // workspace folder when determining if a server needs to be started. + let workspaceFolder = vscode.workspace.getWorkspaceFolder(uri); + let workspaceFolderStr = + workspaceFolder ? workspaceFolder.uri.toString() : ""; + + // Get or create a client context for this folder. + let folderContext = this.workspaceFolders.get(workspaceFolderStr); + if (!folderContext) { + folderContext = new WorkspaceFolderContext(); + this.workspaceFolders.set(workspaceFolderStr, folderContext); + } + // Start the client for this language if necessary. + let client = folderContext.clients.get(languageId); + if (!client) { + client = await this.activateWorkspaceFolder( + workspaceFolder, serverSettingName, languageId, this.outputChannel); + folderContext.clients.set(languageId, client); + } + return client; + } + /** * Prepare a compilation database option for a server. */ @@ -263,7 +277,7 @@ // Configure the client options. const clientOptions: vscodelc.LanguageClientOptions = { documentSelector : [ - {scheme : 'file', language : languageName, pattern : selectorPattern} + {language : languageName, pattern : selectorPattern}, ], synchronize : { // Notify the server about file changes to language files contained in @@ -353,11 +367,12 @@ } /** - * Return the language client for the given language and workspace folder, or - * null if no client is active. + * Return the language client for the given language and uri, or null if no + * client is active. */ - getLanguageClient(workspaceFolder: vscode.WorkspaceFolder, + getLanguageClient(uri: vscode.Uri, languageName: string): vscodelc.LanguageClient { + let workspaceFolder = vscode.workspace.getWorkspaceFolder(uri); let workspaceFolderStr = workspaceFolder ? workspaceFolder.uri.toString() : ""; let folderContext = this.workspaceFolders.get(workspaceFolderStr);