diff --git a/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h b/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h --- a/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h +++ b/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h @@ -66,6 +66,9 @@ /// Signal code completion for Pattern metadata. virtual void codeCompletePatternMetadata() {} + /// Signal code completion for an include filename. + virtual void codeCompleteIncludeFilename(StringRef curPath) {} + //===--------------------------------------------------------------------===// // Signature Hooks //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h --- a/mlir/lib/Tools/PDLL/Parser/Lexer.h +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h @@ -34,22 +34,25 @@ class Token { public: enum Kind { - // Markers. + /// Markers. eof, error, + /// Token signifying a code completion location. code_complete, + /// Token signifying a code completion location within a string. + code_complete_string, - // Keywords. + /// Keywords. KW_BEGIN, - // Dependent keywords, i.e. those that are treated as keywords depending on - // the current parser context. + /// Dependent keywords, i.e. those that are treated as keywords depending on + /// the current parser context. KW_DEPENDENT_BEGIN, kw_attr, kw_op, kw_type, KW_DEPENDENT_END, - // General keywords. + /// General keywords. kw_Attr, kw_erase, kw_let, @@ -68,7 +71,7 @@ kw_with, KW_END, - // Punctuation. + /// Punctuation. arrow, colon, comma, @@ -76,7 +79,7 @@ equal, equal_arrow, semicolon, - // Paired punctuation. + /// Paired punctuation. less, greater, l_brace, @@ -87,7 +90,7 @@ r_square, underscore, - // Tokens. + /// Tokens. directive, identifier, integer, diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp --- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp @@ -22,11 +22,15 @@ //===----------------------------------------------------------------------===// std::string Token::getStringValue() const { - assert(getKind() == string || getKind() == string_block); + assert(getKind() == string || getKind() == string_block || + getKind() == code_complete_string); // Start by dropping the quotes. - StringRef bytes = getSpelling().drop_front().drop_back(); - if (is(string_block)) bytes = bytes.drop_front().drop_back(); + StringRef bytes = getSpelling(); + if (is(string)) + bytes = bytes.drop_front().drop_back(); + else if (is(string_block)) + bytes = bytes.drop_front(2).drop_back(2); std::string result; result.reserve(bytes.size()); @@ -337,6 +341,16 @@ Token Lexer::lexString(const char *tokStart, bool isStringBlock) { while (true) { + // Check to see if there is a code completion location within the string. In + // these cases we generate a completion location and place the currently + // lexed string within the token (without the quotes). This allows for the + // parser to use the partially lexed string when computing the completion + // results. + if (curPtr == codeCompletionLocation) { + return formToken(Token::code_complete_string, + tokStart + (isStringBlock ? 2 : 1)); + } + switch (*curPtr++) { case '"': // If this is a string block, we only end the string when we encounter a diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -424,6 +424,7 @@ LogicalResult codeCompleteDialectName(); LogicalResult codeCompleteOperationName(StringRef dialectName); LogicalResult codeCompletePatternMetadata(); + LogicalResult codeCompleteIncludeFilename(StringRef curPath); void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); void codeCompleteOperationOperandsSignature(Optional opName, @@ -680,6 +681,10 @@ SMRange loc = curToken.getLoc(); consumeToken(Token::directive); + // Handle code completion of the include file path. + if (curToken.is(Token::code_complete_string)) + return codeCompleteIncludeFilename(curToken.getStringValue()); + // Parse the file being included. if (!curToken.isString()) return emitError(loc, @@ -2922,6 +2927,11 @@ return failure(); } +LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) { + codeCompleteContext->codeCompleteIncludeFilename(curPath); + return failure(); +} + void Parser::codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs) { ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent); diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp --- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp @@ -119,7 +119,8 @@ ">", ":", ";", ",", "+", "-", "/", "*", "%", "^", "&", "#", "?", ".", "=", "\"", "'", "|"}}, {"resolveProvider", false}, - {"triggerCharacters", {".", ">", "(", "{", ",", "<", ":", "[", " "}}, + {"triggerCharacters", + {".", ">", "(", "{", ",", "<", ":", "[", " ", "\"", "/"}}, }}, {"signatureHelpProvider", llvm::json::Object{ diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" using namespace mlir; @@ -663,9 +664,10 @@ class LSPCodeCompleteContext : public CodeCompleteContext { public: LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList, - ods::Context &odsContext) + ods::Context &odsContext, + ArrayRef includeDirs) : CodeCompleteContext(completeLoc), completionList(completionList), - odsContext(odsContext) {} + odsContext(odsContext), includeDirs(includeDirs) {} void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final { ArrayRef elementTypes = tupleType.getElementTypes(); @@ -899,9 +901,68 @@ "The pattern properly handles recursive application."); } + void codeCompleteIncludeFilename(StringRef curPath) final { + // Normalize the path to allow for interacting with the file system + // utilities. + SmallString<128> nativeRelDir(llvm::sys::path::convert_to_slash(curPath)); + llvm::sys::path::native(nativeRelDir); + + // Set of already included completion paths. + DenseSet seenResults; + + // Functor used to add a single include completion item. + auto addIncludeCompletion = [&](StringRef path, bool isDirectory) { + lsp::CompletionItem item; + item.label = (path + (isDirectory ? "/" : "")).str(); + item.kind = isDirectory ? lsp::CompletionItemKind::Folder + : lsp::CompletionItemKind::File; + if (seenResults.insert(item.label).second) + completionList.items.emplace_back(item); + }; + + // Process the include directories for this file, adding any potential + // nested include files or directories. + for (StringRef includeDir : includeDirs) { + llvm::SmallString<128> dir = includeDir; + if (!nativeRelDir.empty()) + llvm::sys::path::append(dir, nativeRelDir); + + std::error_code errorCode; + for (auto it = llvm::sys::fs::directory_iterator(dir, errorCode), + e = llvm::sys::fs::directory_iterator(); + !errorCode && it != e; it.increment(errorCode)) { + StringRef filename = llvm::sys::path::filename(it->path()); + + // To know whether a symlink should be treated as file or a directory, + // we have to stat it. This should be cheap enough as there shouldn't be + // many symlinks. + llvm::sys::fs::file_type fileType = it->type(); + if (fileType == llvm::sys::fs::file_type::symlink_file) { + if (auto fileStatus = it->status()) + fileType = fileStatus->type(); + } + + switch (fileType) { + case llvm::sys::fs::file_type::directory_file: + addIncludeCompletion(filename, /*IsDirectory=*/true); + break; + case llvm::sys::fs::file_type::regular_file: { + // Only consider concrete files that can actually be included by PDLL. + if (filename.endswith(".pdll") || filename.endswith(".td")) + addIncludeCompletion(filename, /*IsDirectory=*/false); + break; + } + default: + break; + } + } + }; + } + private: lsp::CompletionList &completionList; ods::Context &odsContext; + ArrayRef includeDirs; }; } // namespace @@ -919,8 +980,8 @@ // code completion context provided. ods::Context tmpODSContext; lsp::CompletionList completionList; - LSPCodeCompleteContext lspCompleteContext(posLoc, completionList, - tmpODSContext); + LSPCodeCompleteContext lspCompleteContext( + posLoc, completionList, tmpODSContext, sourceMgr.getIncludeDirs()); ast::Context tmpContext(tmpODSContext); (void)parsePDLAST(tmpContext, sourceMgr, &lspCompleteContext); diff --git a/mlir/test/mlir-pdll-lsp-server/completion.test b/mlir/test/mlir-pdll-lsp-server/completion.test --- a/mlir/test/mlir-pdll-lsp-server/completion.test +++ b/mlir/test/mlir-pdll-lsp-server/completion.test @@ -1,16 +1,16 @@ -// RUN: mlir-pdll-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +// RUN: mlir-pdll-lsp-server -pdll-extra-dir %S -pdll-extra-dir %S/../../include -lit-test < %s | FileCheck -strict-whitespace %s {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"pdll","capabilities":{},"trace":"off"}} // ----- {"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ "uri":"test:///foo.pdll", "languageId":"pdll", "version":1, - "text":"Constraint ValueCst(value: Value);\nConstraint Cst();\nPattern FooPattern with benefit(1) {\nlet tuple = (value1 = _: Op, _: Op);\nerase tuple.value1;\n}" + "text":"#include \"include/included.pdll\"\nConstraint ValueCst(value: Value);\nConstraint Cst();\nPattern FooPattern with benefit(1) {\nlet tuple = (value1 = _: Op, _: Op);\nerase tuple.value1;\n}" }}} // ----- {"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{ "textDocument":{"uri":"test:///foo.pdll"}, - "position":{"line":4,"character":12} + "position":{"line":5,"character":12} }} // CHECK: "id": 1 // CHECK-NEXT: "jsonrpc": "2.0", @@ -49,7 +49,7 @@ // ----- {"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{ "textDocument":{"uri":"test:///foo.pdll"}, - "position":{"line":2,"character":23} + "position":{"line":3,"character":23} }} // CHECK: "id": 1 // CHECK-NEXT: "jsonrpc": "2.0", @@ -82,7 +82,7 @@ // ----- {"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{ "textDocument":{"uri":"test:///foo.pdll"}, - "position":{"line":3,"character":24} + "position":{"line":4,"character":24} }} // CHECK: "id": 1 // CHECK-NEXT: "jsonrpc": "2.0", @@ -200,6 +200,26 @@ // CHECK-NEXT: ] // CHECK-NEXT: } // ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{ + "textDocument":{"uri":"test:///foo.pdll"}, + "position":{"line":0,"character":18} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "isIncomplete": false, +// CHECK-NEXT: "items": [ +// CHECK-NEXT: { +// CHECK-NEXT: "kind": 17, +// CHECK-NEXT: "label": "included.td" +// CHECK-NEXT: }, +// CHECK-NEXT: { +// CHECK-NEXT: "kind": 17, +// CHECK-NEXT: "label": "included.pdll" +// CHECK-NEXT: } +// CHECK-NEXT: ] +// CHECK-NEXT: } +// ----- {"jsonrpc":"2.0","id":3,"method":"shutdown"} // ----- {"jsonrpc":"2.0","method":"exit"}