diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt b/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt --- a/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt @@ -2,12 +2,14 @@ CompilationDatabase.cpp LSPServer.cpp PDLLServer.cpp + Protocol.cpp MlirPdllLspServerMain.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-pdll-lsp-server LINK_LIBS PUBLIC + MLIRPDLLCodeGen MLIRPDLLParser MLIRLspServerSupportLib ) diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/CompilationDatabase.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/CompilationDatabase.cpp --- a/mlir/lib/Tools/mlir-pdll-lsp-server/CompilationDatabase.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/CompilationDatabase.cpp @@ -8,7 +8,7 @@ #include "CompilationDatabase.h" #include "../lsp-server-support/Logging.h" -#include "../lsp-server-support/Protocol.h" +#include "Protocol.h" #include "mlir/Support/FileUtilities.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/YAMLTraits.h" diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp --- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp @@ -9,9 +9,9 @@ #include "LSPServer.h" #include "../lsp-server-support/Logging.h" -#include "../lsp-server-support/Protocol.h" #include "../lsp-server-support/Transport.h" #include "PDLLServer.h" +#include "Protocol.h" #include "llvm/ADT/FunctionExtras.h" #include "llvm/ADT/StringMap.h" @@ -82,6 +82,12 @@ void onSignatureHelp(const TextDocumentPositionParams ¶ms, Callback reply); + //===--------------------------------------------------------------------===// + // PDLL View Output + + void onPDLLViewOutput(const PDLLViewOutputParams ¶ms, + Callback> reply); + //===--------------------------------------------------------------------===// // Fields //===--------------------------------------------------------------------===// @@ -248,6 +254,15 @@ reply(server.getSignatureHelp(params.textDocument.uri, params.position)); } +//===----------------------------------------------------------------------===// +// PDLL ViewOutput + +void LSPServer::onPDLLViewOutput( + const PDLLViewOutputParams ¶ms, + Callback> reply) { + reply(server.getPDLLViewOutput(params.uri, params.kind)); +} + //===----------------------------------------------------------------------===// // Entry Point //===----------------------------------------------------------------------===// @@ -296,6 +311,10 @@ messageHandler.method("textDocument/signatureHelp", &lspServer, &LSPServer::onSignatureHelp); + // PDLL ViewOutput + messageHandler.method("pdll/viewOutput", &lspServer, + &LSPServer::onPDLLViewOutput); + // Diagnostics lspServer.publishDiagnostics = messageHandler.outgoingNotification( diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h @@ -18,6 +18,8 @@ namespace lsp { struct Diagnostic; class CompilationDatabase; +struct PDLLViewOutputResult; +enum class PDLLViewOutputKind; struct CompletionList; struct DocumentLink; struct DocumentSymbol; @@ -88,6 +90,11 @@ SignatureHelp getSignatureHelp(const URIForFile &uri, const Position &helpPos); + /// Get the output of the given PDLL file, or None if there is no valid + /// output. + Optional getPDLLViewOutput(const URIForFile &uri, + PDLLViewOutputKind kind); + private: struct Impl; std::unique_ptr impl; diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -9,11 +9,14 @@ #include "PDLLServer.h" #include "../lsp-server-support/Logging.h" -#include "../lsp-server-support/Protocol.h" #include "CompilationDatabase.h" +#include "Protocol.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Tools/PDLL/AST/Context.h" #include "mlir/Tools/PDLL/AST/Nodes.h" #include "mlir/Tools/PDLL/AST/Types.h" +#include "mlir/Tools/PDLL/CodeGen/CPPGen.h" +#include "mlir/Tools/PDLL/CodeGen/MLIRGen.h" #include "mlir/Tools/PDLL/ODS/Constraint.h" #include "mlir/Tools/PDLL/ODS/Context.h" #include "mlir/Tools/PDLL/ODS/Dialect.h" @@ -322,6 +325,12 @@ lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri, const lsp::Position &helpPos); + //===--------------------------------------------------------------------===// + // PDLL ViewOutput + //===--------------------------------------------------------------------===// + + void getPDLLViewOutput(raw_ostream &os, lsp::PDLLViewOutputKind kind); + //===--------------------------------------------------------------------===// // Fields //===--------------------------------------------------------------------===// @@ -1141,6 +1150,39 @@ return signatureHelp; } +//===----------------------------------------------------------------------===// +// PDLL ViewOutput +//===----------------------------------------------------------------------===// + +void PDLDocument::getPDLLViewOutput(raw_ostream &os, + lsp::PDLLViewOutputKind kind) { + if (failed(astModule)) + return; + if (kind == lsp::PDLLViewOutputKind::AST) { + (*astModule)->print(os); + return; + } + + // Generate the MLIR for the ast module. We also capture diagnostics here to + // show to the user, which may be useful if PDLL isn't capturing constraints + // expected by PDL. + MLIRContext mlirContext; + SourceMgrDiagnosticHandler diagHandler(sourceMgr, &mlirContext, os); + OwningOpRef pdlModule = + codegenPDLLToMLIR(&mlirContext, astContext, sourceMgr, **astModule); + if (!pdlModule) + return; + if (kind == lsp::PDLLViewOutputKind::MLIR) { + pdlModule->print(os, OpPrintingFlags().enableDebugInfo()); + return; + } + + // Otherwise, generate the output for C++. + assert(kind == lsp::PDLLViewOutputKind::CPP && + "unexpected PDLLViewOutputKind"); + codegenPDLLToCPP(**astModule, *pdlModule, os); +} + //===----------------------------------------------------------------------===// // PDLTextFileChunk //===----------------------------------------------------------------------===// @@ -1204,6 +1246,7 @@ lsp::Position completePos); lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri, lsp::Position helpPos); + lsp::PDLLViewOutputResult getPDLLViewOutput(lsp::PDLLViewOutputKind kind); private: /// Find the PDL document that contains the given position, and update the @@ -1377,6 +1420,21 @@ return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos); } +lsp::PDLLViewOutputResult +PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) { + lsp::PDLLViewOutputResult result; + { + llvm::raw_string_ostream outputOS(result.output); + llvm::interleave( + llvm::make_pointee_range(chunks), + [&](PDLTextFileChunk &chunk) { + chunk.document.getPDLLViewOutput(outputOS, kind); + }, + [&] { outputOS << "\n// -----\n\n"; }); + } + return result; +} + PDLTextFileChunk &PDLTextFile::getChunkFor(lsp::Position &pos) { if (chunks.size() == 1) return *chunks.front(); @@ -1494,3 +1552,12 @@ return fileIt->second->getSignatureHelp(uri, helpPos); return SignatureHelp(); } + +Optional +lsp::PDLLServer::getPDLLViewOutput(const URIForFile &uri, + PDLLViewOutputKind kind) { + auto fileIt = impl->files.find(uri.file()); + if (fileIt != impl->files.end()) + return fileIt->second->getPDLLViewOutput(kind); + return llvm::None; +} diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h @@ -0,0 +1,69 @@ +//===--- Protocol.h - Language Server Protocol Implementation ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains structs for LSP commands that are specific to the PDLL +// server. +// +// Each struct has a toJSON and fromJSON function, that converts between +// the struct and a JSON representation. (See JSON.h) +// +// Some structs also have operator<< serialization. This is for debugging and +// tests, and is not generally machine-readable. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_ +#define LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_ + +#include "../lsp-server-support/Protocol.h" + +namespace mlir { +namespace lsp { +//===----------------------------------------------------------------------===// +// PDLLViewOutputParams +//===----------------------------------------------------------------------===// + +/// The type of output to view from PDLL. +enum class PDLLViewOutputKind { + AST, + MLIR, + CPP, +}; + +/// Represents the parameters used when viewing the output of a PDLL file. +struct PDLLViewOutputParams { + /// The URI of the document to view the output of. + URIForFile uri; + + /// The kind of output to generate. + PDLLViewOutputKind kind; +}; + +/// Add support for JSON serialization. +bool fromJSON(const llvm::json::Value &value, PDLLViewOutputKind &result, + llvm::json::Path path); +bool fromJSON(const llvm::json::Value &value, PDLLViewOutputParams &result, + llvm::json::Path path); + +//===----------------------------------------------------------------------===// +// PDLLViewOutputResult +//===----------------------------------------------------------------------===// + +/// Represents the result of viewing the output of a PDLL file. +struct PDLLViewOutputResult { + /// The string representation of the output. + std::string output; +}; + +/// Add support for JSON serialization. +llvm::json::Value toJSON(const PDLLViewOutputResult &value); + +} // namespace lsp +} // namespace mlir + +#endif diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp @@ -0,0 +1,77 @@ +//===--- Protocol.cpp - Language Server Protocol Implementation -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the serialization code for the PDLL specific LSP structs. +// +//===----------------------------------------------------------------------===// + +#include "Protocol.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::lsp; + +// Helper that doesn't treat `null` and absent fields as failures. +template +static bool mapOptOrNull(const llvm::json::Value ¶ms, + llvm::StringLiteral prop, T &out, + llvm::json::Path path) { + const llvm::json::Object *o = params.getAsObject(); + assert(o); + + // Field is missing or null. + auto *v = o->get(prop); + if (!v || v->getAsNull().hasValue()) + return true; + return fromJSON(*v, out, path.field(prop)); +} + +//===----------------------------------------------------------------------===// +// PDLLViewOutputParams +//===----------------------------------------------------------------------===// + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + PDLLViewOutputKind &result, llvm::json::Path path) { + if (Optional str = value.getAsString()) { + if (*str == "ast") { + result = PDLLViewOutputKind::AST; + return true; + } + if (*str == "mlir") { + result = PDLLViewOutputKind::MLIR; + return true; + } + if (*str == "cpp") { + result = PDLLViewOutputKind::CPP; + return true; + } + } + return false; +} + +bool mlir::lsp::fromJSON(const llvm::json::Value &value, + PDLLViewOutputParams &result, llvm::json::Path path) { + llvm::json::ObjectMapper o(value, path); + return o && o.map("uri", result.uri) && o.map("kind", result.kind); +} + +//===----------------------------------------------------------------------===// +// PDLLViewOutputResult +//===----------------------------------------------------------------------===// + +llvm::json::Value mlir::lsp::toJSON(const PDLLViewOutputResult &value) { + return llvm::json::Object{{"output", value.output}}; +} diff --git a/mlir/test/mlir-pdll-lsp-server/view-output.test b/mlir/test/mlir-pdll-lsp-server/view-output.test new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll-lsp-server/view-output.test @@ -0,0 +1,43 @@ +// RUN: mlir-pdll-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"pdll","capabilities":{},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo.pdll", + "languageId":"pdll", + "version":1, + "text":"Pattern TestPat => erase op;" +}}} +// ----- +{"jsonrpc":"2.0","id":1,"method":"pdll/viewOutput","params":{ + "uri":"test:///foo.pdll", + "kind":"ast" +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "output": "-Module{{.*}}PatternDecl{{.*}}Name{{.*}}\n" +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":2,"method":"pdll/viewOutput","params":{ + "uri":"test:///foo.pdll", + "kind":"mlir" +}} +// CHECK: "id": 2 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "output": "module {\n pdl.pattern @TestPat {{.*}}\n" +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":3,"method":"pdll/viewOutput","params":{ + "uri":"test:///foo.pdll", + "kind":"cpp" +}} +// CHECK: "id": 3 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "output": "{{.*}}struct TestPat : ::mlir::PDLPatternModule{{.*}}\n" +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/utils/vscode/package.json b/mlir/utils/vscode/package.json --- a/mlir/utils/vscode/package.json +++ b/mlir/utils/vscode/package.json @@ -175,7 +175,20 @@ { "command": "mlir.restart", "title": "mlir: Restart language server" + }, + { + "command": "mlir.viewPDLLOutput", + "title": "mlir-pdll: View PDLL output" } - ] + ], + "menus": { + "editor/context": [ + { + "command": "mlir.viewPDLLOutput", + "group": "z_commands", + "when": "editorLangId == pdll" + } + ] + } } } diff --git a/mlir/utils/vscode/src/PDLL/commands/viewOutput.ts b/mlir/utils/vscode/src/PDLL/commands/viewOutput.ts new file mode 100644 --- /dev/null +++ b/mlir/utils/vscode/src/PDLL/commands/viewOutput.ts @@ -0,0 +1,66 @@ +import * as vscode from 'vscode' + +import {Command} from '../../command'; +import {MLIRContext} from '../../mlirContext'; + +/** + * The parameters to the pdll/viewOutput command. These parameters are: + * - `uri`: The URI of the file to view. + * - `kind`: The kind of the output to generate. + */ +type ViewOutputParams = Partial<{uri : string, kind : string;}>; + +/** + * The output of the commands: + * - `output`: The output string of the command, e.g. a .mlir PDL string. + */ +type ViewOutputResult = Partial<{output : string}>; + +/** + * A command that displays the output of the current PDLL document. + */ +export class ViewPDLLCommand extends Command { + constructor(context: MLIRContext) { super('mlir.viewPDLLOutput', context); } + + async execute() { + const editor = vscode.window.activeTextEditor; + if (editor.document.languageId != 'pdll') + return; + + // Check to see if a language client is active for this document. + const workspaceFolder = + vscode.workspace.getWorkspaceFolder(editor.document.uri); + const pdllClient = this.context.getLanguageClient(workspaceFolder, "pdll"); + if (!pdllClient) { + return; + } + + // Ask the user for the desired output type. + const outputType = + await vscode.window.showQuickPick([ 'ast', 'mlir', 'cpp' ]); + if (!outputType) { + return; + } + + // If we have the language client, ask it to try compiling the document. + let outputParams: ViewOutputParams = { + uri : editor.document.uri.toString(), + kind : outputType, + }; + const result: ViewOutputResult|undefined = + await pdllClient.sendRequest('pdll/viewOutput', outputParams); + if (!result || result.output.length === 0) { + return; + } + + // Display the output in a new editor. + let outputFileType = 'plaintext'; + if (outputType == 'mlir') { + outputFileType = 'mlir'; + } else if (outputType == 'cpp') { + outputFileType = 'cpp'; + } + await vscode.workspace.openTextDocument( + {language : outputFileType, content : result.output}); + } +} diff --git a/mlir/utils/vscode/src/PDLL/pdll.ts b/mlir/utils/vscode/src/PDLL/pdll.ts new file mode 100644 --- /dev/null +++ b/mlir/utils/vscode/src/PDLL/pdll.ts @@ -0,0 +1,12 @@ +import * as vscode from 'vscode'; + +import {MLIRContext} from '../mlirContext'; +import {ViewPDLLCommand} from './commands/viewOutput'; + +/** + * Register the necessary context and commands for PDLL. + */ +export function registerPDLLCommands(context: vscode.ExtensionContext, + mlirContext: MLIRContext) { + context.subscriptions.push(new ViewPDLLCommand(mlirContext)); +} diff --git a/mlir/utils/vscode/src/command.ts b/mlir/utils/vscode/src/command.ts new file mode 100644 --- /dev/null +++ b/mlir/utils/vscode/src/command.ts @@ -0,0 +1,25 @@ +import * as vscode from 'vscode'; +import {MLIRContext} from './mlirContext'; + +/** + * This class represents a base vscode command. It handles all of the necessary + * command registration and disposal boilerplate. + */ +export abstract class Command extends vscode.Disposable { + private disposable: vscode.Disposable; + protected context: MLIRContext; + + constructor(command: string, context: MLIRContext) { + super(() => this.dispose()); + this.disposable = + vscode.commands.registerCommand(command, this.execute, this); + this.context = context; + } + + dispose() { this.disposable && this.disposable.dispose(); } + + /** + * The function executed when this command is invoked. + */ + abstract execute(...args: any[]): any; +} diff --git a/mlir/utils/vscode/src/extension.ts b/mlir/utils/vscode/src/extension.ts --- a/mlir/utils/vscode/src/extension.ts +++ b/mlir/utils/vscode/src/extension.ts @@ -1,6 +1,7 @@ import * as vscode from 'vscode'; import {MLIRContext} from './mlirContext'; +import {registerPDLLCommands} from './PDLL/pdll'; /** * This method is called when the extension is activated. The extension is @@ -20,6 +21,7 @@ mlirContext.dispose(); await mlirContext.activate(outputChannel); })); + registerPDLLCommands(context, mlirContext); mlirContext.activate(outputChannel); } diff --git a/mlir/utils/vscode/src/mlirContext.ts b/mlir/utils/vscode/src/mlirContext.ts --- a/mlir/utils/vscode/src/mlirContext.ts +++ b/mlir/utils/vscode/src/mlirContext.ts @@ -327,6 +327,21 @@ return this.resolvePath(serverPath, defaultPath, workspaceFolder); } + /** + * Return the language client for the given language and workspace folder, or + * null if no client is active. + */ + getLanguageClient(workspaceFolder: vscode.WorkspaceFolder, + languageName: string): vscodelc.LanguageClient { + let workspaceFolderStr = + workspaceFolder ? workspaceFolder.uri.toString() : ""; + let folderContext = this.workspaceFolders.get(workspaceFolderStr); + if (!folderContext) { + return null; + } + return folderContext.clients.get(languageName); + } + dispose() { this.subscriptions.forEach((d) => { d.dispose(); }); this.subscriptions = [];