diff --git a/clang/lib/Interpreter/CodeCompletion.cpp b/clang/lib/Interpreter/CodeCompletion.cpp --- a/clang/lib/Interpreter/CodeCompletion.cpp +++ b/clang/lib/Interpreter/CodeCompletion.cpp @@ -37,12 +37,23 @@ return Opts; } +class CodeCompletionSubContext { +public: + virtual ~CodeCompletionSubContext(){}; + virtual void + HandleCodeCompleteResults(class Sema &S, CodeCompletionResult *InResults, + unsigned NumResults, + std::vector &Results) = 0; +}; + class ReplCompletionConsumer : public CodeCompleteConsumer { public: - ReplCompletionConsumer(std::vector &Results) + ReplCompletionConsumer(std::vector &Results, + std::unique_ptr SubCtxt) : CodeCompleteConsumer(getClangCompleteOpts()), CCAllocator(std::make_shared()), - CCTUInfo(CCAllocator), Results(Results){}; + CCTUInfo(CCAllocator), Results(Results), SubCtxt(std::move(SubCtxt)) { + } void ProcessCodeCompleteResults(class Sema &S, CodeCompletionContext Context, CodeCompletionResult *InResults, @@ -56,26 +67,13 @@ std::shared_ptr CCAllocator; CodeCompletionTUInfo CCTUInfo; std::vector &Results; + std::unique_ptr SubCtxt; }; void ReplCompletionConsumer::ProcessCodeCompleteResults( class Sema &S, CodeCompletionContext Context, CodeCompletionResult *InResults, unsigned NumResults) { - for (unsigned I = 0; I < NumResults; ++I) { - auto &Result = InResults[I]; - switch (Result.Kind) { - case CodeCompletionResult::RK_Declaration: - if (auto *ID = Result.Declaration->getIdentifier()) { - Results.push_back(ID->getName().str()); - } - break; - case CodeCompletionResult::RK_Keyword: - Results.push_back(Result.Keyword); - break; - default: - break; - } - } + SubCtxt->HandleCodeCompleteResults(S, InResults, NumResults, Results); } class IncrementalSyntaxOnlyAction : public SyntaxOnlyAction { @@ -177,11 +175,166 @@ } } + +class CCSubContextRegular : public CodeCompletionSubContext { + StringRef prefix; + +public: + CCSubContextRegular(StringRef prefix) : prefix(prefix){}; + virtual ~CCSubContextRegular(){}; + void HandleCodeCompleteResults( + class Sema &S, CodeCompletionResult *InResults, unsigned NumResults, + std::vector &Results) override; +}; + +class CCSubContextCallSite : public CodeCompletionSubContext { + StringRef calleeName; + StringRef Prefix; + std::optional lookUp(CodeCompletionResult *InResults, + unsigned NumResults); + +public: + CCSubContextCallSite(StringRef calleeName, StringRef Prefix) + : calleeName(calleeName), Prefix(Prefix) {} + virtual ~CCSubContextCallSite(){}; + void HandleCodeCompleteResults( + class Sema &S, CodeCompletionResult *InResults, unsigned NumResults, + std::vector &Results) override; +}; + +void CCSubContextRegular::HandleCodeCompleteResults( + class Sema &S, CodeCompletionResult *InResults, unsigned NumResults, + std::vector &Results) { + for (unsigned I = 0; I < NumResults; ++I) { + auto &Result = InResults[I]; + switch (Result.Kind) { + case CodeCompletionResult::RK_Declaration: + if (auto *ID = Result.Declaration->getIdentifier()) { + Results.push_back(ID->getName().str()); + } + break; + case CodeCompletionResult::RK_Keyword: + Results.push_back(Result.Keyword); + break; + default: + break; + } + } +} + +std::optional +CCSubContextCallSite::lookUp(CodeCompletionResult *InResults, + unsigned NumResults) { + for (unsigned I = 0; I < NumResults; I++) { + auto &Result = InResults[I]; + switch (Result.Kind) { + case CodeCompletionResult::RK_Declaration: + if (Result.Hidden) { + continue; + } + if (const auto *Function = Result.Declaration->getAsFunction()) { + if (Function->isDestroyingOperatorDelete() || + Function->isOverloadedOperator()) { + continue; + } + + auto Name = Function->getDeclName(); + switch (Name.getNameKind()) { + case DeclarationName::CXXConstructorName: + case DeclarationName::CXXDestructorName: + continue; + default: + if (Function->getName() == calleeName) { + return Function; + } + continue; + } + } + break; + default: + break; + } + } + return std::nullopt; +} + +void CCSubContextCallSite::HandleCodeCompleteResults( + class Sema &S, CodeCompletionResult *InResults, unsigned NumResults, + std::vector &Results) { + auto Function = lookUp(InResults, NumResults); + if (!Function) + return; + for (unsigned I = 0; I < NumResults; I++) { + auto &Result = InResults[I]; + switch (Result.Kind) { + case CodeCompletionResult::RK_Declaration: + if (Result.Hidden) { + continue; + } + if (!Result.Declaration->getIdentifier()) { + continue; + } + if (auto *DD = dyn_cast(Result.Declaration)) { + if (!DD->getName().startswith(Prefix)) + continue; + + auto ArgumentType = DD->getType(); + auto RequiredType = (*Function)->getParamDecl(0)->getType(); + if (RequiredType->isReferenceType()) { + QualType RT = RequiredType->castAs()->getPointeeType(); + Sema::ReferenceConversions RefConv; + Sema::ReferenceCompareResult RefRelationship = + S.CompareReferenceRelationship(SourceLocation(), RT, ArgumentType, + &RefConv); + if (RefRelationship == Sema::Ref_Compatible) { + Results.push_back(DD->getName().str()); + } else if (RefRelationship == Sema::Ref_Related) { + Results.push_back(DD->getName().str()); + } + } else if (S.Context.hasSameType(ArgumentType, RequiredType)) { + Results.push_back(DD->getName().str()); + } + } + break; + default: + break; + } + } +} + + +static std::pair> +getCodeCompletionSubContext(llvm::StringRef CurInput, size_t Pos) { + size_t LeftParenPos = CurInput.rfind("("); + if (LeftParenPos == llvm::StringRef::npos) { + size_t space_pos = CurInput.rfind(" "); + llvm::StringRef Prefix; + if (space_pos == llvm::StringRef::npos) { + Prefix = CurInput; + } else { + Prefix = CurInput.substr(space_pos + 1); + } + return {Prefix, std::move(std::make_unique(Prefix))}; + } + auto subs = CurInput.substr(0, LeftParenPos); + size_t start_pos = subs.rfind(" "); + if (start_pos == llvm::StringRef::npos) { + start_pos = 0; + } + auto Prefix = CurInput.substr(LeftParenPos + 1, Pos); + return {Prefix, + std::make_unique( + subs.substr(start_pos, LeftParenPos - start_pos), Prefix)}; +} + + void codeComplete(CompilerInstance *InterpCI, llvm::StringRef Content, unsigned Line, unsigned Col, const CompilerInstance *ParentCI, std::vector &CCResults) { auto DiagOpts = DiagnosticOptions(); - auto consumer = ReplCompletionConsumer(CCResults); + auto Pos = Col - 1; + auto [s, SubCtxt] = getCodeCompletionSubContext(Content, Pos); + auto consumer = ReplCompletionConsumer(CCResults, std::move(SubCtxt)); auto diag = InterpCI->getDiagnosticsPtr(); std::unique_ptr AU(ASTUnit::LoadFromCompilerInvocationAction( diff --git a/clang/unittests/Interpreter/CodeCompletionTest.cpp b/clang/unittests/Interpreter/CodeCompletionTest.cpp --- a/clang/unittests/Interpreter/CodeCompletionTest.cpp +++ b/clang/unittests/Interpreter/CodeCompletionTest.cpp @@ -98,4 +98,107 @@ EXPECT_EQ((bool)Err, false); } +TEST(CodeCompletionTest, TypedDirected) { + auto Interp = createInterpreter(); + if (auto R = Interp->ParseAndExecute("int application = 12;")) { + consumeError(std::move(R)); + return; + } + if (auto R = Interp->ParseAndExecute("char apple = '2';")) { + consumeError(std::move(R)); + return; + } + if (auto R = Interp->ParseAndExecute("void add(int &SomeInt){}")) { + consumeError(std::move(R)); + return; + } + { + auto Err = llvm::Error::success(); + auto comps = runComp(*Interp, std::string("add("), Err); + EXPECT_EQ((size_t)1, comps.size()); + } + + if (auto R = Interp->ParseAndExecute("int banana = 42;")) { + consumeError(std::move(R)); + return; + } + + { + auto Err = llvm::Error::success(); + auto comps = runComp(*Interp, std::string("add("), Err); + EXPECT_EQ((size_t)2, comps.size()); + EXPECT_EQ(comps[0], "application"); + EXPECT_EQ(comps[1], "banana"); + } + + { + auto Err = llvm::Error::success(); + auto comps = runComp(*Interp, std::string("add(b"), Err); + EXPECT_EQ((size_t)1, comps.size()); + EXPECT_EQ(comps[0], "anana"); + } +} + +TEST(CodeCompletionTest, SanityClasses) { + auto Interp = createInterpreter(); + if (auto R = Interp->ParseAndExecute("struct Apple{};")) { + consumeError(std::move(R)); + return; + } + if (auto R = Interp->ParseAndExecute("void takeApple(Apple &a1){}")) { + consumeError(std::move(R)); + return; + } + if (auto R = Interp->ParseAndExecute("Apple a1;")) { + consumeError(std::move(R)); + return; + } + if (auto R = Interp->ParseAndExecute("void takeAppleCopy(Apple a1){}")) { + consumeError(std::move(R)); + return; + } + + { + auto Err = llvm::Error::success(); + auto comps = runComp(*Interp, "takeApple(", Err); + EXPECT_EQ((size_t)1, comps.size()); + EXPECT_EQ(comps[0], std::string("a1")); + } + { + auto Err = llvm::Error::success(); + auto comps = runComp(*Interp, std::string("takeAppleCopy("), Err); + EXPECT_EQ((size_t)1, comps.size()); + EXPECT_EQ(comps[0], std::string("a1")); + } +} + +TEST(CodeCompletionTest, SubClassing) { + auto Interp = createInterpreter(); + if (auto R = Interp->ParseAndExecute("struct Fruit {};")) { + consumeError(std::move(R)); + return; + } + if (auto R = Interp->ParseAndExecute("struct Apple : Fruit{};")) { + consumeError(std::move(R)); + return; + } + if (auto R = Interp->ParseAndExecute("void takeFruit(Fruit &f){}")) { + consumeError(std::move(R)); + return; + } + if (auto R = Interp->ParseAndExecute("Apple a1;")) { + consumeError(std::move(R)); + return; + } + if (auto R = Interp->ParseAndExecute("Fruit f1;")) { + consumeError(std::move(R)); + return; + } + auto Err = llvm::Error::success(); + auto comps = runComp(*Interp, std::string("takeFruit("), Err); + EXPECT_EQ((size_t)2, comps.size()); + EXPECT_EQ(comps[0], std::string("a1")); + EXPECT_EQ(comps[1], std::string("f1")); +} + } // anonymous namespace