Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -10350,7 +10350,7 @@ void CodeCompleteInitializer(Scope *S, Decl *D); void CodeCompleteReturn(Scope *S); void CodeCompleteAfterIf(Scope *S); - void CodeCompleteAssignmentRHS(Scope *S, Expr *LHS); + void CodeCompleteBinaryRHS(Scope *S, Expr *LHS, tok::TokenKind Op); void CodeCompleteQualifiedId(Scope *S, CXXScopeSpec &SS, bool EnteringContext, QualType BaseType); Index: lib/Parse/ParseExpr.cpp =================================================================== --- lib/Parse/ParseExpr.cpp +++ lib/Parse/ParseExpr.cpp @@ -393,10 +393,11 @@ } } - // Code completion for the right-hand side of an assignment expression - // goes through a special hook that takes the left-hand side into account. - if (Tok.is(tok::code_completion) && NextTokPrec == prec::Assignment) { - Actions.CodeCompleteAssignmentRHS(getCurScope(), LHS.get()); + // Code completion for the right-hand side of a binary expression goes + // through a special hook that takes the left-hand side into account. + if (Tok.is(tok::code_completion)) { + Actions.CodeCompleteBinaryRHS(getCurScope(), LHS.get(), + OpToken.getKind()); cutOffParsing(); return ExprError(); } Index: lib/Sema/SemaCodeComplete.cpp =================================================================== --- lib/Sema/SemaCodeComplete.cpp +++ lib/Sema/SemaCodeComplete.cpp @@ -4878,9 +4878,86 @@ Results.data(), Results.size()); } -void Sema::CodeCompleteAssignmentRHS(Scope *S, Expr *LHS) { - if (LHS) - CodeCompleteExpression(S, static_cast(LHS)->getType()); +static QualType getPreferredTypeOfBinaryRHS(Sema &S, Expr *LHS, + tok::TokenKind Op) { + if (!LHS) + return QualType(); + + QualType LHSType = LHS->getType(); + if (LHSType->isPointerType()) { + if (Op == tok::plus || Op == tok::plusequal || Op == tok::minusequal) + return S.getASTContext().getPointerDiffType(); + // Pointer difference is more common than subtracting an int from a pointer. + if (Op == tok::minus) + return LHSType; + } + + switch (Op) { + // No way to infer the type of RHS from LHS. + case tok::comma: + return QualType(); + // Prefer the type of the left operand for all of these. + // Arithmetic operations. + case tok::plus: + case tok::plusequal: + case tok::minus: + case tok::minusequal: + case tok::percent: + case tok::percentequal: + case tok::slash: + case tok::slashequal: + case tok::star: + case tok::starequal: + // Assignment. + case tok::equal: + // Comparison operators. + case tok::equalequal: + case tok::exclaimequal: + case tok::less: + case tok::lessequal: + case tok::greater: + case tok::greaterequal: + case tok::spaceship: + return LHS->getType(); + // Binary shifts are often overloaded, so don't try to guess those. + case tok::greatergreater: + case tok::greatergreaterequal: + case tok::lessless: + case tok::lesslessequal: + if (LHSType->isIntegralOrEnumerationType()) + return S.getASTContext().IntTy; + return QualType(); + // Logical operators, assume we want bool. + case tok::ampamp: + case tok::pipepipe: + case tok::caretcaret: + return S.getASTContext().BoolTy; + // Operators often used for bit manipulation are typically used with the type + // of the left argument. + case tok::pipe: + case tok::pipeequal: + case tok::caret: + case tok::caretequal: + case tok::amp: + case tok::ampequal: + if (LHSType->isIntegralOrEnumerationType()) + return LHSType; + return QualType(); + // RHS should be a pointer to a member of the 'LHS' type, but we can't give + // any particular type here. + case tok::periodstar: + case tok::arrowstar: + return QualType(); + default: + assert(false && "unhandled binary op"); + return QualType(); + } +} + +void Sema::CodeCompleteBinaryRHS(Scope *S, Expr *LHS, tok::TokenKind Op) { + auto PreferredType = getPreferredTypeOfBinaryRHS(*this, LHS, Op); + if (!PreferredType.isNull()) + CodeCompleteExpression(S, PreferredType); else CodeCompleteOrdinaryName(S, PCC_Expression); } Index: test/Index/complete-exprs.c =================================================================== --- test/Index/complete-exprs.c +++ test/Index/complete-exprs.c @@ -33,16 +33,11 @@ // CHECK-CC1: ParmDecl:{ResultType int}{TypedText j} (8) // CHECK-CC1: NotImplemented:{ResultType size_t}{TypedText sizeof}{LeftParen (}{Placeholder expression-or-type}{RightParen )} (40) // RUN: env CINDEXTEST_EDITING=1 CINDEXTEST_COMPLETION_CACHING=1 c-index-test -code-completion-at=%s:7:10 -Xclang -code-completion-patterns %s | FileCheck -check-prefix=CHECK-CC1 %s -// RUN: c-index-test -code-completion-at=%s:7:14 -Xclang -code-completion-patterns %s | FileCheck -check-prefix=CHECK-CC3 %s -// RUN: env CINDEXTEST_EDITING=1 CINDEXTEST_COMPLETION_CACHING=1 c-index-test -code-completion-at=%s:7:14 -Xclang -code-completion-patterns %s | FileCheck -check-prefix=CHECK-CC3 %s -// CHECK-CC3: macro definition:{TypedText __VERSION__} (70) -// CHECK-CC3: FunctionDecl:{ResultType int}{TypedText f}{LeftParen (}{Placeholder int}{RightParen )} (50) -// CHECK-CC3-NOT: NotImplemented:{TypedText float} -// CHECK-CC3: ParmDecl:{ResultType int}{TypedText j} (34) -// CHECK-CC3: NotImplemented:{ResultType size_t}{TypedText sizeof}{LeftParen (}{Placeholder expressio +// RUN: c-index-test -code-completion-at=%s:7:14 -Xclang -code-completion-patterns %s | FileCheck -check-prefix=CHECK-CC1 %s +// RUN: env CINDEXTEST_EDITING=1 CINDEXTEST_COMPLETION_CACHING=1 c-index-test -code-completion-at=%s:7:14 -Xclang -code-completion-patterns %s | FileCheck -check-prefix=CHECK-CC1 %s -// RUN: c-index-test -code-completion-at=%s:7:18 -Xclang -code-completion-patterns %s | FileCheck -check-prefix=CHECK-CC3 %s -// RUN: c-index-test -code-completion-at=%s:7:22 -Xclang -code-completion-patterns %s | FileCheck -check-prefix=CHECK-CC3 %s +// RUN: c-index-test -code-completion-at=%s:7:18 -Xclang -code-completion-patterns %s | FileCheck -check-prefix=CHECK-CC1 %s +// RUN: c-index-test -code-completion-at=%s:7:22 -Xclang -code-completion-patterns %s | FileCheck -check-prefix=CHECK-CC1 %s // RUN: c-index-test -code-completion-at=%s:7:2 -Xclang -code-completion-patterns %s | FileCheck -check-prefix=CHECK-CC2 %s // CHECK-CC2: macro definition:{TypedText __VERSION__} (70) // CHECK-CC2: FunctionDecl:{ResultType int}{TypedText f}{LeftParen (}{Placeholder int}{RightParen )} (50) Index: unittests/Sema/CodeCompleteTest.cpp =================================================================== --- unittests/Sema/CodeCompleteTest.cpp +++ unittests/Sema/CodeCompleteTest.cpp @@ -14,31 +14,39 @@ #include "clang/Sema/Sema.h" #include "clang/Sema/SemaDiagnostic.h" #include "clang/Tooling/Tooling.h" -#include "gtest/gtest.h" #include "gmock/gmock.h" +#include "gtest/gtest.h" +#include +#include namespace { using namespace clang; using namespace clang::tooling; +using ::testing::Each; using ::testing::UnorderedElementsAre; const char TestCCName[] = "test.cc"; -using VisitedContextResults = std::vector; -class VisitedContextFinder: public CodeCompleteConsumer { +struct CompletionContext { + std::vector VisitedNamespaces; + std::string PreferredType; +}; + +class VisitedContextFinder : public CodeCompleteConsumer { public: - VisitedContextFinder(VisitedContextResults &Results) + VisitedContextFinder(CompletionContext &ResultCtx) : CodeCompleteConsumer(/*CodeCompleteOpts=*/{}, /*CodeCompleteConsumer*/ false), - VCResults(Results), + ResultCtx(ResultCtx), CCTUInfo(std::make_shared()) {} void ProcessCodeCompleteResults(Sema &S, CodeCompletionContext Context, CodeCompletionResult *Results, unsigned NumResults) override { - VisitedContexts = Context.getVisitedContexts(); - VCResults = getVisitedNamespace(); + ResultCtx.VisitedNamespaces = + getVisitedNamespace(Context.getVisitedContexts()); + ResultCtx.PreferredType = Context.getPreferredType().getAsString(); } CodeCompletionAllocator &getAllocator() override { @@ -47,7 +55,9 @@ CodeCompletionTUInfo &getCodeCompletionTUInfo() override { return CCTUInfo; } - std::vector getVisitedNamespace() const { +private: + std::vector getVisitedNamespace( + CodeCompletionContext::VisitedContextSet VisitedContexts) const { std::vector NSNames; for (const auto *Context : VisitedContexts) if (const auto *NS = llvm::dyn_cast(Context)) @@ -55,27 +65,25 @@ return NSNames; } -private: - VisitedContextResults& VCResults; + CompletionContext &ResultCtx; CodeCompletionTUInfo CCTUInfo; - CodeCompletionContext::VisitedContextSet VisitedContexts; }; class CodeCompleteAction : public SyntaxOnlyAction { public: - CodeCompleteAction(ParsedSourceLocation P, VisitedContextResults &Results) - : CompletePosition(std::move(P)), VCResults(Results) {} + CodeCompleteAction(ParsedSourceLocation P, CompletionContext &ResultCtx) + : CompletePosition(std::move(P)), ResultCtx(ResultCtx) {} bool BeginInvocation(CompilerInstance &CI) override { CI.getFrontendOpts().CodeCompletionAt = CompletePosition; - CI.setCodeCompletionConsumer(new VisitedContextFinder(VCResults)); + CI.setCodeCompletionConsumer(new VisitedContextFinder(ResultCtx)); return true; } private: // 1-based code complete position ; ParsedSourceLocation CompletePosition; - VisitedContextResults& VCResults; + CompletionContext &ResultCtx; }; ParsedSourceLocation offsetToPosition(llvm::StringRef Code, size_t Offset) { @@ -88,21 +96,49 @@ static_cast(Offset - StartOfLine + 1)}; } -VisitedContextResults runCodeCompleteOnCode(StringRef Code) { - VisitedContextResults Results; - auto TokenOffset = Code.find('^'); - assert(TokenOffset != StringRef::npos && - "Completion token ^ wasn't found in Code."); - std::string WithoutToken = Code.take_front(TokenOffset); - WithoutToken += Code.drop_front(WithoutToken.size() + 1); - assert(StringRef(WithoutToken).find('^') == StringRef::npos && - "expected exactly one completion token ^ inside the code"); - +CompletionContext runCompletion(StringRef Code, size_t Offset) { + CompletionContext ResultCtx; auto Action = llvm::make_unique( - offsetToPosition(WithoutToken, TokenOffset), Results); + offsetToPosition(Code, Offset), ResultCtx); clang::tooling::runToolOnCodeWithArgs(Action.release(), Code, {"-std=c++11"}, TestCCName); - return Results; + return ResultCtx; +} + +struct ParsedAnnotations { + std::vector Points; + std::string Code; +}; + +ParsedAnnotations parseAnnotations(StringRef AnnotatedCode) { + ParsedAnnotations R; + while (!AnnotatedCode.empty()) { + size_t NextPoint = AnnotatedCode.find('^'); + if (NextPoint == StringRef::npos) { + R.Code += AnnotatedCode; + AnnotatedCode = ""; + break; + } + R.Code += AnnotatedCode.substr(0, NextPoint); + R.Points.push_back(R.Code.size()); + + AnnotatedCode = AnnotatedCode.substr(NextPoint + 1); + } + return R; +} + +CompletionContext runCodeCompleteOnCode(StringRef AnnotatedCode) { + ParsedAnnotations P = parseAnnotations(AnnotatedCode); + assert(P.Points.size() == 1 && "expected exactly one annotation point"); + return runCompletion(P.Code, P.Points.front()); +} + +std::vector collectPreferredTypes(StringRef AnnotatedCode) { + ParsedAnnotations P = parseAnnotations(AnnotatedCode); + std::vector Types; + for (size_t Point : P.Points) + Types.push_back(runCompletion(P.Code, Point).PreferredType); + return Types; } TEST(SemaCodeCompleteTest, VisitedNSForValidQualifiedId) { @@ -119,7 +155,8 @@ inline namespace bar { using namespace ns3::nns3; } } // foo namespace ns { foo::^ } - )cpp"); + )cpp") + .VisitedNamespaces; EXPECT_THAT(VisitedNS, UnorderedElementsAre("foo", "ns1", "ns2", "ns3::nns3", "foo::(anonymous)")); } @@ -127,7 +164,8 @@ TEST(SemaCodeCompleteTest, VisitedNSForInvalideQualifiedId) { auto VisitedNS = runCodeCompleteOnCode(R"cpp( namespace ns { foo::^ } - )cpp"); + )cpp") + .VisitedNamespaces; EXPECT_TRUE(VisitedNS.empty()); } @@ -138,8 +176,150 @@ void f(^) {} } } - )cpp"); + )cpp") + .VisitedNamespaces; EXPECT_THAT(VisitedNS, UnorderedElementsAre("n1", "n1::n2")); } +TEST(PreferredTypeTest, BinaryExpr) { + // Check various operations for arithmetic types. + EXPECT_THAT(collectPreferredTypes(R"cpp( + void test(int x) { + x = ^10; + x += ^10; x -= ^10; x *= ^10; x /= ^10; x %= ^10; + x + ^10; x - ^10; x * ^10; x / ^10; x % ^10; + })cpp"), + Each("int")); + EXPECT_THAT(collectPreferredTypes(R"cpp( + void test(float x) { + x = ^10; + x += ^10; x -= ^10; x *= ^10; x /= ^10; x %= ^10; + x + ^10; x - ^10; x * ^10; x / ^10; x % ^10; + })cpp"), + Each("float")); + + // Pointer types. + EXPECT_THAT(collectPreferredTypes(R"cpp( + void test(int *ptr) { + ptr - ^ptr; + ptr = ^ptr; + })cpp"), + Each("int *")); + + EXPECT_THAT(collectPreferredTypes(R"cpp( + void test(int *ptr) { + ptr + ^10; + ptr += ^10; + ptr -= ^10; + })cpp"), + Each("long")); // long is normalized 'ptrdiff_t'. + + // Comparison operators. + EXPECT_THAT(collectPreferredTypes(R"cpp( + void test(int i) { + i <= ^1; i < ^1; i >= ^1; i > ^1; i == ^1; i != ^1; + } + )cpp"), + Each("int")); + + EXPECT_THAT(collectPreferredTypes(R"cpp( + void test(int *ptr) { + ptr <= ^ptr; ptr < ^ptr; ptr >= ^ptr; ptr > ^ptr; + ptr == ^ptr; ptr != ^ptr; + } + )cpp"), + Each("int *")); + + // Relational operations. + EXPECT_THAT(collectPreferredTypes(R"cpp( + void test(int i, int *ptr) { + i && ^1; i || ^1; + ptr && ^1; ptr || ^1; + } + )cpp"), + Each("_Bool")); + + // Bitwise operations. + EXPECT_THAT(collectPreferredTypes(R"cpp( + void test(long long ll) { + ll | ^1; ll & ^1; + } + )cpp"), + Each("long long")); + + EXPECT_THAT(collectPreferredTypes(R"cpp( + enum A {}; + void test(A a) { + a | ^1; a & ^1; + } + )cpp"), + Each("enum A")); + + EXPECT_THAT(collectPreferredTypes(R"cpp( + enum class A {}; + void test(A a) { + // This is technically illegal with the 'enum class' without overloaded + // operators, but we pretend it's fine. + a | ^a; a & ^a; + } + )cpp"), + Each("enum A")); + + // Binary shifts. + EXPECT_THAT(collectPreferredTypes(R"cpp( + void test(int i, long long ll) { + i << ^1; ll << ^1; + i <<= ^1; i <<= ^1; + i >> ^1; ll >> ^1; + i >>= ^1; i >>= ^1; + } + )cpp"), + Each("int")); + + // Comma does not provide any useful information. + EXPECT_THAT(collectPreferredTypes(R"cpp( + class Cls {}; + void test(int i, int* ptr, Cls x) { + (i, ^i); + (ptr, ^ptr); + (x, ^x); + } + )cpp"), + Each("NULL TYPE")); + + // User-defined types do not take operator overloading into account. + // However, they provide heuristics for some common cases. + EXPECT_THAT(collectPreferredTypes(R"cpp( + class Cls {}; + void test(Cls c) { + // we assume arithmetic and comparions ops take the same type. + c + ^c; c - ^c; c * ^c; c / ^c; c % ^c; + c == ^c; c != ^c; c < ^c; c <= ^c; c > ^c; c >= ^c; + // same for the assignments. + c = ^c; c += ^c; c -= ^c; c *= ^c; c /= ^c; c %= ^c; + } + )cpp"), + Each("class Cls")); + + EXPECT_THAT(collectPreferredTypes(R"cpp( + class Cls {}; + void test(Cls c) { + // we assume relational ops operate on bools. + c && ^c; c || ^c; + } + )cpp"), + Each("_Bool")); + + EXPECT_THAT(collectPreferredTypes(R"cpp( + class Cls {}; + void test(Cls c) { + // we make no assumptions about the following operators, since they are + // often overloaded with a non-standard meaning. + c << ^c; c >> ^c; c | ^c; c & ^c; + c <<= ^c; c >>= ^c; c |= ^c; c &= ^c; + } + )cpp"), + Each("NULL TYPE")); +} + } // namespace