Index: include/clang/Parse/Parser.h =================================================================== --- include/clang/Parse/Parser.h +++ include/clang/Parse/Parser.h @@ -220,6 +220,10 @@ /// function call. bool CalledSignatureHelp = false; + /// Tracks expected type of the expression currently being parsed. + /// Used by code completion for ranking. + PreferredTypeBuilder PreferredType; + /// The "depth" of the template parameters currently being parsed. unsigned TemplateParameterDepth; Index: include/clang/Sema/CodeCompleteConsumer.h =================================================================== --- include/clang/Sema/CodeCompleteConsumer.h +++ include/clang/Sema/CodeCompleteConsumer.h @@ -381,6 +381,7 @@ /// if the expression is a variable initializer or a function argument, the /// type of the corresponding variable or function parameter. QualType getPreferredType() const { return PreferredType; } + void setPreferredType(QualType T) { PreferredType = T; } /// Retrieve the type of the base object in a member-access /// expression. Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -274,6 +274,59 @@ } }; +/// Keeps track of expected type during expression parsing. +class PreferredTypeBuilder { +public: + class RestoreRAII; + + PreferredTypeBuilder() = default; + explicit PreferredTypeBuilder(QualType Type) : Type(Type) {} + + LLVM_NODISCARD RestoreRAII enterUnknown(); + LLVM_NODISCARD RestoreRAII enterCondition(Sema &S); + LLVM_NODISCARD RestoreRAII enterReturn(Sema &S); + LLVM_NODISCARD RestoreRAII enterVariableInit(Decl *D); + + LLVM_NODISCARD RestoreRAII enterUnary(Sema &S, tok::TokenKind Op); + LLVM_NODISCARD RestoreRAII enterBinary(Sema &S, Expr *LHS, tok::TokenKind Op); + LLVM_NODISCARD RestoreRAII enterSubscript(Sema &S, Expr *LHS); + /// Handles all type casts, including C-style cast, C++ casts, etc. + LLVM_NODISCARD RestoreRAII enterTypeCast(QualType CastType); + + QualType get() const { return Type; } + +private: + LLVM_NODISCARD RestoreRAII update(llvm::function_ref Updater); + + QualType Type; +}; + +class PreferredTypeBuilder::RestoreRAII { +public: + RestoreRAII(RestoreRAII const &) = delete; + RestoreRAII &operator=(RestoreRAII const &) = delete; + + explicit RestoreRAII(PreferredTypeBuilder &Builder) + : Old(Builder.Type), Builder(&Builder) {} + + RestoreRAII(RestoreRAII &&Other) { + Old = Other.Old; + Builder = Other.Builder; + + Other.Builder = nullptr; + } + + ~RestoreRAII() { + if (!Builder) + return; + Builder->Type = Old; + } + +private: + QualType Old; + PreferredTypeBuilder *Builder; +}; + /// Sema - This implements semantic analysis and AST building for C. class Sema { Sema(const Sema &) = delete; @@ -10342,11 +10395,14 @@ struct CodeCompleteExpressionData; void CodeCompleteExpression(Scope *S, const CodeCompleteExpressionData &Data); - void CodeCompleteExpression(Scope *S, QualType PreferredType); + void CodeCompleteExpression(Scope *S, QualType PreferredType, + bool IsParenthesized = false); void CodeCompleteMemberReferenceExpr(Scope *S, Expr *Base, Expr *OtherOpBase, SourceLocation OpLoc, bool IsArrow, - bool IsBaseExprStatement); - void CodeCompletePostfixExpression(Scope *S, ExprResult LHS); + bool IsBaseExprStatement, + QualType PreferredType); + void CodeCompletePostfixExpression(Scope *S, ExprResult LHS, + QualType PreferredType); void CodeCompleteTag(Scope *S, unsigned TagSpec); void CodeCompleteTypeQualifiers(DeclSpec &DS); void CodeCompleteFunctionQualifiers(DeclSpec &DS, Declarator &D, @@ -10368,9 +10424,7 @@ IdentifierInfo *II, SourceLocation OpenParLoc); void CodeCompleteInitializer(Scope *S, Decl *D); - void CodeCompleteReturn(Scope *S); void CodeCompleteAfterIf(Scope *S); - void CodeCompleteBinaryRHS(Scope *S, Expr *LHS, tok::TokenKind Op); void CodeCompleteQualifiedId(Scope *S, CXXScopeSpec &SS, bool EnteringContext, QualType BaseType); Index: lib/Parse/ParseDecl.cpp =================================================================== --- lib/Parse/ParseDecl.cpp +++ lib/Parse/ParseDecl.cpp @@ -45,6 +45,8 @@ AccessSpecifier AS, Decl **OwnedType, ParsedAttributes *Attrs) { + auto TypeRAII = PreferredType.enterUnknown(); + DeclSpecContext DSC = getDeclSpecContextFromDeclaratorContext(Context); if (DSC == DeclSpecContext::DSC_normal) DSC = DeclSpecContext::DSC_type_specifier; @@ -2275,7 +2277,11 @@ return nullptr; } - ExprResult Init(ParseInitializer()); + ExprResult Init; + { + auto TypeRAII = PreferredType.enterVariableInit(ThisDecl); + Init = ParseInitializer(); + } // If this is the only decl in (possibly) range based for statement, // our best guess is that the user meant ':' instead of '='. Index: lib/Parse/ParseDeclCXX.cpp =================================================================== --- lib/Parse/ParseDeclCXX.cpp +++ lib/Parse/ParseDeclCXX.cpp @@ -922,6 +922,8 @@ assert(Tok.isOneOf(tok::kw_decltype, tok::annot_decltype) && "Not a decltype specifier"); + auto TypeRAII = PreferredType.enterUnknown(); + ExprResult Result; SourceLocation StartLoc = Tok.getLocation(); SourceLocation EndLoc; Index: lib/Parse/ParseExpr.cpp =================================================================== --- lib/Parse/ParseExpr.cpp +++ lib/Parse/ParseExpr.cpp @@ -159,7 +159,7 @@ /// Parse an expr that doesn't include (top-level) commas. ExprResult Parser::ParseAssignmentExpression(TypeCastState isTypeCast) { if (Tok.is(tok::code_completion)) { - Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Expression); + Actions.CodeCompleteExpression(getCurScope(), PreferredType.get()); cutOffParsing(); return ExprError(); } @@ -393,15 +393,8 @@ } } - // Code completion for the right-hand side of a binary expression goes - // through a special hook that takes the left-hand side into account. - if (Tok.is(tok::code_completion)) { - Actions.CodeCompleteBinaryRHS(getCurScope(), LHS.get(), - OpToken.getKind()); - cutOffParsing(); - return ExprError(); - } - + auto TypeRAII = + PreferredType.enterBinary(Actions, LHS.get(), OpToken.getKind()); // Parse another leaf here for the RHS of the operator. // ParseCastExpression works here because all RHS expressions in C have it // as a prefix, at least. However, in C++, an assignment-expression could @@ -1115,6 +1108,8 @@ // -- cast-expression Token SavedTok = Tok; ConsumeToken(); + + auto TypeRAII = PreferredType.enterUnary(Actions, SavedTok.getKind()); // One special case is implicitly handled here: if the preceding tokens are // an ambiguous cast expression, such as "(T())++", then we recurse to // determine whether the '++' is prefix or postfix. @@ -1134,6 +1129,7 @@ return Res; } case tok::amp: { // unary-expression: '&' cast-expression + auto TypeRAII = PreferredType.enterUnary(Actions, tok::amp); // Special treatment because of member pointers SourceLocation SavedLoc = ConsumeToken(); Res = ParseCastExpression(false, true); @@ -1149,6 +1145,8 @@ case tok::exclaim: // unary-expression: '!' cast-expression case tok::kw___real: // unary-expression: '__real' cast-expression [GNU] case tok::kw___imag: { // unary-expression: '__imag' cast-expression [GNU] + auto TypeRAII = PreferredType.enterUnary(Actions, Tok.getKind()); + SourceLocation SavedLoc = ConsumeToken(); Res = ParseCastExpression(false); if (!Res.isInvalid()) @@ -1184,9 +1182,13 @@ // unary-expression: 'sizeof' '(' type-name ')' case tok::kw_vec_step: // unary-expression: OpenCL 'vec_step' expression // unary-expression: '__builtin_omp_required_simd_align' '(' type-name ')' - case tok::kw___builtin_omp_required_simd_align: + case tok::kw___builtin_omp_required_simd_align: { + auto TypeRAII = PreferredType.enterUnknown(); return ParseUnaryExprOrTypeTraitExpression(); + } case tok::ampamp: { // unary-expression: '&&' identifier + auto TypeRAII = PreferredType.enterUnknown(); + SourceLocation AmpAmpLoc = ConsumeToken(); if (Tok.isNot(tok::identifier)) return ExprError(Diag(Tok, diag::err_expected) << tok::identifier); @@ -1386,6 +1388,7 @@ SourceLocation KeyLoc = ConsumeToken(); BalancedDelimiterTracker T(*this, tok::l_paren); + auto TypeRAII = PreferredType.enterUnknown(); if (T.expectAndConsume(diag::err_expected_lparen_after, "noexcept")) return ExprError(); // C++11 [expr.unary.noexcept]p1: @@ -1424,7 +1427,7 @@ Res = ParseBlockLiteralExpression(); break; case tok::code_completion: { - Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Expression); + Actions.CodeCompleteExpression(getCurScope(), PreferredType.get()); cutOffParsing(); return ExprError(); } @@ -1504,7 +1507,8 @@ if (InMessageExpression) return LHS; - Actions.CodeCompletePostfixExpression(getCurScope(), LHS); + Actions.CodeCompletePostfixExpression(getCurScope(), LHS, + PreferredType.get()); cutOffParsing(); return ExprError(); @@ -1541,6 +1545,8 @@ return ExprError(); } + auto TypeRAII = PreferredType.enterSubscript(Actions, LHS.get()); + BalancedDelimiterTracker T(*this, tok::l_square); T.consumeOpen(); Loc = T.getOpenLocation(); @@ -1773,7 +1779,8 @@ // Code completion for a member access expression. Actions.CodeCompleteMemberReferenceExpr( getCurScope(), Base, CorrectedBase, OpLoc, OpKind == tok::arrow, - Base && ExprStatementTokLoc == Base->getBeginLoc()); + Base && ExprStatementTokLoc == Base->getBeginLoc(), + PreferredType.get()); cutOffParsing(); return ExprError(); @@ -2332,9 +2339,9 @@ CastTy = nullptr; if (Tok.is(tok::code_completion)) { - Actions.CodeCompleteOrdinaryName(getCurScope(), - ExprType >= CompoundLiteral? Sema::PCC_ParenthesizedExpression - : Sema::PCC_Expression); + Actions.CodeCompleteExpression(getCurScope(), PreferredType.get(), + /*IsParenthesized=*/ExprType >= + CompoundLiteral); cutOffParsing(); return ExprError(); } @@ -2414,6 +2421,9 @@ TypeResult Ty = ParseTypeName(); T.consumeClose(); ColonProtection.restore(); + + auto TypeRAII = PreferredType.enterTypeCast(Ty.get().get()); + RParenLoc = T.getCloseLocation(); ExprResult SubExpr = ParseCastExpression(/*isUnaryExpression=*/false); @@ -2545,6 +2555,7 @@ return ExprError(); } + auto TypeRAII = PreferredType.enterTypeCast(CastTy.get()); // Parse the cast-expression that follows it next. // TODO: For cast expression with CastTy. Result = ParseCastExpression(/*isUnaryExpression=*/false, @@ -2840,13 +2851,15 @@ bool Parser::ParseExpressionList(SmallVectorImpl &Exprs, SmallVectorImpl &CommaLocs, llvm::function_ref Completer) { + auto TypeRAII = PreferredType.enterUnknown(); + bool SawError = false; while (1) { if (Tok.is(tok::code_completion)) { if (Completer) Completer(); else - Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Expression); + Actions.CodeCompleteExpression(getCurScope(), PreferredType.get()); cutOffParsing(); return true; } Index: lib/Parse/ParseExprCXX.cpp =================================================================== --- lib/Parse/ParseExprCXX.cpp +++ lib/Parse/ParseExprCXX.cpp @@ -675,6 +675,7 @@ /// trailing-return-type[opt] /// ExprResult Parser::ParseLambdaExpression() { + auto TypeRAII = PreferredType.enterUnknown(); // Parse lambda-introducer. LambdaIntroducer Intro; Optional DiagID = ParseLambdaIntroducer(Intro); @@ -1384,6 +1385,8 @@ ExprResult Parser::ParseCXXTypeid() { assert(Tok.is(tok::kw_typeid) && "Not 'typeid'!"); + auto TypeRAII = PreferredType.enterUnknown(); + SourceLocation OpLoc = ConsumeToken(); SourceLocation LParenLoc, RParenLoc; BalancedDelimiterTracker T(*this, tok::l_paren); @@ -1451,6 +1454,8 @@ ExprResult Parser::ParseCXXUuidof() { assert(Tok.is(tok::kw___uuidof) && "Not '__uuidof'!"); + auto TypeRAII = PreferredType.enterUnknown(); + SourceLocation OpLoc = ConsumeToken(); BalancedDelimiterTracker T(*this, tok::l_paren); @@ -1606,11 +1611,15 @@ case tok::comma: return Actions.ActOnCXXThrow(getCurScope(), ThrowLoc, nullptr); - default: + default: { + auto TypeRAII = PreferredType.enterUnknown(); + ExprResult Expr(ParseAssignmentExpression()); - if (Expr.isInvalid()) return Expr; + if (Expr.isInvalid()) + return Expr; return Actions.ActOnCXXThrow(getCurScope(), ThrowLoc, Expr.get()); } + } } /// Parse the C++ Coroutines co_yield expression. @@ -1620,6 +1629,8 @@ ExprResult Parser::ParseCoyieldExpression() { assert(Tok.is(tok::kw_co_yield) && "Not co_yield!"); + auto TypeRAII = PreferredType.enterUnknown(); + SourceLocation Loc = ConsumeToken(); ExprResult Expr = Tok.is(tok::l_brace) ? ParseBraceInitializer() : ParseAssignmentExpression(); @@ -1657,6 +1668,8 @@ Declarator DeclaratorInfo(DS, DeclaratorContext::FunctionalCastContext); ParsedType TypeRep = Actions.ActOnTypeName(getCurScope(), DeclaratorInfo).get(); + auto TypeRAII = PreferredType.enterTypeCast(TypeRep.get()); + assert((Tok.is(tok::l_paren) || (getLangOpts().CPlusPlus11 && Tok.is(tok::l_brace))) && "Expected '(' or '{'!"); @@ -1740,6 +1753,7 @@ Sema::ConditionKind CK, ForRangeInfo *FRI) { ParenBraceBracketBalancer BalancerRAIIObj(*this); + auto TypeRAII = PreferredType.enterCondition(Actions); if (Tok.is(tok::code_completion)) { Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Condition); @@ -1859,6 +1873,7 @@ diag::warn_cxx98_compat_generalized_initializer_lists); InitExpr = ParseBraceInitializer(); } else if (CopyInitialization) { + auto TypeRAII = PreferredType.enterVariableInit(DeclOut); InitExpr = ParseAssignmentExpression(); } else if (Tok.is(tok::l_paren)) { // This was probably an attempt to initialize the variable. @@ -2966,6 +2981,7 @@ assert(Tok.is(tok::kw_delete) && "Expected 'delete' keyword"); ConsumeToken(); // Consume 'delete' + auto TypeRAII = PreferredType.enterUnknown(); // Array delete? bool ArrayDelete = false; if (Tok.is(tok::l_square) && NextToken().is(tok::r_square)) { @@ -3043,6 +3059,8 @@ /// type-id ...[opt] type-id-seq[opt] /// ExprResult Parser::ParseTypeTrait() { + auto TypeRAII = PreferredType.enterUnknown(); + tok::TokenKind Kind = Tok.getKind(); unsigned Arity = TypeTraitArity(Kind); @@ -3102,6 +3120,8 @@ /// [Embarcadero] '__array_extent' '(' type-id ',' expression ')' /// ExprResult Parser::ParseArrayTypeTrait() { + auto TypeRAII = PreferredType.enterUnknown(); + ArrayTypeTrait ATT = ArrayTypeTraitFromTokKind(Tok.getKind()); SourceLocation Loc = ConsumeToken(); @@ -3145,6 +3165,8 @@ /// [Embarcadero] expression-trait '(' expression ')' /// ExprResult Parser::ParseExpressionTrait() { + auto TypeRAII = PreferredType.enterUnknown(); + ExpressionTrait ET = ExpressionTraitFromTokKind(Tok.getKind()); SourceLocation Loc = ConsumeToken(); Index: lib/Parse/ParseInit.cpp =================================================================== --- lib/Parse/ParseInit.cpp +++ lib/Parse/ParseInit.cpp @@ -386,6 +386,8 @@ /// initializer-list ',' designation[opt] initializer ...[opt] /// ExprResult Parser::ParseBraceInitializer() { + auto TypeRAII = PreferredType.enterUnknown(); + InMessageExpressionRAIIObject InMessage(*this, false); BalancedDelimiterTracker T(*this, tok::l_brace); Index: lib/Parse/ParseStmt.cpp =================================================================== --- lib/Parse/ParseStmt.cpp +++ lib/Parse/ParseStmt.cpp @@ -1971,9 +1971,11 @@ ExprResult R; if (Tok.isNot(tok::semi)) { + auto TypeRAII = IsCoreturn ? PreferredType.enterUnknown() + : PreferredType.enterReturn(Actions); // FIXME: Code completion for co_return. if (Tok.is(tok::code_completion) && !IsCoreturn) { - Actions.CodeCompleteReturn(getCurScope()); + Actions.CodeCompleteExpression(getCurScope(), PreferredType.get()); cutOffParsing(); return StmtError(); } Index: lib/Parse/ParseTemplate.cpp =================================================================== --- lib/Parse/ParseTemplate.cpp +++ lib/Parse/ParseTemplate.cpp @@ -1304,6 +1304,7 @@ /// template-argument-list ',' template-argument bool Parser::ParseTemplateArgumentList(TemplateArgList &TemplateArgs) { + auto TypeRAII = PreferredType.enterUnknown(); ColonProtectionRAIIObject ColonProtection(*this, false); Index: lib/Sema/SemaCodeComplete.cpp =================================================================== --- lib/Sema/SemaCodeComplete.cpp +++ lib/Sema/SemaCodeComplete.cpp @@ -348,6 +348,180 @@ }; } // namespace +PreferredTypeBuilder::RestoreRAII +PreferredTypeBuilder::update(llvm::function_ref Updater) { + RestoreRAII R(*this); + Updater(); + return R; +} + +PreferredTypeBuilder::RestoreRAII PreferredTypeBuilder::enterReturn(Sema &S) { + return update([&]() { + if (isa(S.CurContext)) { + if (sema::BlockScopeInfo *BSI = S.getCurBlock()) + Type = BSI->ReturnType; + return; + } + if (const auto *Function = dyn_cast(S.CurContext)) { + Type = Function->getReturnType(); + return; + } + if (const auto *Method = dyn_cast(S.CurContext)) { + Type = Method->getReturnType(); + return; + } + Type = QualType(); + }); +} + +PreferredTypeBuilder::RestoreRAII +PreferredTypeBuilder::enterVariableInit(Decl *D) { + return update([&]() { + auto *VD = llvm::dyn_cast_or_null(D); + Type = VD ? VD->getType() : QualType(); + }); +} + +static QualType getPreferredTypeOfBinaryRHS(Sema &S, Expr *LHS, + tok::TokenKind Op) { + if (!LHS) + return QualType(); + + QualType LHSType = LHS->getType(); + if (LHSType->isPointerType()) { + if (Op == tok::plus || Op == tok::plusequal || Op == tok::minusequal) + return S.getASTContext().getPointerDiffType(); + // Pointer difference is more common than subtracting an int from a pointer. + if (Op == tok::minus) + return LHSType; + } + + switch (Op) { + // No way to infer the type of RHS from LHS. + case tok::comma: + return QualType(); + // Prefer the type of the left operand for all of these. + // Arithmetic operations. + case tok::plus: + case tok::plusequal: + case tok::minus: + case tok::minusequal: + case tok::percent: + case tok::percentequal: + case tok::slash: + case tok::slashequal: + case tok::star: + case tok::starequal: + // Assignment. + case tok::equal: + // Comparison operators. + case tok::equalequal: + case tok::exclaimequal: + case tok::less: + case tok::lessequal: + case tok::greater: + case tok::greaterequal: + case tok::spaceship: + return LHS->getType(); + // Binary shifts are often overloaded, so don't try to guess those. + case tok::greatergreater: + case tok::greatergreaterequal: + case tok::lessless: + case tok::lesslessequal: + if (LHSType->isIntegralOrEnumerationType()) + return S.getASTContext().IntTy; + return QualType(); + // Logical operators, assume we want bool. + case tok::ampamp: + case tok::pipepipe: + case tok::caretcaret: + return S.getASTContext().BoolTy; + // Operators often used for bit manipulation are typically used with the type + // of the left argument. + case tok::pipe: + case tok::pipeequal: + case tok::caret: + case tok::caretequal: + case tok::amp: + case tok::ampequal: + if (LHSType->isIntegralOrEnumerationType()) + return LHSType; + return QualType(); + // RHS should be a pointer to a member of the 'LHS' type, but we can't give + // any particular type here. + case tok::periodstar: + case tok::arrowstar: + return QualType(); + default: + // FIXME(ibiryukov): handle the missing op, re-add the assertion. + // assert(false && "unhandled binary op"); + return QualType(); + } +} + +PreferredTypeBuilder::RestoreRAII +PreferredTypeBuilder::enterBinary(Sema &S, Expr *LHS, tok::TokenKind Op) { + return update([&] { Type = getPreferredTypeOfBinaryRHS(S, LHS, Op); }); +} + +PreferredTypeBuilder::RestoreRAII +PreferredTypeBuilder::enterUnary(Sema &S, tok::TokenKind Op) { + return update([&] { + switch (Op) { + case tok::exclaim: + Type = S.getASTContext().BoolTy; + break; + case tok::amp: + if (!Type.isNull() && Type->isPointerType()) + Type = Type->getPointeeType(); + else + Type = QualType(); + break; + case tok::star: + if (Type.isNull()) + break; + Type = S.getASTContext().getPointerType(Type.getNonReferenceType()); + break; + case tok::plus: + case tok::minus: + case tok::tilde: + case tok::minusminus: + case tok::plusplus: + if (Type.isNull()) + Type = S.getASTContext().IntTy; + // else leave as is, these operators typically return the same type. + break; + case tok::kw___real: + case tok::kw___imag: + Type = QualType(); + break; + default: + assert(false && "unhnalded unary op"); + Type = QualType(); + break; + } + }); +} + +PreferredTypeBuilder::RestoreRAII +PreferredTypeBuilder::enterSubscript(Sema &S, Expr *LHS) { + return update([&]() { Type = S.getASTContext().IntTy; }); +} + +PreferredTypeBuilder::RestoreRAII +PreferredTypeBuilder::enterTypeCast(QualType CastType) { + return update([&] { Type = CastType; }); +} + +PreferredTypeBuilder::RestoreRAII PreferredTypeBuilder::enterUnknown() { + return update([&] { Type = QualType(); }); +} + +PreferredTypeBuilder::RestoreRAII +PreferredTypeBuilder::enterCondition(Sema &S) { + return update([&] { Type = S.getASTContext().BoolTy; }); +} + class ResultBuilder::ShadowMapEntry::iterator { llvm::PointerUnion DeclOrIterator; unsigned SingleDeclIndex; @@ -3856,13 +4030,15 @@ } struct Sema::CodeCompleteExpressionData { - CodeCompleteExpressionData(QualType PreferredType = QualType()) + CodeCompleteExpressionData(QualType PreferredType = QualType(), + bool IsParenthesized = false) : PreferredType(PreferredType), IntegralConstantExpression(false), - ObjCCollection(false) {} + ObjCCollection(false), IsParenthesized(IsParenthesized) {} QualType PreferredType; bool IntegralConstantExpression; bool ObjCCollection; + bool IsParenthesized; SmallVector IgnoreDecls; }; @@ -3873,13 +4049,18 @@ ResultBuilder Results( *this, CodeCompleter->getAllocator(), CodeCompleter->getCodeCompletionTUInfo(), - CodeCompletionContext(CodeCompletionContext::CCC_Expression, - Data.PreferredType)); + CodeCompletionContext( + Data.IsParenthesized + ? CodeCompletionContext::CCC_ParenthesizedExpression + : CodeCompletionContext::CCC_Expression, + Data.PreferredType)); + auto PCC = + Data.IsParenthesized ? PCC_ParenthesizedExpression : PCC_Expression; if (Data.ObjCCollection) Results.setFilter(&ResultBuilder::IsObjCCollection); else if (Data.IntegralConstantExpression) Results.setFilter(&ResultBuilder::IsIntegralConstantValue); - else if (WantTypesInContext(PCC_Expression, getLangOpts())) + else if (WantTypesInContext(PCC, getLangOpts())) Results.setFilter(&ResultBuilder::IsOrdinaryName); else Results.setFilter(&ResultBuilder::IsOrdinaryNonTypeName); @@ -3897,7 +4078,7 @@ CodeCompleter->loadExternal()); Results.EnterNewScope(); - AddOrdinaryNameResults(PCC_Expression, S, *this, Results); + AddOrdinaryNameResults(PCC, S, *this, Results); Results.ExitScope(); bool PreferredTypeIsPointer = false; @@ -3917,13 +4098,16 @@ Results.data(), Results.size()); } -void Sema::CodeCompleteExpression(Scope *S, QualType PreferredType) { - return CodeCompleteExpression(S, CodeCompleteExpressionData(PreferredType)); +void Sema::CodeCompleteExpression(Scope *S, QualType PreferredType, + bool IsParenthesized) { + return CodeCompleteExpression( + S, CodeCompleteExpressionData(PreferredType, IsParenthesized)); } -void Sema::CodeCompletePostfixExpression(Scope *S, ExprResult E) { +void Sema::CodeCompletePostfixExpression(Scope *S, ExprResult E, + QualType PreferredType) { if (E.isInvalid()) - CodeCompleteOrdinaryName(S, PCC_RecoveryInFunction); + CodeCompleteExpression(S, PreferredType); else if (getLangOpts().ObjC) CodeCompleteObjCInstanceMessage(S, E.get(), None, false); } @@ -4211,7 +4395,8 @@ void Sema::CodeCompleteMemberReferenceExpr(Scope *S, Expr *Base, Expr *OtherOpBase, SourceLocation OpLoc, bool IsArrow, - bool IsBaseExprStatement) { + bool IsBaseExprStatement, + QualType PreferredType) { if (!Base || !CodeCompleter) return; @@ -4239,6 +4424,7 @@ } CodeCompletionContext CCContext(contextKind, ConvertedBaseType); + CCContext.setPreferredType(PreferredType); ResultBuilder Results(*this, CodeCompleter->getAllocator(), CodeCompleter->getCodeCompletionTUInfo(), CCContext, &ResultBuilder::IsMember); @@ -4800,22 +4986,6 @@ CodeCompleteExpression(S, Data); } -void Sema::CodeCompleteReturn(Scope *S) { - QualType ResultType; - if (isa(CurContext)) { - if (BlockScopeInfo *BSI = getCurBlock()) - ResultType = BSI->ReturnType; - } else if (const auto *Function = dyn_cast(CurContext)) - ResultType = Function->getReturnType(); - else if (const auto *Method = dyn_cast(CurContext)) - ResultType = Method->getReturnType(); - - if (ResultType.isNull()) - CodeCompleteOrdinaryName(S, PCC_Expression); - else - CodeCompleteExpression(S, ResultType); -} - void Sema::CodeCompleteAfterIf(Scope *S) { ResultBuilder Results(*this, CodeCompleter->getAllocator(), CodeCompleter->getCodeCompletionTUInfo(), @@ -4877,91 +5047,6 @@ Results.data(), Results.size()); } -static QualType getPreferredTypeOfBinaryRHS(Sema &S, Expr *LHS, - tok::TokenKind Op) { - if (!LHS) - return QualType(); - - QualType LHSType = LHS->getType(); - if (LHSType->isPointerType()) { - if (Op == tok::plus || Op == tok::plusequal || Op == tok::minusequal) - return S.getASTContext().getPointerDiffType(); - // Pointer difference is more common than subtracting an int from a pointer. - if (Op == tok::minus) - return LHSType; - } - - switch (Op) { - // No way to infer the type of RHS from LHS. - case tok::comma: - return QualType(); - // Prefer the type of the left operand for all of these. - // Arithmetic operations. - case tok::plus: - case tok::plusequal: - case tok::minus: - case tok::minusequal: - case tok::percent: - case tok::percentequal: - case tok::slash: - case tok::slashequal: - case tok::star: - case tok::starequal: - // Assignment. - case tok::equal: - // Comparison operators. - case tok::equalequal: - case tok::exclaimequal: - case tok::less: - case tok::lessequal: - case tok::greater: - case tok::greaterequal: - case tok::spaceship: - return LHS->getType(); - // Binary shifts are often overloaded, so don't try to guess those. - case tok::greatergreater: - case tok::greatergreaterequal: - case tok::lessless: - case tok::lesslessequal: - if (LHSType->isIntegralOrEnumerationType()) - return S.getASTContext().IntTy; - return QualType(); - // Logical operators, assume we want bool. - case tok::ampamp: - case tok::pipepipe: - case tok::caretcaret: - return S.getASTContext().BoolTy; - // Operators often used for bit manipulation are typically used with the type - // of the left argument. - case tok::pipe: - case tok::pipeequal: - case tok::caret: - case tok::caretequal: - case tok::amp: - case tok::ampequal: - if (LHSType->isIntegralOrEnumerationType()) - return LHSType; - return QualType(); - // RHS should be a pointer to a member of the 'LHS' type, but we can't give - // any particular type here. - case tok::periodstar: - case tok::arrowstar: - return QualType(); - default: - // FIXME(ibiryukov): handle the missing op, re-add the assertion. - // assert(false && "unhandled binary op"); - return QualType(); - } -} - -void Sema::CodeCompleteBinaryRHS(Scope *S, Expr *LHS, tok::TokenKind Op) { - auto PreferredType = getPreferredTypeOfBinaryRHS(*this, LHS, Op); - if (!PreferredType.isNull()) - CodeCompleteExpression(S, PreferredType); - else - CodeCompleteOrdinaryName(S, PCC_Expression); -} - void Sema::CodeCompleteQualifiedId(Scope *S, CXXScopeSpec &SS, bool EnteringContext, QualType BaseType) { if (SS.isEmpty() || !CodeCompleter) Index: unittests/Sema/CodeCompleteTest.cpp =================================================================== --- unittests/Sema/CodeCompleteTest.cpp +++ unittests/Sema/CodeCompleteTest.cpp @@ -340,4 +340,96 @@ EXPECT_THAT(collectPreferredTypes(Code), Each("NULL TYPE")); } +TEST(PreferredTypeTest, Members) { + StringRef Code = R"cpp( + struct vector { + int *begin(); + vector clone(); + }; + + void test(int *a) { + a = ^vector().^clone().^begin(); + } + )cpp"; + EXPECT_THAT(collectPreferredTypes(Code), Each("int *")); +} + +TEST(PreferredTypeTest, Conditions) { + StringRef Code = R"cpp( + struct vector { + bool empty(); + }; + + void test() { + if (^vector().^empty()) {} + while (^vector().^empty()) {} + for (; ^vector().^empty();) {} + } + )cpp"; + EXPECT_THAT(collectPreferredTypes(Code), Each("_Bool")); +} + +TEST(PreferredTypeTest, InitAndAssignment) { + StringRef Code = R"cpp( + struct vector { + int* begin(); + }; + + void test() { + const int* x = ^vector().^begin(); + x = ^vector().^begin(); + + if (const int* y = ^vector().^begin()) {} + } + )cpp"; + EXPECT_THAT(collectPreferredTypes(Code), Each("const int *")); +} + +TEST(PreferredTypeTest, UnaryExprs) { + StringRef Code = R"cpp( + void test(long long a) { + a = +^a; + a = -^a + a = ++^a; + a = --^a; + } + )cpp"; + EXPECT_THAT(collectPreferredTypes(Code), Each("long long")); + + Code = R"cpp( + void test(int a, int *ptr) { + !^a; + !^ptr; + !!!^a; + + a = !^a; + a = !^ptr; + a = !!!^a; + } + )cpp"; + EXPECT_THAT(collectPreferredTypes(Code), Each("_Bool")); + + Code = R"cpp( + void test(int a) { + const int* x = &^a; + } + )cpp"; + EXPECT_THAT(collectPreferredTypes(Code), Each("const int")); + + Code = R"cpp( + void test(int *a) { + int x = *^a; + int &r = *^a; + } + )cpp"; + EXPECT_THAT(collectPreferredTypes(Code), Each("int *")); + + Code = R"cpp( + void test(int a) { + *^a; + &^a; + } + + )cpp"; +} } // namespace