diff --git a/clang-tools-extra/clangd/CodeComplete.cpp b/clang-tools-extra/clangd/CodeComplete.cpp --- a/clang-tools-extra/clangd/CodeComplete.cpp +++ b/clang-tools-extra/clangd/CodeComplete.cpp @@ -214,7 +214,8 @@ // Returns a token identifying the overload set this is part of. // 0 indicates it's not part of any overload set. size_t overloadSet(const CodeCompleteOptions &Opts, llvm::StringRef FileName, - IncludeInserter *Inserter) const { + IncludeInserter *Inserter, + CodeCompletionContext::Kind CCContextKind) const { if (!Opts.BundleOverloads.value_or(false)) return 0; @@ -223,7 +224,7 @@ // bundle those, so we must resolve the header to be included here. std::string HeaderForHash; if (Inserter) { - if (auto Header = headerToInsertIfAllowed(Opts)) { + if (auto Header = headerToInsertIfAllowed(Opts, CCContextKind)) { if (auto HeaderFile = toHeaderFile(*Header, FileName)) { if (auto Spelled = Inserter->calculateIncludePath(*HeaderFile, FileName)) @@ -271,11 +272,21 @@ return 0; } + bool contextAllowsHeaderInsertion(CodeCompletionContext::Kind Kind) const { + // Explicitly disable insertions for forward declarations since they don't + // reference the declaration. + if (Kind == CodeCompletionContext::CCC_ObjCClassForwardDecl) + return false; + return true; + } + // The best header to include if include insertion is allowed. std::optional - headerToInsertIfAllowed(const CodeCompleteOptions &Opts) const { + headerToInsertIfAllowed(const CodeCompleteOptions &Opts, + CodeCompletionContext::Kind ContextKind) const { if (Opts.InsertIncludes == CodeCompleteOptions::NeverInsert || - RankedIncludeHeaders.empty()) + RankedIncludeHeaders.empty() || + !contextAllowsHeaderInsertion(ContextKind)) return std::nullopt; if (SemaResult && SemaResult->Declaration) { // Avoid inserting new #include if the declaration is found in the current @@ -401,7 +412,8 @@ std::move(*Spelled), Includes.shouldInsertInclude(*ResolvedDeclaring, *ResolvedInserted)); }; - bool ShouldInsert = C.headerToInsertIfAllowed(Opts).has_value(); + bool ShouldInsert = + C.headerToInsertIfAllowed(Opts, ContextKind).has_value(); Symbol::IncludeDirective Directive = insertionDirective(Opts); // Calculate include paths and edits for all possible headers. for (const auto &Inc : C.RankedIncludeHeaders) { @@ -780,6 +792,7 @@ case CodeCompletionContext::CCC_ObjCInterfaceName: case CodeCompletionContext::CCC_Symbol: case CodeCompletionContext::CCC_SymbolOrNewName: + case CodeCompletionContext::CCC_ObjCClassForwardDecl: return true; case CodeCompletionContext::CCC_OtherWithMacros: case CodeCompletionContext::CCC_DotMemberAccess: @@ -1422,6 +1435,10 @@ else if (Kind == CodeCompletionContext::CCC_ObjCProtocolName) // Don't show anything else in ObjC protocol completions. return false; + + if (Kind == CodeCompletionContext::CCC_ObjCClassForwardDecl) + return Sym.SymInfo.Kind == index::SymbolKind::Class && + Sym.SymInfo.Lang == index::SymbolLanguage::ObjC; return true; } @@ -1832,8 +1849,8 @@ assert(IdentifierResult); C.Name = IdentifierResult->Name; } - if (auto OverloadSet = - C.overloadSet(Opts, FileName, Inserter ? &*Inserter : nullptr)) { + if (auto OverloadSet = C.overloadSet( + Opts, FileName, Inserter ? &*Inserter : nullptr, CCContextKind)) { auto Ret = BundleLookup.try_emplace(OverloadSet, Bundles.size()); if (Ret.second) Bundles.emplace_back(); diff --git a/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp b/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp --- a/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp +++ b/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp @@ -3434,6 +3434,20 @@ EXPECT_THAT(Results.Completions, IsEmpty()); } +TEST(CompletionTest, ObjectiveCForwardDeclFromIndex) { + Symbol FoodClass = objcClass("FoodClass"); + FoodClass.IncludeHeaders.emplace_back("\"Foo.h\"", 2, Symbol::Import); + Symbol SymFood = objcProtocol("Food"); + auto Results = completions("@class Foo^", {SymFood, FoodClass}, + /*Opts=*/{}, "Foo.m"); + + // Should only give class names without any include insertion. + EXPECT_THAT(Results.Completions, + UnorderedElementsAre(AllOf(named("FoodClass"), + kind(CompletionItemKind::Class), + Not(insertInclude())))); +} + TEST(CompletionTest, CursorInSnippets) { clangd::CodeCompleteOptions Options; Options.EnableSnippets = true; diff --git a/clang/include/clang/Sema/CodeCompleteConsumer.h b/clang/include/clang/Sema/CodeCompleteConsumer.h --- a/clang/include/clang/Sema/CodeCompleteConsumer.h +++ b/clang/include/clang/Sema/CodeCompleteConsumer.h @@ -333,7 +333,10 @@ /// An unknown context, in which we are recovering from a parsing /// error and don't know which completions we should give. - CCC_Recovery + CCC_Recovery, + + /// Code completion in a @class forward declaration. + CCC_ObjCClassForwardDecl }; using VisitedContextSet = llvm::SmallPtrSet; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -13429,6 +13429,7 @@ ArrayRef Protocols); void CodeCompleteObjCProtocolDecl(Scope *S); void CodeCompleteObjCInterfaceDecl(Scope *S); + void CodeCompleteObjCClassForwardDecl(Scope *S); void CodeCompleteObjCSuperclass(Scope *S, IdentifierInfo *ClassName, SourceLocation ClassNameLoc); diff --git a/clang/lib/Frontend/ASTUnit.cpp b/clang/lib/Frontend/ASTUnit.cpp --- a/clang/lib/Frontend/ASTUnit.cpp +++ b/clang/lib/Frontend/ASTUnit.cpp @@ -322,6 +322,7 @@ if (ID->getDefinition()) Contexts |= (1LL << CodeCompletionContext::CCC_Expression); Contexts |= (1LL << CodeCompletionContext::CCC_ObjCInterfaceName); + Contexts |= (1LL << CodeCompletionContext::CCC_ObjCClassForwardDecl); } // Deal with tag names. @@ -2028,6 +2029,7 @@ case CodeCompletionContext::CCC_IncludedFile: case CodeCompletionContext::CCC_Attribute: case CodeCompletionContext::CCC_NewName: + case CodeCompletionContext::CCC_ObjCClassForwardDecl: // We're looking for nothing, or we're looking for names that cannot // be hidden. return; diff --git a/clang/lib/Parse/ParseObjc.cpp b/clang/lib/Parse/ParseObjc.cpp --- a/clang/lib/Parse/ParseObjc.cpp +++ b/clang/lib/Parse/ParseObjc.cpp @@ -153,6 +153,11 @@ while (true) { MaybeSkipAttributes(tok::objc_class); + if (Tok.is(tok::code_completion)) { + cutOffParsing(); + Actions.CodeCompleteObjCClassForwardDecl(getCurScope()); + return Actions.ConvertDeclToDeclGroup(nullptr); + } if (expectIdentifier()) { SkipUntil(tok::semi); return Actions.ConvertDeclToDeclGroup(nullptr); diff --git a/clang/lib/Sema/CodeCompleteConsumer.cpp b/clang/lib/Sema/CodeCompleteConsumer.cpp --- a/clang/lib/Sema/CodeCompleteConsumer.cpp +++ b/clang/lib/Sema/CodeCompleteConsumer.cpp @@ -83,6 +83,7 @@ case CCC_ObjCCategoryName: case CCC_IncludedFile: case CCC_Attribute: + case CCC_ObjCClassForwardDecl: return false; } @@ -166,6 +167,8 @@ return "Attribute"; case CCKind::CCC_Recovery: return "Recovery"; + case CCKind::CCC_ObjCClassForwardDecl: + return "ObjCClassForwardDecl"; } llvm_unreachable("Invalid CodeCompletionContext::Kind!"); } diff --git a/clang/lib/Sema/SemaCodeComplete.cpp b/clang/lib/Sema/SemaCodeComplete.cpp --- a/clang/lib/Sema/SemaCodeComplete.cpp +++ b/clang/lib/Sema/SemaCodeComplete.cpp @@ -8460,6 +8460,24 @@ Results.data(), Results.size()); } +void Sema::CodeCompleteObjCClassForwardDecl(Scope *S) { + ResultBuilder Results(*this, CodeCompleter->getAllocator(), + CodeCompleter->getCodeCompletionTUInfo(), + CodeCompletionContext::CCC_ObjCClassForwardDecl); + Results.EnterNewScope(); + + if (CodeCompleter->includeGlobals()) { + // Add all classes. + AddInterfaceResults(Context.getTranslationUnitDecl(), CurContext, false, + false, Results); + } + + Results.ExitScope(); + + HandleCodeCompleteResults(this, CodeCompleter, Results.getCompletionContext(), + Results.data(), Results.size()); +} + void Sema::CodeCompleteObjCSuperclass(Scope *S, IdentifierInfo *ClassName, SourceLocation ClassNameLoc) { ResultBuilder Results(*this, CodeCompleter->getAllocator(), diff --git a/clang/tools/libclang/CIndexCodeCompletion.cpp b/clang/tools/libclang/CIndexCodeCompletion.cpp --- a/clang/tools/libclang/CIndexCodeCompletion.cpp +++ b/clang/tools/libclang/CIndexCodeCompletion.cpp @@ -537,6 +537,7 @@ case CodeCompletionContext::CCC_Other: case CodeCompletionContext::CCC_ObjCInterface: case CodeCompletionContext::CCC_ObjCImplementation: + case CodeCompletionContext::CCC_ObjCClassForwardDecl: case CodeCompletionContext::CCC_NewName: case CodeCompletionContext::CCC_MacroName: case CodeCompletionContext::CCC_PreprocessorExpression: