diff --git a/clang/include/clang/Tooling/Transformer/SourceCodeBuilders.h b/clang/include/clang/Tooling/Transformer/SourceCodeBuilders.h --- a/clang/include/clang/Tooling/Transformer/SourceCodeBuilders.h +++ b/clang/include/clang/Tooling/Transformer/SourceCodeBuilders.h @@ -43,6 +43,13 @@ /// Determines whether printing this expression to the right of a unary operator /// requires a parentheses to preserve its meaning. bool needParensAfterUnaryOperator(const Expr &E); + +// Recognizes known types (and sugared versions thereof) that overload the `*` +// and `->` operator. Currently, included types are exactly: +// +// * std::unique_ptr, std::shared_ptr, std::weak_ptr, +// * std::optional, absl::optional. +bool isKnownPointerLikeType(QualType Ty, ASTContext &Context); /// @} /// \name Basic code-string generation utilities. @@ -69,6 +76,8 @@ /// `x` becomes `x.` /// `*a` becomes `a->` /// `a+b` becomes `(a+b).` +/// +/// DEPRECATED. Use `buildAccess`. llvm::Optional buildDot(const Expr &E, const ASTContext &Context); /// Adds an arrow to the end of the given expression, but adds parentheses @@ -77,8 +86,32 @@ /// `x` becomes `x->` /// `&a` becomes `a.` /// `a+b` becomes `(a+b)->` +/// +/// DEPRECATED. Use `buildAccess`. llvm::Optional buildArrow(const Expr &E, const ASTContext &Context); + +/// Specifies how to classify pointer-like types -- like values or like pointers +/// -- with regard to generating member-access syntax. +enum class PLTClass : bool { + Value, + Pointer, +}; + +/// Adds an appropriate access operator (`.`, `->` or nothing, in the case of +/// implicit `this`) to the end of the given expression. Adds parentheses when +/// needed by the syntax and simplifies when possible. If `PLTypeClass` is +/// `Pointer`, for known pointer-like types (see `isKnownPointerLikeType`), +/// treats `operator->` and `operator*` like the built-in `->` and `*` +/// operators. +/// +/// `x` becomes `x->` or `x.`, depending on `E`'s type +/// `a+b` becomes `(a+b)->` or `(a+b).`, depending on `E`'s type +/// `&a` becomes `a.` +/// `*a` becomes `a->` +llvm::Optional +buildAccess(const Expr &E, ASTContext &Context, + PLTClass Classification = PLTClass::Pointer); /// @} } // namespace tooling diff --git a/clang/lib/Tooling/Transformer/SourceCodeBuilders.cpp b/clang/lib/Tooling/Transformer/SourceCodeBuilders.cpp --- a/clang/lib/Tooling/Transformer/SourceCodeBuilders.cpp +++ b/clang/lib/Tooling/Transformer/SourceCodeBuilders.cpp @@ -10,6 +10,8 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" #include "clang/AST/ExprCXX.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/ASTMatchers/ASTMatchers.h" #include "clang/Tooling/Transformer/SourceCode.h" #include "llvm/ADT/Twine.h" #include @@ -60,6 +62,15 @@ return false; } +bool tooling::isKnownPointerLikeType(QualType Ty, ASTContext &Context) { + using namespace ast_matchers; + const auto PointerLikeTy = + type(hasUnqualifiedDesugaredType(recordType(hasDeclaration(cxxRecordDecl( + hasAnyName("::std::unique_ptr", "::std::shared_ptr", + "::std::weak_ptr", "std::optional", "absl::optional")))))); + return match(PointerLikeTy, Ty, Context).size() > 0; +} + llvm::Optional tooling::buildParens(const Expr &E, const ASTContext &Context) { StringRef Text = getText(E, Context); @@ -114,8 +125,10 @@ return ("&" + Text).str(); } -llvm::Optional tooling::buildDot(const Expr &E, - const ASTContext &Context) { +// Append the appropriate access operation (syntactically) to `E`, assuming `E` +// is a non-pointer value. +static llvm::Optional +buildAccessForValue(const Expr &E, const ASTContext &Context) { if (const auto *Op = llvm::dyn_cast(&E)) if (Op->getOpcode() == UO_Deref) { // Strip leading '*', add following '->'. @@ -138,8 +151,10 @@ return (Text + ".").str(); } -llvm::Optional tooling::buildArrow(const Expr &E, - const ASTContext &Context) { +// Append the appropriate access operation (syntactically) to `E`, assuming `E` +// is a pointer value. +static llvm::Optional +buildAccessForPointer(const Expr &E, const ASTContext &Context) { if (const auto *Op = llvm::dyn_cast(&E)) if (Op->getOpcode() == UO_AddrOf) { // Strip leading '&', add following '.'. @@ -160,3 +175,62 @@ return ("(" + Text + ")->").str(); return (Text + "->").str(); } + +llvm::Optional tooling::buildDot(const Expr &E, + const ASTContext &Context) { + return buildAccessForValue(E, Context); +} + +llvm::Optional tooling::buildArrow(const Expr &E, + const ASTContext &Context) { + return buildAccessForPointer(E, Context); +} + +// If `E` is an overloaded-operator call of kind `K` on an object `O`, returns +// `O`. Otherwise, returns `nullptr`. +static const Expr *maybeGetOperatorObjectArg(const Expr &E, + OverloadedOperatorKind K) { + if (const auto *OpCall = dyn_cast(&E)) { + if (OpCall->getOperator() == K && OpCall->getNumArgs() == 1) + return OpCall->getArg(0); + } + return nullptr; +} + +static bool treatLikePointer(QualType Ty, PLTClass C, ASTContext &Context) { + switch (C) { + case PLTClass::Value: + return false; + case PLTClass::Pointer: + return isKnownPointerLikeType(Ty, Context); + } +} + +// FIXME: move over the other `maybe` functionality from Stencil. Should all be +// in one place. +llvm::Optional tooling::buildAccess(const Expr &RawExpression, + ASTContext &Context, + PLTClass Classification) { + if (RawExpression.isImplicitCXXThis()) + // Return the empty string, because `None` signifies some sort of failure. + return std::string(); + + const Expr *E = RawExpression.IgnoreImplicitAsWritten(); + + if (E->getType()->isAnyPointerType() || + treatLikePointer(E->getType(), Classification, Context)) { + // Strip off operator-> calls. They can only occur inside an actual arrow + // member access, so we treat them as equivalent to an actual object + // expression. + if (const auto *Obj = maybeGetOperatorObjectArg(*E, clang::OO_Arrow)) + E = Obj; + return buildAccessForPointer(*E, Context); + } + + if (const auto *Obj = maybeGetOperatorObjectArg(*E, clang::OO_Star)) { + if (treatLikePointer(Obj->getType(), Classification, Context)) + return buildAccessForPointer(*Obj, Context); + }; + + return buildAccessForValue(*E, Context); +} diff --git a/clang/lib/Tooling/Transformer/Stencil.cpp b/clang/lib/Tooling/Transformer/Stencil.cpp --- a/clang/lib/Tooling/Transformer/Stencil.cpp +++ b/clang/lib/Tooling/Transformer/Stencil.cpp @@ -11,7 +11,6 @@ #include "clang/AST/ASTTypeTraits.h" #include "clang/AST/Expr.h" #include "clang/ASTMatchers/ASTMatchFinder.h" -#include "clang/ASTMatchers/ASTMatchers.h" #include "clang/Basic/SourceLocation.h" #include "clang/Lex/Lexer.h" #include "clang/Tooling/Transformer/SourceCode.h" @@ -56,39 +55,6 @@ return Error::success(); } -// FIXME: Consider memoizing this function using the `ASTContext`. -static bool isSmartPointerType(QualType Ty, ASTContext &Context) { - using namespace ::clang::ast_matchers; - - // Optimization: hard-code common smart-pointer types. This can/should be - // removed if we start caching the results of this function. - auto KnownSmartPointer = - cxxRecordDecl(hasAnyName("::std::unique_ptr", "::std::shared_ptr")); - const auto QuacksLikeASmartPointer = cxxRecordDecl( - hasMethod(cxxMethodDecl(hasOverloadedOperatorName("->"), - returns(qualType(pointsTo(type()))))), - hasMethod(cxxMethodDecl(hasOverloadedOperatorName("*"), - returns(qualType(references(type())))))); - const auto SmartPointer = qualType(hasDeclaration( - cxxRecordDecl(anyOf(KnownSmartPointer, QuacksLikeASmartPointer)))); - return match(SmartPointer, Ty, Context).size() > 0; -} - -// Identifies use of `operator*` on smart pointers, and returns the underlying -// smart-pointer expression; otherwise, returns null. -static const Expr *isSmartDereference(const Expr &E, ASTContext &Context) { - using namespace ::clang::ast_matchers; - - const auto HasOverloadedArrow = cxxRecordDecl(hasMethod(cxxMethodDecl( - hasOverloadedOperatorName("->"), returns(qualType(pointsTo(type())))))); - // Verify it is a smart pointer by finding `operator->` in the class - // declaration. - auto Deref = cxxOperatorCallExpr( - hasOverloadedOperatorName("*"), hasUnaryOperand(expr().bind("arg")), - callee(cxxMethodDecl(ofClass(HasOverloadedArrow)))); - return selectFirst("arg", match(Deref, E, Context)); -} - namespace { // An arbitrary fragment of code within a stencil. class RawTextStencil : public StencilInterface { @@ -196,7 +162,7 @@ break; case UnaryNodeOperator::MaybeDeref: if (E->getType()->isAnyPointerType() || - isSmartPointerType(E->getType(), *Match.Context)) { + tooling::isKnownPointerLikeType(E->getType(), *Match.Context)) { // Strip off any operator->. This can only occur inside an actual arrow // member access, so we treat it as equivalent to an actual object // expression. @@ -216,7 +182,7 @@ break; case UnaryNodeOperator::MaybeAddressOf: if (E->getType()->isAnyPointerType() || - isSmartPointerType(E->getType(), *Match.Context)) { + tooling::isKnownPointerLikeType(E->getType(), *Match.Context)) { // Strip off any operator->. This can only occur inside an actual arrow // member access, so we treat it as equivalent to an actual object // expression. @@ -311,34 +277,12 @@ if (E == nullptr) return llvm::make_error(errc::invalid_argument, "Id not bound: " + BaseId); - if (!E->isImplicitCXXThis()) { - llvm::Optional S; - if (E->getType()->isAnyPointerType() || - isSmartPointerType(E->getType(), *Match.Context)) { - // Strip off any operator->. This can only occur inside an actual arrow - // member access, so we treat it as equivalent to an actual object - // expression. - if (const auto *OpCall = dyn_cast(E)) { - if (OpCall->getOperator() == clang::OO_Arrow && - OpCall->getNumArgs() == 1) { - E = OpCall->getArg(0); - } - } - S = tooling::buildArrow(*E, *Match.Context); - } else if (const auto *Operand = isSmartDereference(*E, *Match.Context)) { - // `buildDot` already handles the built-in dereference operator, so we - // only need to catch overloaded `operator*`. - S = tooling::buildArrow(*Operand, *Match.Context); - } else { - S = tooling::buildDot(*E, *Match.Context); - } - if (S.hasValue()) - *Result += *S; - else - return llvm::make_error( - errc::invalid_argument, - "Could not construct object text from ID: " + BaseId); - } + llvm::Optional S = tooling::buildAccess(*E, *Match.Context); + if (!S.hasValue()) + return llvm::make_error( + errc::invalid_argument, + "Could not construct object text from ID: " + BaseId); + *Result += *S; return Member->eval(Match, Result); } }; diff --git a/clang/unittests/Tooling/SourceCodeBuildersTest.cpp b/clang/unittests/Tooling/SourceCodeBuildersTest.cpp --- a/clang/unittests/Tooling/SourceCodeBuildersTest.cpp +++ b/clang/unittests/Tooling/SourceCodeBuildersTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "clang/Tooling/Transformer/SourceCodeBuilders.h" +#include "clang/AST/Type.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/ASTMatchers/ASTMatchers.h" #include "clang/Tooling/Tooling.h" @@ -24,8 +25,23 @@ // Create a valid translation unit from a statement. static std::string wrapSnippet(StringRef StatementCode) { - return ("struct S { S(); S(int); int field; };\n" + return ("namespace std {\n" + "template struct unique_ptr {\n" + " T* operator->() const;\n" + " T& operator*() const;\n" + "};\n" + "template struct shared_ptr {\n" + " T* operator->() const;\n" + " T& operator*() const;\n" + "};\n" + "}\n" + "struct A { void super(); };\n" + "struct S : public A { S(); S(int); int Field; };\n" "S operator+(const S &a, const S &b);\n" + "struct Smart {\n" + " S* operator->() const;\n" + " S& operator*() const;\n" + "};\n" "auto test_snippet = []{" + StatementCode + "};") .str(); @@ -51,7 +67,8 @@ // `StatementCode` may contain other statements not described by `Matcher`. static llvm::Optional matchStmt(StringRef StatementCode, StatementMatcher Matcher) { - auto AstUnit = buildASTFromCode(wrapSnippet(StatementCode)); + auto AstUnit = buildASTFromCodeWithArgs(wrapSnippet(StatementCode), + {"-Wno-unused-value"}); if (AstUnit == nullptr) { ADD_FAILURE() << "AST construction failed"; return llvm::None; @@ -95,7 +112,7 @@ testPredicate(needParensAfterUnaryOperator, "int(3.0);", false); testPredicate(needParensAfterUnaryOperator, "void f(); f();", false); testPredicate(needParensAfterUnaryOperator, "int a[3]; a[0];", false); - testPredicate(needParensAfterUnaryOperator, "S x; x.field;", false); + testPredicate(needParensAfterUnaryOperator, "S x; x.Field;", false); testPredicate(needParensAfterUnaryOperator, "int x = 1; --x;", false); testPredicate(needParensAfterUnaryOperator, "int x = 1; -x;", false); } @@ -117,7 +134,7 @@ testPredicate(mayEverNeedParens, "int(3.0);", false); testPredicate(mayEverNeedParens, "void f(); f();", false); testPredicate(mayEverNeedParens, "int a[3]; a[0];", false); - testPredicate(mayEverNeedParens, "S x; x.field;", false); + testPredicate(mayEverNeedParens, "S x; x.Field;", false); } TEST(SourceCodeBuildersTest, mayEverNeedParensInImplictConversion) { @@ -126,6 +143,50 @@ testPredicateOnArg(mayEverNeedParens, "void f(S); f(3 + 5);", true); } +TEST(SourceCodeBuildersTest, isKnownPointerLikeTypeUniquePtr) { + std::string Snippet = "std::unique_ptr P; P;"; + auto StmtMatch = + matchStmt(Snippet, declRefExpr(hasType(qualType().bind("ty")))); + ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet; + EXPECT_TRUE( + isKnownPointerLikeType(*StmtMatch->Result.Nodes.getNodeAs("ty"), + *StmtMatch->Result.Context)) + << "Snippet: " << Snippet; +} + +TEST(SourceCodeBuildersTest, isKnownPointerLikeTypeSharedPtr) { + std::string Snippet = "std::shared_ptr P; P;"; + auto StmtMatch = + matchStmt(Snippet, declRefExpr(hasType(qualType().bind("ty")))); + ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet; + EXPECT_TRUE( + isKnownPointerLikeType(*StmtMatch->Result.Nodes.getNodeAs("ty"), + *StmtMatch->Result.Context)) + << "Snippet: " << Snippet; +} + +TEST(SourceCodeBuildersTest, isKnownPointerLikeTypeUnknownTypeFalse) { + std::string Snippet = "Smart P; P;"; + auto StmtMatch = + matchStmt(Snippet, declRefExpr(hasType(qualType().bind("ty")))); + ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet; + EXPECT_FALSE( + isKnownPointerLikeType(*StmtMatch->Result.Nodes.getNodeAs("ty"), + *StmtMatch->Result.Context)) + << "Snippet: " << Snippet; +} + +TEST(SourceCodeBuildersTest, isKnownPointerLikeTypeNormalTypeFalse) { + std::string Snippet = "int *P; P;"; + auto StmtMatch = + matchStmt(Snippet, declRefExpr(hasType(qualType().bind("ty")))); + ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet; + EXPECT_FALSE( + isKnownPointerLikeType(*StmtMatch->Result.Nodes.getNodeAs("ty"), + *StmtMatch->Result.Context)) + << "Snippet: " << Snippet; +} + static void testBuilder( llvm::Optional (*Builder)(const Expr &, const ASTContext &), StringRef Snippet, StringRef Expected) { @@ -136,6 +197,15 @@ ValueIs(std::string(Expected))); } +static void testBuildAccess(StringRef Snippet, StringRef Expected, + PLTClass C = PLTClass::Pointer) { + auto StmtMatch = matchStmt(Snippet, expr().bind("expr")); + ASSERT_TRUE(StmtMatch); + EXPECT_THAT(buildAccess(*StmtMatch->Result.Nodes.getNodeAs("expr"), + *StmtMatch->Result.Context, C), + ValueIs(std::string(Expected))); +} + TEST(SourceCodeBuildersTest, BuildParensUnaryOp) { testBuilder(buildParens, "-4;", "(-4)"); } @@ -245,4 +315,117 @@ TEST(SourceCodeBuildersTest, BuildArrowValueAddressWithParens) { testBuilder(buildArrow, "S x; &(true ? x : x);", "(true ? x : x)."); } + +TEST(SourceCodeBuildersTest, BuildAccessValue) { + testBuildAccess("S x; x;", "x."); +} + +TEST(SourceCodeBuildersTest, BuildAccessPointerDereference) { + testBuildAccess("S *x; *x;", "x->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessPointerDereferenceIgnoresParens) { + testBuildAccess("S *x; *(x);", "x->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessValueBinaryOperation) { + testBuildAccess("S x; x + x;", "(x + x)."); +} + +TEST(SourceCodeBuildersTest, BuildAccessPointerDereferenceExprWithParens) { + testBuildAccess("S *x; *(x + 1);", "(x + 1)->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessPointer) { + testBuildAccess("S *x; x;", "x->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessValueAddress) { + testBuildAccess("S x; &x;", "x."); +} + +TEST(SourceCodeBuildersTest, BuildAccessValueAddressIgnoresParens) { + testBuildAccess("S x; &(x);", "x."); +} + +TEST(SourceCodeBuildersTest, BuildAccessPointerBinaryOperation) { + testBuildAccess("S *x; x + 1;", "(x + 1)->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessValueAddressWithParens) { + testBuildAccess("S x; &(true ? x : x);", "(true ? x : x)."); +} + +TEST(SourceCodeBuildersTest, BuildAccessSmartPointer) { + testBuildAccess("std::unique_ptr x; x;", "x->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessSmartPointerAsValue) { + testBuildAccess("std::unique_ptr x; x;", "x.", PLTClass::Value); +} + +TEST(SourceCodeBuildersTest, BuildAccessSmartPointerDeref) { + testBuildAccess("std::unique_ptr x; *x;", "x->"); +} + +TEST(SourceCodeBuildersTest, BuildAccessSmartPointerDerefAsValue) { + testBuildAccess("std::unique_ptr x; *x;", "(*x).", PLTClass::Value); +} + +TEST(SourceCodeBuildersTest, BuildAccessSmartPointerMemberCall) { + StringRef Snippet = R"cc( + Smart x; + x->Field; + )cc"; + auto StmtMatch = + matchStmt(Snippet, memberExpr(hasObjectExpression(expr().bind("expr")))); + ASSERT_TRUE(StmtMatch); + EXPECT_THAT(buildAccess(*StmtMatch->Result.Nodes.getNodeAs("expr"), + *StmtMatch->Result.Context), + ValueIs(std::string("x->"))); +} + +TEST(SourceCodeBuildersTest, BuildAccessIgnoreImplicit) { + StringRef Snippet = R"cc( + S x; + A *a; + a = &x; + )cc"; + auto StmtMatch = + matchStmt(Snippet, binaryOperator(isAssignmentOperator(), + hasRHS(expr().bind("expr")))); + ASSERT_TRUE(StmtMatch); + EXPECT_THAT(buildAccess(*StmtMatch->Result.Nodes.getNodeAs("expr"), + *StmtMatch->Result.Context), + ValueIs(std::string("x."))); +} + +TEST(SourceCodeBuildersTest, BuildAccessImplicitThis) { + StringRef Snippet = R"cc( + struct Struct { + void foo() {} + void bar() { + foo(); + } + }; + )cc"; + auto StmtMatch = matchStmt( + Snippet, + cxxMemberCallExpr(onImplicitObjectArgument(cxxThisExpr().bind("expr")))); + ASSERT_TRUE(StmtMatch); + EXPECT_THAT(buildAccess(*StmtMatch->Result.Nodes.getNodeAs("expr"), + *StmtMatch->Result.Context), + ValueIs(std::string())); +} + +TEST(SourceCodeBuildersTest, BuildAccessImplicitThisIgnoreImplicitCasts) { + StringRef Snippet = "struct B : public A { void f() { super(); } };"; + auto StmtMatch = matchStmt( + Snippet, + cxxMemberCallExpr(onImplicitObjectArgument(expr().bind("expr")))); + ASSERT_TRUE(StmtMatch); + EXPECT_THAT(buildAccess(*StmtMatch->Result.Nodes.getNodeAs("expr"), + *StmtMatch->Result.Context), + ValueIs(std::string())); +} } // namespace