Index: include/clang/ASTMatchers/ASTMatchersInternal.h =================================================================== --- include/clang/ASTMatchers/ASTMatchersInternal.h +++ include/clang/ASTMatchers/ASTMatchersInternal.h @@ -146,11 +146,21 @@ /// Internal version of BoundNodes. Holds all the bound nodes. class BoundNodesMap { public: + struct NodeEntry { + ast_type_traits::DynTypedNode DynNode; + ast_type_traits::ASTNodeKind NodeKind; + + bool operator<(const NodeEntry &other) const { + return DynNode < other.DynNode && NodeKind < other.NodeKind; + } + }; + /// Adds \c Node to the map with key \c ID. /// /// The node's base type should be in NodeBaseType or it will be unaccessible. - void addNode(StringRef ID, const ast_type_traits::DynTypedNode& DynNode) { - NodeMap[ID] = DynNode; + void addNode(StringRef ID, const ast_type_traits::DynTypedNode &DynNode, + ast_type_traits::ASTNodeKind NodeKind) { + NodeMap[ID] = NodeEntry{DynNode, NodeKind}; } /// Returns the AST node bound to \c ID. @@ -163,7 +173,7 @@ if (It == NodeMap.end()) { return nullptr; } - return It->second.get(); + return It->second.DynNode.get(); } ast_type_traits::DynTypedNode getNode(StringRef ID) const { @@ -171,7 +181,7 @@ if (It == NodeMap.end()) { return ast_type_traits::DynTypedNode(); } - return It->second; + return It->second.DynNode; } /// Imposes an order on BoundNodesMaps. @@ -184,7 +194,7 @@ /// Note that we're using std::map here, as for memoization: /// - we need a comparison operator /// - we need an assignment operator - using IDToNodeMap = std::map; + using IDToNodeMap = std::map; const IDToNodeMap &getMap() const { return NodeMap; @@ -194,7 +204,7 @@ /// stored nodes have memoization data. bool isComparable() const { for (const auto &IDAndNode : NodeMap) { - if (!IDAndNode.second.getMemoizationData()) + if (!IDAndNode.second.DynNode.getMemoizationData()) return false; } return true; @@ -223,11 +233,12 @@ }; /// Add a binding from an id to a node. - void setBinding(StringRef Id, const ast_type_traits::DynTypedNode &DynNode) { + void setBinding(StringRef Id, const ast_type_traits::DynTypedNode &DynNode, + ast_type_traits::ASTNodeKind NodeKind) { if (Bindings.empty()) Bindings.emplace_back(); for (BoundNodesMap &Binding : Bindings) - Binding.addNode(Id, DynNode); + Binding.addNode(Id, DynNode, NodeKind); } /// Adds a branch in the tree. @@ -282,7 +293,8 @@ /// the AST via \p Finder. virtual bool dynMatches(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, - BoundNodesTreeBuilder *Builder) const = 0; + BoundNodesTreeBuilder *Builder, + ast_type_traits::ASTNodeKind NodeKind) const = 0; }; /// Generic interface for matchers on an AST node of type T. @@ -304,8 +316,8 @@ BoundNodesTreeBuilder *Builder) const = 0; bool dynMatches(const ast_type_traits::DynTypedNode &DynNode, - ASTMatchFinder *Finder, - BoundNodesTreeBuilder *Builder) const override { + ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder, + ast_type_traits::ASTNodeKind NodeKind) const override { return matches(DynNode.getUnchecked(), Finder, Builder); } }; Index: lib/ASTMatchers/ASTMatchersInternal.cpp =================================================================== --- lib/ASTMatchers/ASTMatchersInternal.cpp +++ lib/ASTMatchers/ASTMatchersInternal.cpp @@ -52,21 +52,25 @@ bool NotUnaryOperator(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder, + ast_type_traits::ASTNodeKind NodeKind, ArrayRef InnerMatchers); bool AllOfVariadicOperator(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder, + ast_type_traits::ASTNodeKind NodeKind, ArrayRef InnerMatchers); bool EachOfVariadicOperator(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder, + ast_type_traits::ASTNodeKind NodeKind, ArrayRef InnerMatchers); bool AnyOfVariadicOperator(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder, + ast_type_traits::ASTNodeKind NodeKind, ArrayRef InnerMatchers); void BoundNodesTreeBuilder::visitMatches(Visitor *ResultVisitor) { @@ -81,7 +85,8 @@ using VariadicOperatorFunction = bool (*)( const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, - BoundNodesTreeBuilder *Builder, ArrayRef InnerMatchers); + BoundNodesTreeBuilder *Builder, ast_type_traits::ASTNodeKind NodeKind, + ArrayRef InnerMatchers); template class VariadicMatcher : public DynMatcherInterface { @@ -90,9 +95,9 @@ : InnerMatchers(std::move(InnerMatchers)) {} bool dynMatches(const ast_type_traits::DynTypedNode &DynNode, - ASTMatchFinder *Finder, - BoundNodesTreeBuilder *Builder) const override { - return Func(DynNode, Finder, Builder, InnerMatchers); + ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder, + ast_type_traits::ASTNodeKind NodeKind) const override { + return Func(DynNode, Finder, Builder, NodeKind, InnerMatchers); } private: @@ -106,10 +111,11 @@ : ID(ID), InnerMatcher(std::move(InnerMatcher)) {} bool dynMatches(const ast_type_traits::DynTypedNode &DynNode, - ASTMatchFinder *Finder, - BoundNodesTreeBuilder *Builder) const override { - bool Result = InnerMatcher->dynMatches(DynNode, Finder, Builder); - if (Result) Builder->setBinding(ID, DynNode); + ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder, + ast_type_traits::ASTNodeKind NodeKind) const override { + bool Result = InnerMatcher->dynMatches(DynNode, Finder, Builder, NodeKind); + if (Result) + Builder->setBinding(ID, DynNode, NodeKind); return Result; } @@ -130,7 +136,8 @@ } bool dynMatches(const ast_type_traits::DynTypedNode &, ASTMatchFinder *, - BoundNodesTreeBuilder *) const override { + BoundNodesTreeBuilder *, + ast_type_traits::ASTNodeKind) const override { return true; } }; @@ -213,7 +220,7 @@ ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder) const { if (RestrictKind.isBaseOf(DynNode.getNodeKind()) && - Implementation->dynMatches(DynNode, Finder, Builder)) { + Implementation->dynMatches(DynNode, Finder, Builder, RestrictKind)) { return true; } // Delete all bindings when a matcher does not match. @@ -227,7 +234,7 @@ const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder) const { assert(RestrictKind.isBaseOf(DynNode.getNodeKind())); - if (Implementation->dynMatches(DynNode, Finder, Builder)) { + if (Implementation->dynMatches(DynNode, Finder, Builder, RestrictKind)) { return true; } // Delete all bindings when a matcher does not match. @@ -262,6 +269,7 @@ bool NotUnaryOperator(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder, + ast_type_traits::ASTNodeKind NodeKind, ArrayRef InnerMatchers) { if (InnerMatchers.size() != 1) return false; @@ -283,6 +291,7 @@ bool AllOfVariadicOperator(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder, + ast_type_traits::ASTNodeKind NodeKind, ArrayRef InnerMatchers) { // allOf leads to one matcher for each alternative in the first // matcher combined with each alternative in the second matcher. @@ -297,6 +306,7 @@ bool EachOfVariadicOperator(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder, + ast_type_traits::ASTNodeKind NodeKind, ArrayRef InnerMatchers) { BoundNodesTreeBuilder Result; bool Matched = false; @@ -314,6 +324,7 @@ bool AnyOfVariadicOperator(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder, + ast_type_traits::ASTNodeKind NodeKind, ArrayRef InnerMatchers) { for (const DynTypedMatcher &InnerMatcher : InnerMatchers) { BoundNodesTreeBuilder Result = *Builder; Index: lib/Tooling/RefactoringCallbacks.cpp =================================================================== --- lib/Tooling/RefactoringCallbacks.cpp +++ lib/Tooling/RefactoringCallbacks.cpp @@ -213,8 +213,8 @@ << " used in replacement template not bound in Matcher \n"; llvm::report_fatal_error("Unbound node in replacement template."); } - CharSourceRange Source = - CharSourceRange::getTokenRange(NodeIter->second.getSourceRange()); + CharSourceRange Source = CharSourceRange::getTokenRange( + NodeIter->second.DynNode.getSourceRange()); ToText += Lexer::getSourceText(Source, *Result.SourceManager, Result.Context->getLangOpts()); break; @@ -227,8 +227,8 @@ llvm::report_fatal_error("FromId node not bound in MatchResult"); } auto Replacement = - tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText, - Result.Context->getLangOpts()); + tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId).DynNode, + ToText, Result.Context->getLangOpts()); llvm::Error Err = Replace.add(Replacement); if (Err) { llvm::errs() << "Query and replace failed in " << Replacement.getFilePath() Index: unittests/ASTMatchers/ASTMatchersNodeTest.cpp =================================================================== --- unittests/ASTMatchers/ASTMatchersNodeTest.cpp +++ unittests/ASTMatchers/ASTMatchersNodeTest.cpp @@ -1724,12 +1724,19 @@ std::string SourceCode = "struct A { void f() {} };"; auto Matcher = functionDecl(isDefinition()).bind("method"); + using namespace ast_type_traits; + auto astUnit = tooling::buildASTFromCode(SourceCode); auto GlobalBoundNodes = matchDynamic(Matcher, astUnit->getASTContext()); EXPECT_EQ(GlobalBoundNodes.size(), 1u); EXPECT_EQ(GlobalBoundNodes[0].getMap().size(), 1u); + auto GlobalMapPair = *GlobalBoundNodes[0].getMap().begin(); + EXPECT_TRUE(GlobalMapPair.second.DynNode.getNodeKind().isSame( + ASTNodeKind::getFromNodeKind())); + EXPECT_TRUE(GlobalMapPair.second.NodeKind.isSame( + ASTNodeKind::getFromNodeKind())); auto GlobalMethodNode = GlobalBoundNodes[0].getNodeAs("method"); EXPECT_TRUE(GlobalMethodNode != nullptr); @@ -1738,6 +1745,14 @@ matchDynamic(Matcher, *GlobalMethodNode, astUnit->getASTContext()); EXPECT_EQ(MethodBoundNodes.size(), 1u); EXPECT_EQ(MethodBoundNodes[0].getMap().size(), 1u); + auto MethodMapPair = *MethodBoundNodes[0].getMap().begin(); + EXPECT_TRUE(MethodMapPair.second.DynNode.getNodeKind().isSame( + ASTNodeKind::getFromNodeKind())); + EXPECT_TRUE(MethodMapPair.second.NodeKind.isSame( + ASTNodeKind::getFromNodeKind())); + EXPECT_EQ(MethodMapPair.second.DynNode, GlobalMapPair.second.DynNode); + EXPECT_TRUE( + MethodMapPair.second.NodeKind.isSame(GlobalMapPair.second.NodeKind)); auto MethodNode = MethodBoundNodes[0].getNodeAs("method"); EXPECT_EQ(MethodNode, GlobalMethodNode); Index: unittests/ASTMatchers/ASTMatchersTest.h =================================================================== --- unittests/ASTMatchers/ASTMatchersTest.h +++ unittests/ASTMatchers/ASTMatchersTest.h @@ -349,12 +349,12 @@ BoundNodes::IDToNodeMap::const_iterator I = M.find(Id); EXPECT_NE(M.end(), I); if (I != M.end()) { - EXPECT_EQ(Nodes->getNodeAs(Id), I->second.get()); + EXPECT_EQ(Nodes->getNodeAs(Id), I->second.DynNode.get()); } return true; } EXPECT_TRUE(M.count(Id) == 0 || - M.find(Id)->second.template get() == nullptr); + M.find(Id)->second.DynNode.get() == nullptr); return false; }