diff --git a/clang/include/clang/Tooling/Transformer/SourceCodeBuilders.h b/clang/include/clang/Tooling/Transformer/SourceCodeBuilders.h --- a/clang/include/clang/Tooling/Transformer/SourceCodeBuilders.h +++ b/clang/include/clang/Tooling/Transformer/SourceCodeBuilders.h @@ -43,6 +43,14 @@ /// Determines whether printing this expression to the right of a unary operator /// requires a parentheses to preserve its meaning. bool needParensAfterUnaryOperator(const Expr &E); + +// Heuristic that guesses whether `Ty` is a "smart-pointer" type based on its +// name or overloaded operators. +bool isSmartPointerType(QualType Ty, ASTContext &Context); + +// Identifies use of `operator*` on smart pointers, and returns the underlying +// smart-pointer expression; otherwise, returns null. +const Expr *isSmartDereference(const Expr &E, ASTContext &Context); /// @} /// \name Basic code-string generation utilities. @@ -79,6 +87,18 @@ /// `a+b` becomes `(a+b)->` llvm::Optional buildArrow(const Expr &E, const ASTContext &Context); + +/// Adds an appropriate access operator (`.`, `->` or nothing, in the case of +/// implicit `this`) to the end of the given expression, but adds parentheses +/// when needed by the syntax, strips any `operator->` class and simplifies when +/// possible. For example: +/// +/// `x` becomes `x->` or `x.`, depending on `E`'s type +/// `x.operator->()` becomes `x->` +/// `a+b` becomes `(a+b)->` or `(a+b).`, depending on `E`'s type +/// `&a` becomes `a.` +/// `*a` becomes `a->` +llvm::Optional buildAccess(const Expr &E, ASTContext &Context); /// @} } // namespace tooling diff --git a/clang/lib/Tooling/Transformer/SourceCodeBuilders.cpp b/clang/lib/Tooling/Transformer/SourceCodeBuilders.cpp --- a/clang/lib/Tooling/Transformer/SourceCodeBuilders.cpp +++ b/clang/lib/Tooling/Transformer/SourceCodeBuilders.cpp @@ -10,6 +10,8 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" #include "clang/AST/ExprCXX.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/ASTMatchers/ASTMatchers.h" #include "clang/Tooling/Transformer/SourceCode.h" #include "llvm/ADT/Twine.h" #include @@ -60,6 +62,37 @@ return false; } +// FIXME: Consider memoizing this function using the `ASTContext`. +bool tooling::isSmartPointerType(QualType Ty, ASTContext &Context) { + using namespace ast_matchers; + + // Optimization: hard-code common smart-pointer types. This can/should be + // removed if we start caching the results of this function. + auto KnownSmartPointer = + cxxRecordDecl(hasAnyName("::std::unique_ptr", "::std::shared_ptr")); + const auto QuacksLikeASmartPointer = cxxRecordDecl( + hasMethod(cxxMethodDecl(hasOverloadedOperatorName("->"), + returns(qualType(pointsTo(type()))))), + hasMethod(cxxMethodDecl(hasOverloadedOperatorName("*"), + returns(qualType(references(type())))))); + const auto SmartPointer = qualType(hasDeclaration( + cxxRecordDecl(anyOf(KnownSmartPointer, QuacksLikeASmartPointer)))); + return match(SmartPointer, Ty, Context).size() > 0; +} + +const Expr *tooling::isSmartDereference(const Expr &E, ASTContext &Context) { + using namespace ::clang::ast_matchers; + + const auto HasOverloadedArrow = cxxRecordDecl(hasMethod(cxxMethodDecl( + hasOverloadedOperatorName("->"), returns(qualType(pointsTo(type())))))); + // Verify it is a smart pointer by finding `operator->` in the class + // declaration. + auto Deref = cxxOperatorCallExpr( + hasOverloadedOperatorName("*"), hasUnaryOperand(expr().bind("arg")), + callee(cxxMethodDecl(ofClass(HasOverloadedArrow)))); + return selectFirst("arg", match(Deref, E, Context)); +} + llvm::Optional tooling::buildParens(const Expr &E, const ASTContext &Context) { StringRef Text = getText(E, Context); @@ -160,3 +193,33 @@ return ("(" + Text + ")->").str(); return (Text + "->").str(); } + +llvm::Optional tooling::buildAccess(const Expr &E, + ASTContext &Context) { + // We return the empty string, because `None` signifies some sort of failure. + if (E.isImplicitCXXThis()) + return std::string(); + + if (E.getType()->isAnyPointerType() || + isSmartPointerType(E.getType(), Context)) { + // Strip off any operator->. This can only occur inside an actual arrow + // member access, so we treat it as equivalent to an actual object + // expression. + const Expr *ENorm = &E; + if (const auto *OpCall = dyn_cast(ENorm)) { + if (OpCall->getOperator() == clang::OO_Arrow && + OpCall->getNumArgs() == 1) { + ENorm = OpCall->getArg(0); + } + } + return tooling::buildArrow(*ENorm, Context); + } + + if (const auto *Operand = isSmartDereference(E, Context)) { + // `buildDot` already handles the built-in dereference operator, so we + // only need to catch overloaded `operator*`. + return tooling::buildArrow(*Operand, Context); + } + + return tooling::buildDot(E, Context); +} diff --git a/clang/lib/Tooling/Transformer/Stencil.cpp b/clang/lib/Tooling/Transformer/Stencil.cpp --- a/clang/lib/Tooling/Transformer/Stencil.cpp +++ b/clang/lib/Tooling/Transformer/Stencil.cpp @@ -11,7 +11,6 @@ #include "clang/AST/ASTTypeTraits.h" #include "clang/AST/Expr.h" #include "clang/ASTMatchers/ASTMatchFinder.h" -#include "clang/ASTMatchers/ASTMatchers.h" #include "clang/Basic/SourceLocation.h" #include "clang/Lex/Lexer.h" #include "clang/Tooling/Transformer/SourceCode.h" @@ -56,39 +55,6 @@ return Error::success(); } -// FIXME: Consider memoizing this function using the `ASTContext`. -static bool isSmartPointerType(QualType Ty, ASTContext &Context) { - using namespace ::clang::ast_matchers; - - // Optimization: hard-code common smart-pointer types. This can/should be - // removed if we start caching the results of this function. - auto KnownSmartPointer = - cxxRecordDecl(hasAnyName("::std::unique_ptr", "::std::shared_ptr")); - const auto QuacksLikeASmartPointer = cxxRecordDecl( - hasMethod(cxxMethodDecl(hasOverloadedOperatorName("->"), - returns(qualType(pointsTo(type()))))), - hasMethod(cxxMethodDecl(hasOverloadedOperatorName("*"), - returns(qualType(references(type())))))); - const auto SmartPointer = qualType(hasDeclaration( - cxxRecordDecl(anyOf(KnownSmartPointer, QuacksLikeASmartPointer)))); - return match(SmartPointer, Ty, Context).size() > 0; -} - -// Identifies use of `operator*` on smart pointers, and returns the underlying -// smart-pointer expression; otherwise, returns null. -static const Expr *isSmartDereference(const Expr &E, ASTContext &Context) { - using namespace ::clang::ast_matchers; - - const auto HasOverloadedArrow = cxxRecordDecl(hasMethod(cxxMethodDecl( - hasOverloadedOperatorName("->"), returns(qualType(pointsTo(type())))))); - // Verify it is a smart pointer by finding `operator->` in the class - // declaration. - auto Deref = cxxOperatorCallExpr( - hasOverloadedOperatorName("*"), hasUnaryOperand(expr().bind("arg")), - callee(cxxMethodDecl(ofClass(HasOverloadedArrow)))); - return selectFirst("arg", match(Deref, E, Context)); -} - namespace { // An arbitrary fragment of code within a stencil. class RawTextStencil : public StencilInterface { @@ -196,7 +162,7 @@ break; case UnaryNodeOperator::MaybeDeref: if (E->getType()->isAnyPointerType() || - isSmartPointerType(E->getType(), *Match.Context)) { + tooling::isSmartPointerType(E->getType(), *Match.Context)) { // Strip off any operator->. This can only occur inside an actual arrow // member access, so we treat it as equivalent to an actual object // expression. @@ -216,7 +182,7 @@ break; case UnaryNodeOperator::MaybeAddressOf: if (E->getType()->isAnyPointerType() || - isSmartPointerType(E->getType(), *Match.Context)) { + tooling::isSmartPointerType(E->getType(), *Match.Context)) { // Strip off any operator->. This can only occur inside an actual arrow // member access, so we treat it as equivalent to an actual object // expression. @@ -311,34 +277,12 @@ if (E == nullptr) return llvm::make_error(errc::invalid_argument, "Id not bound: " + BaseId); - if (!E->isImplicitCXXThis()) { - llvm::Optional S; - if (E->getType()->isAnyPointerType() || - isSmartPointerType(E->getType(), *Match.Context)) { - // Strip off any operator->. This can only occur inside an actual arrow - // member access, so we treat it as equivalent to an actual object - // expression. - if (const auto *OpCall = dyn_cast(E)) { - if (OpCall->getOperator() == clang::OO_Arrow && - OpCall->getNumArgs() == 1) { - E = OpCall->getArg(0); - } - } - S = tooling::buildArrow(*E, *Match.Context); - } else if (const auto *Operand = isSmartDereference(*E, *Match.Context)) { - // `buildDot` already handles the built-in dereference operator, so we - // only need to catch overloaded `operator*`. - S = tooling::buildArrow(*Operand, *Match.Context); - } else { - S = tooling::buildDot(*E, *Match.Context); - } - if (S.hasValue()) - *Result += *S; - else - return llvm::make_error( - errc::invalid_argument, - "Could not construct object text from ID: " + BaseId); - } + llvm::Optional S = tooling::buildAccess(*E, *Match.Context); + if (!S.hasValue()) + return llvm::make_error( + errc::invalid_argument, + "Could not construct object text from ID: " + BaseId); + *Result += *S; return Member->eval(Match, Result); } }; diff --git a/clang/unittests/Tooling/SourceCodeBuildersTest.cpp b/clang/unittests/Tooling/SourceCodeBuildersTest.cpp --- a/clang/unittests/Tooling/SourceCodeBuildersTest.cpp +++ b/clang/unittests/Tooling/SourceCodeBuildersTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "clang/Tooling/Transformer/SourceCodeBuilders.h" +#include "clang/AST/Type.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/ASTMatchers/ASTMatchers.h" #include "clang/Tooling/Tooling.h" @@ -24,8 +25,16 @@ // Create a valid translation unit from a statement. static std::string wrapSnippet(StringRef StatementCode) { - return ("struct S { S(); S(int); int field; };\n" + return ("namespace std {\n" + "template class unique_ptr {};\n" + "template class shared_ptr {};\n" + "}\n" + "struct S { S(); S(int); int field; };\n" "S operator+(const S &a, const S &b);\n" + "struct Smart {\n" + " S* operator->() const;\n" + " S& operator*() const;\n" + "};\n" "auto test_snippet = []{" + StatementCode + "};") .str(); @@ -126,6 +135,69 @@ testPredicateOnArg(mayEverNeedParens, "void f(S); f(3 + 5);", true); } +TEST(SourceCodeBuildersTest, isSmartPointerTypeUniquePtr) { + std::string Snippet = "std::unique_ptr P; P;"; + auto StmtMatch = matchStmt(Snippet, expr(hasType(qualType().bind("ty")))); + ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet; + EXPECT_TRUE( + isSmartPointerType(*StmtMatch->Result.Nodes.getNodeAs("ty"), + *StmtMatch->Result.Context)) + << "Snippet: " << Snippet; +} + +TEST(SourceCodeBuildersTest, isSmartPointerTypeSharedPtr) { + std::string Snippet = "std::shared_ptr P; P;"; + auto StmtMatch = matchStmt(Snippet, expr(hasType(qualType().bind("ty")))); + ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet; + EXPECT_TRUE( + isSmartPointerType(*StmtMatch->Result.Nodes.getNodeAs("ty"), + *StmtMatch->Result.Context)) + << "Snippet: " << Snippet; +} + +TEST(SourceCodeBuildersTest, isSmartPointerTypeDuckType) { + std::string Snippet = "Smart P; P;"; + auto StmtMatch = matchStmt(Snippet, expr(hasType(qualType().bind("ty")))); + ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet; + EXPECT_TRUE( + isSmartPointerType(*StmtMatch->Result.Nodes.getNodeAs("ty"), + *StmtMatch->Result.Context)) + << "Snippet: " << Snippet; +} + +TEST(SourceCodeBuildersTest, isSmartPointerTypeNormalTypeFalse) { + std::string Snippet = "int *P; P;"; + auto StmtMatch = matchStmt(Snippet, expr(hasType(qualType().bind("ty")))); + ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet; + EXPECT_FALSE( + isSmartPointerType(*StmtMatch->Result.Nodes.getNodeAs("ty"), + *StmtMatch->Result.Context)) + << "Snippet: " << Snippet; +} + +TEST(SourceCodeBuildersTest, isSmartDereferenceTrue) { + std::string Snippet = "Smart P; *P;"; + auto StmtMatch = matchStmt( + Snippet, expr(cxxOperatorCallExpr(hasUnaryOperand(expr().bind("arg")))) + .bind("expr")); + ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet; + const auto *Arg = StmtMatch->Result.Nodes.getNodeAs("arg"); + EXPECT_EQ(Arg, + isSmartDereference(*StmtMatch->Result.Nodes.getNodeAs("expr"), + *StmtMatch->Result.Context)) + << "Snippet: " << Snippet; +} + +TEST(SourceCodeBuildersTest, isSmartDereferenceFalse) { + std::string Snippet = "int *P; *P;"; + auto StmtMatch = matchStmt(Snippet, expr().bind("expr")); + ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet; + EXPECT_EQ(nullptr, + isSmartDereference(*StmtMatch->Result.Nodes.getNodeAs("expr"), + *StmtMatch->Result.Context)) + << "Snippet: " << Snippet; +} + static void testBuilder( llvm::Optional (*Builder)(const Expr &, const ASTContext &), StringRef Snippet, StringRef Expected) { @@ -136,6 +208,16 @@ ValueIs(std::string(Expected))); } +static void testBuilder(llvm::Optional (*Builder)(const Expr &, + ASTContext &), + StringRef Snippet, StringRef Expected) { + auto StmtMatch = matchStmt(Snippet, expr().bind("expr")); + ASSERT_TRUE(StmtMatch); + EXPECT_THAT(Builder(*StmtMatch->Result.Nodes.getNodeAs("expr"), + *StmtMatch->Result.Context), + ValueIs(std::string(Expected))); +} + TEST(SourceCodeBuildersTest, BuildParensUnaryOp) { testBuilder(buildParens, "-4;", "(-4)"); } @@ -245,4 +327,83 @@ TEST(SourceCodeBuildersTest, BuildArrowValueAddressWithParens) { testBuilder(buildArrow, "S x; &(true ? x : x);", "(true ? x : x)."); } + +TEST(SourceCodeBuildersTest, BuildAccessValue) { + testBuilder(buildAccess, "S x; x;", "x."); +} + +TEST(SourceCodeBuildersTest, BuildAccessPointerDereference) { + testBuilder(buildAccess, "S *x; *x;", "x->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessPointerDereferenceIgnoresParens) { + testBuilder(buildAccess, "S *x; *(x);", "x->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessValueBinaryOperation) { + testBuilder(buildAccess, "S x; x + x;", "(x + x)."); +} + +TEST(SourceCodeBuildersTest, BuildAccessPointerDereferenceExprWithParens) { + testBuilder(buildAccess, "S *x; *(x + 1);", "(x + 1)->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessPointer) { + testBuilder(buildAccess, "S *x; x;", "x->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessValueAddress) { + testBuilder(buildAccess, "S x; &x;", "x."); +} + +TEST(SourceCodeBuildersTest, BuildAccessValueAddressIgnoresParens) { + testBuilder(buildAccess, "S x; &(x);", "x."); +} + +TEST(SourceCodeBuildersTest, BuildAccessPointerBinaryOperation) { + testBuilder(buildAccess, "S *x; x + 1;", "(x + 1)->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessValueAddressWithParens) { + testBuilder(buildAccess, "S x; &(true ? x : x);", "(true ? x : x)."); +} + +TEST(SourceCodeBuildersTest, BuildAccessSmartPointer) { + testBuilder(buildAccess, "Smart x; x;", "x->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessSmartPointerDeref) { + testBuilder(buildAccess, "Smart x; *x;", "x->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessSmartPointerMemberCall) { + StringRef Snippet = R"cc( + Smart x; + x->Field; + )cc"; + auto StmtMatch = + matchStmt(Snippet, memberExpr(hasObjectExpression(expr().bind("expr")))); + ASSERT_TRUE(StmtMatch); + EXPECT_THAT(buildAccess(*StmtMatch->Result.Nodes.getNodeAs("expr"), + *StmtMatch->Result.Context), + ValueIs(std::string("x->"))); +} + +TEST(SourceCodeBuildersTest, BuildAccessImplicitThis) { + StringRef Snippet = R"cc( + struct Struct { + void foo() {} + void bar() { + foo(); + } + }; + )cc"; + auto StmtMatch = matchStmt( + Snippet, + cxxMemberCallExpr(onImplicitObjectArgument(cxxThisExpr().bind("expr")))); + ASSERT_TRUE(StmtMatch); + EXPECT_THAT(buildAccess(*StmtMatch->Result.Nodes.getNodeAs("expr"), + *StmtMatch->Result.Context), + ValueIs(std::string())); +} } // namespace