diff --git a/clang/include/clang/Interpreter/CodeCompletion.h b/clang/include/clang/Interpreter/CodeCompletion.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/Interpreter/CodeCompletion.h @@ -0,0 +1,62 @@ +//===------ CodeCompletion.h - Code Completion for ClangRepl -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the classes which performs code completion at the REPL. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_INTERPRETER_CODE_COMPLETION_H +#define LLVM_CLANG_INTERPRETER_CODE_COMPLETION_H +#include "clang/Sema/CodeCompleteConsumer.h" +#include "llvm/LineEditor/LineEditor.h" + +namespace clang { +class Interpreter; +class IncrementalCompilerBuilder; + +clang::CodeCompleteOptions getClangCompleteOpts(); + +class ReplCompletionConsumer : public CodeCompleteConsumer { +public: + ReplCompletionConsumer(std::vector &Results) + : CodeCompleteConsumer(getClangCompleteOpts()), + CCAllocator(std::make_shared()), + CCTUInfo(CCAllocator), Results(Results){}; + void ProcessCodeCompleteResults(class Sema &S, CodeCompletionContext Context, + CodeCompletionResult *InResults, + unsigned NumResults) final; + + clang::CodeCompletionAllocator &getAllocator() override { + return *CCAllocator; + } + + clang::CodeCompletionTUInfo &getCodeCompletionTUInfo() override { + return CCTUInfo; + } + +private: + std::shared_ptr CCAllocator; + CodeCompletionTUInfo CCTUInfo; + std::vector &Results; +}; + +struct ReplListCompleter { + IncrementalCompilerBuilder &CB; + Interpreter &MainInterp; + ReplListCompleter(IncrementalCompilerBuilder &CB, Interpreter &Interp) + : CB(CB), MainInterp(Interp){}; + std::vector operator()(llvm::StringRef Buffer, + size_t Pos) const; + +private: + std::vector + toCodeCompleteStrings(const std::vector &Results) const; +}; + +} // namespace clang +#endif diff --git a/clang/include/clang/Interpreter/Interpreter.h b/clang/include/clang/Interpreter/Interpreter.h --- a/clang/include/clang/Interpreter/Interpreter.h +++ b/clang/include/clang/Interpreter/Interpreter.h @@ -35,9 +35,11 @@ namespace clang { +class CodeCompletionResult; class CompilerInstance; class IncrementalExecutor; class IncrementalParser; +class ReplCompletionConsumer; /// Create a pre-configured \c CompilerInstance for incremental processing. class IncrementalCompilerBuilder { @@ -80,8 +82,12 @@ // An optional parser for CUDA offloading std::unique_ptr DeviceParser; + std::unique_ptr CConsumer; Interpreter(std::unique_ptr CI, llvm::Error &Err); + Interpreter(std::unique_ptr CI, llvm::Error &Err, + std::vector &CompResults, + const CompilerInstance *ParentCI = nullptr); llvm::Error CreateExecutor(); unsigned InitPTUSize = 0; @@ -93,13 +99,22 @@ public: ~Interpreter(); + static llvm::Expected> create(std::unique_ptr CI); + static llvm::Expected> createWithCUDA(std::unique_ptr CI, std::unique_ptr DCI); + + static llvm::Expected> + createForCodeCompletion(IncrementalCompilerBuilder &CB, + const CompilerInstance *ParentCI, + std::vector &CompResults); + const ASTContext &getASTContext() const; ASTContext &getASTContext(); + void CodeComplete(llvm::StringRef Input, size_t Col, size_t Line = 1); const CompilerInstance *getCompilerInstance() const; llvm::Expected getExecutionEngine(); diff --git a/clang/include/clang/Sema/CodeCompleteConsumer.h b/clang/include/clang/Sema/CodeCompleteConsumer.h --- a/clang/include/clang/Sema/CodeCompleteConsumer.h +++ b/clang/include/clang/Sema/CodeCompleteConsumer.h @@ -336,7 +336,10 @@ CCC_Recovery, /// Code completion in a @class forward declaration. - CCC_ObjCClassForwardDecl + CCC_ObjCClassForwardDecl, + + /// Code completion at a top level in a REPL session. + CCC_ReplTopLevel, }; using VisitedContextSet = llvm::SmallPtrSet; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -13319,7 +13319,9 @@ PCC_ParenthesizedExpression, /// Code completion occurs within a sequence of declaration /// specifiers within a function, method, or block. - PCC_LocalDeclarationSpecifiers + PCC_LocalDeclarationSpecifiers, + /// Code completion occurs at top-level in a REPL session + PCC_ReplTopLevel, }; void CodeCompleteModuleImport(SourceLocation ImportLoc, ModuleIdPath Path); diff --git a/clang/lib/Frontend/ASTUnit.cpp b/clang/lib/Frontend/ASTUnit.cpp --- a/clang/lib/Frontend/ASTUnit.cpp +++ b/clang/lib/Frontend/ASTUnit.cpp @@ -2005,6 +2005,7 @@ case CodeCompletionContext::CCC_SymbolOrNewName: case CodeCompletionContext::CCC_ParenthesizedExpression: case CodeCompletionContext::CCC_ObjCInterfaceName: + case CodeCompletionContext::CCC_ReplTopLevel: break; case CodeCompletionContext::CCC_EnumTag: diff --git a/clang/lib/Interpreter/CMakeLists.txt b/clang/lib/Interpreter/CMakeLists.txt --- a/clang/lib/Interpreter/CMakeLists.txt +++ b/clang/lib/Interpreter/CMakeLists.txt @@ -12,7 +12,9 @@ ) add_clang_library(clangInterpreter + CodeCompletion.cpp DeviceOffload.cpp + ExternalSource.cpp IncrementalExecutor.cpp IncrementalParser.cpp Interpreter.cpp diff --git a/clang/lib/Interpreter/CodeCompletion.cpp b/clang/lib/Interpreter/CodeCompletion.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/Interpreter/CodeCompletion.cpp @@ -0,0 +1,104 @@ +//===------ CodeCompletion.cpp - Code Completion for ClangRepl -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the classes which performs code completion at the REPL. +// +//===----------------------------------------------------------------------===// + +#include "clang/Interpreter/CodeCompletion.h" +#include "clang/Frontend/CompilerInstance.h" +#include "clang/Interpreter/Interpreter.h" +#include "clang/Lex/PreprocessorOptions.h" +#include "clang/Sema/CodeCompleteOptions.h" +#include "clang/Sema/Sema.h" + +namespace clang { + +clang::CodeCompleteOptions getClangCompleteOpts() { + clang::CodeCompleteOptions Opts; + Opts.IncludeCodePatterns = true; + Opts.IncludeMacros = true; + Opts.IncludeGlobals = true; + Opts.IncludeBriefComments = true; + return Opts; +} + +void ReplCompletionConsumer::ProcessCodeCompleteResults( + class Sema &S, CodeCompletionContext Context, + CodeCompletionResult *InResults, unsigned NumResults) { + for (unsigned I = 0; I < NumResults; ++I) { + auto &Result = InResults[I]; + switch (Result.Kind) { + case CodeCompletionResult::RK_Declaration: + if (Result.Declaration->getIdentifier()) { + Results.push_back(Result); + } + break; + default: + break; + case CodeCompletionResult::RK_Keyword: + Results.push_back(Result); + break; + } + } +} + +std::vector ReplListCompleter::toCodeCompleteStrings( + const std::vector &Results) const { + std::vector CompletionStrings; + for (auto Res : Results) { + switch (Res.Kind) { + case CodeCompletionResult::RK_Declaration: + if (auto *ID = Res.Declaration->getIdentifier()) { + CompletionStrings.push_back(ID->getName()); + } + break; + case CodeCompletionResult::RK_Keyword: + CompletionStrings.push_back(Res.Keyword); + break; + default: + break; + } + } + return CompletionStrings; +} + +std::vector +ReplListCompleter::operator()(llvm::StringRef Buffer, size_t Pos) const { + std::vector Comps; + std::vector Results; + auto Interp = Interpreter::createForCodeCompletion( + CB, MainInterp.getCompilerInstance(), Results); + + if (auto Err = Interp.takeError()) { + // log the error and returns an empty vector; + llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: "); + return Comps; + } + + auto Lines = std::count(Buffer.begin(), Buffer.end(), '\n') + 1; + + (*Interp)->CodeComplete(Buffer, Pos + 1, Lines); + + size_t space_pos = Buffer.rfind(" "); + llvm::StringRef s; + if (space_pos == llvm::StringRef::npos) { + s = Buffer; + } else { + s = Buffer.substr(space_pos + 1); + } + + for (auto c : toCodeCompleteStrings(Results)) { + if (c.startswith(s)) { + Comps.push_back( + llvm::LineEditor::Completion(c.substr(s.size()).str(), c.str())); + } + } + return Comps; +} +} // namespace clang diff --git a/clang/lib/Interpreter/ExternalSource.h b/clang/lib/Interpreter/ExternalSource.h new file mode 100644 --- /dev/null +++ b/clang/lib/Interpreter/ExternalSource.h @@ -0,0 +1,38 @@ +//==----- ExternalSource.h - External AST Source for Code Completion ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines components that make declarations parsed and executed by +// the interpreter visible to the context where code completion is being +// triggered. +// +//===----------------------------------------------------------------------===// + +#include "clang/AST/ExternalASTSource.h" + +namespace clang { +class ASTContext; +class FileManager; +class ASTImporter; + +class ExternalSource : public clang::ExternalASTSource { + ASTContext &ChildASTCtxt; + TranslationUnitDecl *ChildTUDeclCtxt; + ASTContext &ParentASTCtxt; + TranslationUnitDecl *ParentTUDeclCtxt; + + std::unique_ptr Importer; + +public: + ExternalSource(ASTContext &ChildASTCtxt, FileManager &ChildFM, + ASTContext &ParentASTCtxt, FileManager &ParentFM); + bool FindExternalVisibleDeclsByName(const DeclContext *DC, + DeclarationName Name) override; + void + completeVisibleDeclsMap(const clang::DeclContext *childDeclContext) override; +}; +} // namespace clang diff --git a/clang/lib/Interpreter/ExternalSource.cpp b/clang/lib/Interpreter/ExternalSource.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/Interpreter/ExternalSource.cpp @@ -0,0 +1,77 @@ +//===--- ExternalSource.cpp - External AST Source for Code Completion ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The file implements classes that make declarations parsed and executed by the +// interpreter visible to the context where code completion is being triggered. +// +//===----------------------------------------------------------------------===// + +#include "ExternalSource.h" +#include "clang/AST/ASTImporter.h" +#include "clang/AST/DeclarationName.h" +#include "clang/Basic/IdentifierTable.h" + +namespace clang { +ExternalSource::ExternalSource(ASTContext &ChildASTCtxt, FileManager &ChildFM, + ASTContext &ParentASTCtxt, FileManager &ParentFM) + : ChildASTCtxt(ChildASTCtxt), + ChildTUDeclCtxt(ChildASTCtxt.getTranslationUnitDecl()), + ParentASTCtxt(ParentASTCtxt), + ParentTUDeclCtxt(ParentASTCtxt.getTranslationUnitDecl()) { + ASTImporter *importer = + new ASTImporter(ChildASTCtxt, ChildFM, ParentASTCtxt, ParentFM, + /*MinimalImport : ON*/ true); + Importer.reset(importer); +} + +bool ExternalSource::FindExternalVisibleDeclsByName(const DeclContext *DC, + DeclarationName Name) { + IdentifierTable &ParentIdTable = ParentASTCtxt.Idents; + + auto ParentDeclName = + DeclarationName(&(ParentIdTable.get(Name.getAsString()))); + + DeclContext::lookup_result lookup_result = + ParentTUDeclCtxt->lookup(ParentDeclName); + + if (!lookup_result.empty()) { + return true; + } + return false; +} + +void ExternalSource::completeVisibleDeclsMap( + const DeclContext *ChildDeclContext) { + assert(ChildDeclContext && ChildDeclContext == ChildTUDeclCtxt && + "No child decl context!"); + + if (!ChildDeclContext->hasExternalVisibleStorage()) + return; + + for (auto *DeclCtxt = ParentTUDeclCtxt; DeclCtxt != nullptr; + DeclCtxt = DeclCtxt->getPreviousDecl()) { + for (auto &IDeclContext : DeclCtxt->decls()) { + if (NamedDecl *Decl = llvm::dyn_cast(IDeclContext)) { + if (auto DeclOrErr = Importer->Import(Decl)) { + if (NamedDecl *importedNamedDecl = + llvm::dyn_cast(*DeclOrErr)) { + SetExternalVisibleDeclsForName(ChildDeclContext, + importedNamedDecl->getDeclName(), + importedNamedDecl); + } + + } else { + llvm::consumeError(DeclOrErr.takeError()); + } + } + } + ChildDeclContext->setHasExternalLexicalStorage(false); + } +} + +} // namespace clang diff --git a/clang/lib/Interpreter/IncrementalParser.h b/clang/lib/Interpreter/IncrementalParser.h --- a/clang/lib/Interpreter/IncrementalParser.h +++ b/clang/lib/Interpreter/IncrementalParser.h @@ -24,7 +24,7 @@ #include namespace llvm { class LLVMContext; -} +} // namespace llvm namespace clang { class ASTConsumer; @@ -62,7 +62,8 @@ public: IncrementalParser(Interpreter &Interp, std::unique_ptr Instance, - llvm::LLVMContext &LLVMCtx, llvm::Error &Err); + llvm::LLVMContext &LLVMCtx, llvm::Error &Err, + const CompilerInstance *ParentCI = nullptr); virtual ~IncrementalParser(); CompilerInstance *getCI() { return CI.get(); } @@ -72,6 +73,7 @@ ///\returns a \c PartialTranslationUnit which holds information about the /// \c TranslationUnitDecl and \c llvm::Module corresponding to the input. virtual llvm::Expected Parse(llvm::StringRef Input); + void ParseForCodeCompletion(llvm::StringRef Input, size_t Col, size_t Line); /// Uses the CodeGenModule mangled name cache and avoids recomputing. ///\returns the mangled name of a \c GD. @@ -85,6 +87,12 @@ private: llvm::Expected ParseOrWrapTopLevelDecl(); + + llvm::Expected ParseForPTU(FileID FID, + SourceLocation SrcLoc); + + std::pair createSourceFile(llvm::StringRef SourceName, + llvm::StringRef Input); }; } // end namespace clang diff --git a/clang/lib/Interpreter/IncrementalParser.cpp b/clang/lib/Interpreter/IncrementalParser.cpp --- a/clang/lib/Interpreter/IncrementalParser.cpp +++ b/clang/lib/Interpreter/IncrementalParser.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "IncrementalParser.h" +#include "ExternalSource.h" #include "clang/AST/DeclContextInternals.h" #include "clang/CodeGen/BackendUtil.h" #include "clang/CodeGen/CodeGenAction.h" @@ -115,10 +116,12 @@ class IncrementalAction : public WrapperFrontendAction { private: bool IsTerminating = false; + const CompilerInstance *ParentCI; public: IncrementalAction(CompilerInstance &CI, llvm::LLVMContext &LLVMCtx, - llvm::Error &Err) + llvm::Error &Err, + const CompilerInstance *ParentCI = nullptr) : WrapperFrontendAction([&]() { llvm::ErrorAsOutParameter EAO(&Err); std::unique_ptr Act; @@ -152,7 +155,8 @@ break; } return Act; - }()) {} + }()), + ParentCI(ParentCI) {} FrontendAction *getWrapped() const { return WrappedAction.get(); } TranslationUnitKind getTranslationUnitKind() override { return TU_Incremental; @@ -175,6 +179,17 @@ Preprocessor &PP = CI.getPreprocessor(); PP.EnterMainSourceFile(); + if (ParentCI) { + ExternalSource *myExternalSource = new ExternalSource( + CI.getASTContext(), CI.getFileManager(), ParentCI->getASTContext(), + ParentCI->getFileManager()); + llvm::IntrusiveRefCntPtr astContextExternalSource( + myExternalSource); + CI.getASTContext().setExternalSource(astContextExternalSource); + CI.getASTContext().getTranslationUnitDecl()->setHasExternalVisibleStorage( + true); + } + if (!CI.hasSema()) CI.createSema(getTranslationUnitKind(), CompletionConsumer); } @@ -206,10 +221,11 @@ IncrementalParser::IncrementalParser(Interpreter &Interp, std::unique_ptr Instance, llvm::LLVMContext &LLVMCtx, - llvm::Error &Err) + llvm::Error &Err, + const CompilerInstance *ParentCI) : CI(std::move(Instance)) { llvm::ErrorAsOutParameter EAO(&Err); - Act = std::make_unique(*CI, LLVMCtx, Err); + Act = std::make_unique(*CI, LLVMCtx, Err, ParentCI); if (Err) return; CI->ExecuteAction(*Act); @@ -305,22 +321,49 @@ return LastPTU; } -llvm::Expected -IncrementalParser::Parse(llvm::StringRef input) { +void IncrementalParser::ParseForCodeCompletion(llvm::StringRef input, + size_t Col, size_t Line) { Preprocessor &PP = CI->getPreprocessor(); assert(PP.isIncrementalProcessingEnabled() && "Not in incremental mode!?"); std::ostringstream SourceName; - SourceName << "input_line_" << InputCount++; + SourceName << "input_line_[Completion]"; + + auto [FID, SrcLoc] = createSourceFile(SourceName.str(), input); + auto FE = CI->getSourceManager().getFileEntryRefForID(FID); + // auto Entry = PP.getFileManager().getFile(DummyFN); + // if (!Entry) { + // std::cout << "Entry invalid \n"; + // return; + // } + if (FE) { + PP.SetCodeCompletionPoint(*FE, Line, Col); + + // NewLoc only used for diags. + if (PP.EnterSourceFile(FID, /*DirLookup=*/nullptr, SrcLoc)) + return; + + auto PTU = ParseOrWrapTopLevelDecl(); + if (auto Err = PTU.takeError()) { + consumeError(std::move(Err)); + return; + } + return; + } +} + +std::pair +IncrementalParser::createSourceFile(llvm::StringRef SourceName, + llvm::StringRef Input) { // Create an uninitialized memory buffer, copy code in and append "\n" - size_t InputSize = input.size(); // don't include trailing 0 + size_t InputSize = Input.size(); // don't include trailing 0 // MemBuffer size should *not* include terminating zero std::unique_ptr MB( llvm::WritableMemoryBuffer::getNewUninitMemBuffer(InputSize + 1, SourceName.str())); char *MBStart = const_cast(MB->getBufferStart()); - memcpy(MBStart, input.data(), InputSize); + memcpy(MBStart, Input.data(), InputSize); MBStart[InputSize] = '\n'; SourceManager &SM = CI->getSourceManager(); @@ -330,18 +373,46 @@ SourceLocation NewLoc = SM.getLocForStartOfFile(SM.getMainFileID()); // Create FileID for the current buffer. - FileID FID = SM.createFileID(std::move(MB), SrcMgr::C_User, /*LoadedID=*/0, - /*LoadedOffset=*/0, NewLoc); + // FileID FID = SM.createFileID(std::move(MB), SrcMgr::C_User, /*LoadedID=*/0, + // /*LoadedOffset=*/0, NewLoc); + + const clang::FileEntry *FE = SM.getFileManager().getVirtualFile( + SourceName.str(), InputSize, 0 /* mod time*/); + SM.overrideFileContents(FE, std::move(MB)); + FileID FID = SM.createFileID(FE, NewLoc, SrcMgr::C_User); + return {FID, NewLoc}; +} + +llvm::Expected +IncrementalParser::ParseForPTU(FileID FID, SourceLocation SrcLoc) { + // Create an uninitialized memory buffer, copy code in and append "\n" + Preprocessor &PP = CI->getPreprocessor(); + assert(PP.isIncrementalProcessingEnabled() && "Not in incremental mode!?"); // NewLoc only used for diags. - if (PP.EnterSourceFile(FID, /*DirLookup=*/nullptr, NewLoc)) + if (PP.EnterSourceFile(FID, /*DirLookup=*/nullptr, SrcLoc)) return llvm::make_error("Parsing failed. " "Cannot enter source file.", std::error_code()); auto PTU = ParseOrWrapTopLevelDecl(); if (!PTU) - return PTU.takeError(); + return std::move(PTU.takeError()); + return *PTU; +} + +llvm::Expected +IncrementalParser::Parse(llvm::StringRef input) { + Preprocessor &PP = CI->getPreprocessor(); + std::ostringstream SourceName; + SourceName << "input_line_" << InputCount++; + + auto [FID, SrcLoc] = createSourceFile(SourceName.str(), input); + auto PTU = ParseForPTU(FID, SrcLoc); + + if (!PTU) { + return std::move(PTU.takeError()); + } if (PP.getLangOpts().DelayedTemplateParsing) { // Microsoft-specific: diff --git a/clang/lib/Interpreter/Interpreter.cpp b/clang/lib/Interpreter/Interpreter.cpp --- a/clang/lib/Interpreter/Interpreter.cpp +++ b/clang/lib/Interpreter/Interpreter.cpp @@ -14,6 +14,7 @@ #include "clang/Interpreter/Interpreter.h" #include "DeviceOffload.h" +#include "ExternalSource.h" #include "IncrementalExecutor.h" #include "IncrementalParser.h" @@ -33,6 +34,7 @@ #include "clang/Driver/Tool.h" #include "clang/Frontend/CompilerInstance.h" #include "clang/Frontend/TextDiagnosticBuffer.h" +#include "clang/Interpreter/CodeCompletion.h" #include "clang/Interpreter/Value.h" #include "clang/Lex/PreprocessorOptions.h" #include "clang/Sema/Lookup.h" @@ -127,7 +129,6 @@ Clang->getFrontendOpts().DisableFree = false; Clang->getCodeGenOpts().DisableFree = false; - return std::move(Clang); } @@ -237,6 +238,18 @@ *TSCtx->getContext(), Err); } +Interpreter::Interpreter(std::unique_ptr CI, llvm::Error &Err, + std::vector &CompResults, + const CompilerInstance *ParentCI) { + llvm::ErrorAsOutParameter EAO(&Err); + auto LLVMCtx = std::make_unique(); + TSCtx = std::make_unique(std::move(LLVMCtx)); + auto *CConsumer = new ReplCompletionConsumer(CompResults); + CI->setCodeCompletionConsumer(CConsumer); + IncrParser = std::make_unique( + *this, std::move(CI), *TSCtx->getContext(), Err, ParentCI); +} + Interpreter::~Interpreter() { if (IncrExecutor) { if (llvm::Error Err = IncrExecutor->cleanUp()) @@ -288,6 +301,34 @@ return std::move(Interp); } +llvm::Expected> +Interpreter::createForCodeCompletion( + IncrementalCompilerBuilder &CB, const CompilerInstance *ParentCI, + std::vector &CompResults) { + auto CI = CB.CreateCpp(); + if (auto Err = CI.takeError()) { + return std::move(Err); + } + + (*CI)->getPreprocessorOpts().SingleFileParseMode = true; + + (*CI)->getLangOpts().SpellChecking = false; + (*CI)->getLangOpts().DelayedTemplateParsing = false; + + auto &FrontendOpts = (*CI)->getFrontendOpts(); + FrontendOpts.CodeCompleteOpts = getClangCompleteOpts(); + + llvm::Error Err = llvm::Error::success(); + auto Interp = std::unique_ptr( + new Interpreter(std::move(*CI), Err, CompResults, ParentCI)); + + if (Err) + return std::move(Err); + + Interp->InitPTUSize = Interp->IncrParser->getPTUs().size(); + return std::move(Interp); +} + llvm::Expected> Interpreter::createWithCUDA(std::unique_ptr CI, std::unique_ptr DCI) { @@ -738,6 +779,10 @@ return Result.get(); } +void Interpreter::CodeComplete(llvm::StringRef Input, size_t Col, size_t Line) { + IncrParser->ParseForCodeCompletion(Input, Col, Line); +} + // Temporary rvalue struct that need special care. REPL_EXTERNAL_VISIBILITY void * __clang_Interpreter_SetValueWithAlloc(void *This, void *OutVal, diff --git a/clang/lib/Parse/Parser.cpp b/clang/lib/Parse/Parser.cpp --- a/clang/lib/Parse/Parser.cpp +++ b/clang/lib/Parse/Parser.cpp @@ -923,9 +923,18 @@ /*IsInstanceMethod=*/std::nullopt, /*ReturnType=*/nullptr); } + + Sema::ParserCompletionContext PCC; + if (CurParsedObjCImpl) { + PCC = Sema::PCC_ObjCImplementation; + } else if (PP.isIncrementalProcessingEnabled()) { + PCC = Sema::PCC_ReplTopLevel; + } else { + PCC = Sema::PCC_Namespace; + }; Actions.CodeCompleteOrdinaryName( getCurScope(), - CurParsedObjCImpl ? Sema::PCC_ObjCImplementation : Sema::PCC_Namespace); + PCC); return nullptr; case tok::kw_import: { Sema::ModuleImportState IS = Sema::ModuleImportState::NotACXX20Module; diff --git a/clang/lib/Sema/CodeCompleteConsumer.cpp b/clang/lib/Sema/CodeCompleteConsumer.cpp --- a/clang/lib/Sema/CodeCompleteConsumer.cpp +++ b/clang/lib/Sema/CodeCompleteConsumer.cpp @@ -51,6 +51,7 @@ case CCC_ParenthesizedExpression: case CCC_Symbol: case CCC_SymbolOrNewName: + case CCC_ReplTopLevel: return true; case CCC_TopLevel: @@ -169,6 +170,8 @@ return "Recovery"; case CCKind::CCC_ObjCClassForwardDecl: return "ObjCClassForwardDecl"; + case CCKind::CCC_ReplTopLevel: + return "ReplTopLevel"; } llvm_unreachable("Invalid CodeCompletionContext::Kind!"); } diff --git a/clang/lib/Sema/SemaCodeComplete.cpp b/clang/lib/Sema/SemaCodeComplete.cpp --- a/clang/lib/Sema/SemaCodeComplete.cpp +++ b/clang/lib/Sema/SemaCodeComplete.cpp @@ -225,6 +225,7 @@ case CodeCompletionContext::CCC_ObjCMessageReceiver: case CodeCompletionContext::CCC_ParenthesizedExpression: case CodeCompletionContext::CCC_Statement: + case CodeCompletionContext::CCC_ReplTopLevel: case CodeCompletionContext::CCC_Recovery: if (ObjCMethodDecl *Method = SemaRef.getCurMethodDecl()) if (Method->isInstanceMethod()) @@ -1850,6 +1851,7 @@ case Sema::PCC_ObjCInstanceVariableList: case Sema::PCC_Expression: case Sema::PCC_Statement: + case Sema::PCC_ReplTopLevel: case Sema::PCC_ForInit: case Sema::PCC_Condition: case Sema::PCC_RecoveryInFunction: @@ -1907,6 +1909,7 @@ case Sema::PCC_Type: case Sema::PCC_ParenthesizedExpression: case Sema::PCC_LocalDeclarationSpecifiers: + case Sema::PCC_ReplTopLevel: return true; case Sema::PCC_Expression: @@ -2219,6 +2222,7 @@ break; case Sema::PCC_RecoveryInFunction: + case Sema::PCC_ReplTopLevel: case Sema::PCC_Statement: { if (SemaRef.getLangOpts().CPlusPlus11) AddUsingAliasResult(Builder, Results); @@ -4208,6 +4212,8 @@ case Sema::PCC_LocalDeclarationSpecifiers: return CodeCompletionContext::CCC_Type; + case Sema::PCC_ReplTopLevel: + return CodeCompletionContext::CCC_ReplTopLevel; } llvm_unreachable("Invalid ParserCompletionContext!"); @@ -4348,6 +4354,7 @@ break; case PCC_Statement: + case PCC_ReplTopLevel: case PCC_ParenthesizedExpression: case PCC_Expression: case PCC_ForInit: @@ -4385,6 +4392,7 @@ case PCC_ParenthesizedExpression: case PCC_Expression: case PCC_Statement: + case PCC_ReplTopLevel: case PCC_RecoveryInFunction: if (S->getFnParent()) AddPrettyFunctionResults(getLangOpts(), Results); diff --git a/clang/tools/clang-repl/ClangRepl.cpp b/clang/tools/clang-repl/ClangRepl.cpp --- a/clang/tools/clang-repl/ClangRepl.cpp +++ b/clang/tools/clang-repl/ClangRepl.cpp @@ -13,6 +13,7 @@ #include "clang/Basic/Diagnostic.h" #include "clang/Frontend/CompilerInstance.h" #include "clang/Frontend/FrontendDiagnostic.h" +#include "clang/Interpreter/CodeCompletion.h" #include "clang/Interpreter/Interpreter.h" #include "llvm/ExecutionEngine/Orc/LLJIT.h" @@ -155,8 +156,8 @@ if (OptInputs.empty()) { llvm::LineEditor LE("clang-repl"); - // FIXME: Add LE.setListCompleter std::string Input; + LE.setListCompleter(clang::ReplListCompleter(CB, *Interp)); while (std::optional Line = LE.readLine()) { llvm::StringRef L = *Line; L = L.trim(); @@ -168,10 +169,10 @@ } Input += L; - if (Input == R"(%quit)") { break; - } else if (Input == R"(%undo)") { + } + if (Input == R"(%undo)") { if (auto Err = Interp->Undo()) { llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: "); HasError = true; diff --git a/clang/unittests/Interpreter/CMakeLists.txt b/clang/unittests/Interpreter/CMakeLists.txt --- a/clang/unittests/Interpreter/CMakeLists.txt +++ b/clang/unittests/Interpreter/CMakeLists.txt @@ -9,6 +9,7 @@ add_clang_unittest(ClangReplInterpreterTests IncrementalProcessingTest.cpp InterpreterTest.cpp + CodeCompletionTest.cpp ) target_link_libraries(ClangReplInterpreterTests PUBLIC clangAST diff --git a/clang/unittests/Interpreter/CodeCompletionTest.cpp b/clang/unittests/Interpreter/CodeCompletionTest.cpp new file mode 100644 --- /dev/null +++ b/clang/unittests/Interpreter/CodeCompletionTest.cpp @@ -0,0 +1,61 @@ +#include "clang/Interpreter/CodeCompletion.h" +#include "clang/Interpreter/Interpreter.h" + +#include "llvm/LineEditor/LineEditor.h" + +#include "clang/Frontend/CompilerInstance.h" +#include "llvm/Support/Error.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using namespace clang; +namespace { +auto CB = clang::IncrementalCompilerBuilder(); + +static std::unique_ptr createInterpreter() { + auto CI = cantFail(CB.CreateCpp()); + return cantFail(clang::Interpreter::create(std::move(CI))); +} + +TEST(CodeCompletionTest, Sanity) { + auto Interp = createInterpreter(); + if (auto R = Interp->ParseAndExecute("int foo = 12;")) { + consumeError(std::move(R)); + return; + } + auto Completer = ReplListCompleter(CB, *Interp); + std::vector comps = + Completer(std::string("f"), 1); + EXPECT_EQ((size_t)2, comps.size()); // foo and float + EXPECT_EQ(comps[0].TypedText, std::string("oo")); +} + +TEST(CodeCompletionTest, SanityNoneValid) { + auto Interp = createInterpreter(); + if (auto R = Interp->ParseAndExecute("int foo = 12;")) { + consumeError(std::move(R)); + return; + } + auto Completer = ReplListCompleter(CB, *Interp); + std::vector comps = + Completer(std::string("babanana"), 8); + EXPECT_EQ((size_t)0, comps.size()); // foo and float +} + +TEST(CodeCompletionTest, TwoDecls) { + auto Interp = createInterpreter(); + if (auto R = Interp->ParseAndExecute("int application = 12;")) { + consumeError(std::move(R)); + return; + } + if (auto R = Interp->ParseAndExecute("int apple = 12;")) { + consumeError(std::move(R)); + return; + } + auto Completer = ReplListCompleter(CB, *Interp); + std::vector comps = + Completer(std::string("app"), 3); + EXPECT_EQ((size_t)2, comps.size()); +} +} // anonymous namespace