diff --git a/clang/docs/LibASTMatchersReference.html b/clang/docs/LibASTMatchersReference.html --- a/clang/docs/LibASTMatchersReference.html +++ b/clang/docs/LibASTMatchersReference.html @@ -611,6 +611,17 @@ +Matcher<ConceptReference>conceptReferenceMatcher<ConceptReference>... +
Matches concept references.
+
+Given
+  template <class> concept C = true;
+  bool X = C<int>;
+conceptReference()
+  matches 'C<int>'.
+
+ + Matcher<Decl>accessSpecDeclMatcher<AccessSpecDecl>...
Matches C++ access specifier declarations.
 
@@ -2531,20 +2542,21 @@
   class array {
     T data[Size];
   };
-dependentSizedArrayType()
+dependentSizedArrayType
   matches "T data[Size]"
 
Matcher<Type>dependentSizedExtVectorTypeMatcher<DependentSizedExtVectorType>... -
Matches C++ extended vector type where either the type or size is dependent.
+
Matches C++ extended vector type where either the type or size is
+dependent.
 
 Given
   template<typename T, int Size>
   class vector {
     typedef T __attribute__((ext_vector_type(Size))) type;
   };
-dependentSizedExtVectorType()
+dependentSizedExtVectorType
   matches "T __attribute__((ext_vector_type(Size)))"
 
diff --git a/clang/include/clang/ASTMatchers/ASTMatchFinder.h b/clang/include/clang/ASTMatchers/ASTMatchFinder.h --- a/clang/include/clang/ASTMatchers/ASTMatchFinder.h +++ b/clang/include/clang/ASTMatchers/ASTMatchFinder.h @@ -169,6 +169,8 @@ void addMatcher(const TemplateArgumentLocMatcher &NodeMatch, MatchCallback *Action); void addMatcher(const AttrMatcher &NodeMatch, MatchCallback *Action); + void addMatcher(const ConceptReferenceMatcher &NodeMatch, + MatchCallback *Action); /// @} /// Adds a matcher to execute when running over the AST. @@ -222,6 +224,8 @@ std::vector> TemplateArgumentLoc; std::vector> Attr; + std::vector> + ConceptReference; /// All the callbacks in one container to simplify iteration. llvm::SmallPtrSet AllCallbacks; }; diff --git a/clang/include/clang/ASTMatchers/ASTMatchers.h b/clang/include/clang/ASTMatchers/ASTMatchers.h --- a/clang/include/clang/ASTMatchers/ASTMatchers.h +++ b/clang/include/clang/ASTMatchers/ASTMatchers.h @@ -152,6 +152,7 @@ using TemplateArgumentLocMatcher = internal::Matcher; using LambdaCaptureMatcher = internal::Matcher; using AttrMatcher = internal::Matcher; +using ConceptReferenceMatcher = internal::Matcher; /// @} /// Matches any node. @@ -611,6 +612,18 @@ TemplateTemplateParmDecl> templateTemplateParmDecl; +/// Matches concept references. +/// +/// Given +/// \code +/// template concept C = true; +/// bool X = C; +/// \endcode +/// conceptReference() +/// matches 'C'. +extern const internal::VariadicAllOfMatcher conceptReference; + + /// Matches public C++ declarations and C++ base specifers that specify public /// inheritance. /// diff --git a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h --- a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h +++ b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h @@ -1163,19 +1163,10 @@ /// IsBaseType::value is true if T is a "base" type in the AST /// node class hierarchies. template -struct IsBaseType { - static const bool value = - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value; -}; -template -const bool IsBaseType::value; +using IsBaseType = + llvm::is_one_of; /// A "type list" that contains all types. /// diff --git a/clang/lib/ASTMatchers/ASTMatchFinder.cpp b/clang/lib/ASTMatchers/ASTMatchFinder.cpp --- a/clang/lib/ASTMatchers/ASTMatchFinder.cpp +++ b/clang/lib/ASTMatchers/ASTMatchFinder.cpp @@ -26,6 +26,7 @@ #include #include #include +#include namespace clang { namespace ast_matchers { @@ -136,6 +137,8 @@ traverse(*TALoc); else if (const Attr *A = DynNode.get()) traverse(*A); + else if (const ConceptReference *CR = DynNode.get()) + traverse(*CR); // FIXME: Add other base types after adding tests. // It's OK to always overwrite the bound nodes, as if there was @@ -275,6 +278,12 @@ ScopedIncrement ScopedDepth(&CurrentDepth); return traverse(*A); } + bool TraverseConceptReference(ConceptReference *CR) { + if (CR == nullptr) + return true; + ScopedIncrement ScopedDepth(&CurrentDepth); + return traverse(*CR); + } bool TraverseLambdaExpr(LambdaExpr *Node) { if (!Finder->isTraversalIgnoringImplicitNodes()) return VisitorBase::TraverseLambdaExpr(Node); @@ -360,6 +369,10 @@ bool baseTraverse(const Attr &AttrNode) { return VisitorBase::TraverseAttr(const_cast(&AttrNode)); } + bool baseTraverse(const ConceptReference &CR) { + return VisitorBase::TraverseConceptReference( + const_cast(&CR)); + } // Sets 'Matched' to true if 'Matcher' matches 'Node' and: // 0 < CurrentDepth <= MaxDepth. @@ -505,6 +518,7 @@ bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit); bool TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL); bool TraverseAttr(Attr *AttrNode); + bool TraverseConceptReference(ConceptReference *); bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue) { if (auto *RF = dyn_cast(S)) { @@ -712,6 +726,8 @@ match(*N); } else if (auto *N = Node.get()) { match(*N); + } else if (auto *N = Node.get()) { + match(*N); } } @@ -766,85 +782,14 @@ bool TraversingASTNodeNotAsIs = false; bool TraversingASTChildrenNotSpelledInSource = false; - class CurMatchData { -// We don't have enough free low bits in 32bit builds to discriminate 8 pointer -// types in PointerUnion. so split the union in 2 using a free bit from the -// callback pointer. -#define CMD_TYPES_0 \ - const QualType *, const TypeLoc *, const NestedNameSpecifier *, \ - const NestedNameSpecifierLoc * -#define CMD_TYPES_1 \ - const CXXCtorInitializer *, const TemplateArgumentLoc *, const Attr *, \ - const DynTypedNode * - -#define IMPL(Index) \ - template \ - std::enable_if_t< \ - llvm::is_one_of::value> \ - SetCallbackAndRawNode(const MatchCallback *CB, const NodeType &N) { \ - assertEmpty(); \ - Callback.setPointerAndInt(CB, Index); \ - Node##Index = &N; \ - } \ - \ - template \ - std::enable_if_t::value, \ - const T *> \ - getNode() const { \ - assertHoldsState(); \ - return Callback.getInt() == (Index) ? Node##Index.dyn_cast() \ - : nullptr; \ - } - - public: - CurMatchData() : Node0(nullptr) {} - - IMPL(0) - IMPL(1) - - const MatchCallback *getCallback() const { return Callback.getPointer(); } - - void SetBoundNodes(const BoundNodes &BN) { - assertHoldsState(); - BNodes = &BN; - } - - void clearBoundNodes() { - assertHoldsState(); - BNodes = nullptr; - } - - const BoundNodes *getBoundNodes() const { - assertHoldsState(); - return BNodes; - } - - void reset() { - assertHoldsState(); - Callback.setPointerAndInt(nullptr, 0); - Node0 = nullptr; - } - - private: - void assertHoldsState() const { - assert(Callback.getPointer() != nullptr && !Node0.isNull()); - } - - void assertEmpty() const { - assert(Callback.getPointer() == nullptr && Node0.isNull() && - BNodes == nullptr); - } - - llvm::PointerIntPair Callback; - union { - llvm::PointerUnion Node0; - llvm::PointerUnion Node1; - }; + struct CurMatchData { + const MatchCallback *Callback = nullptr; const BoundNodes *BNodes = nullptr; - -#undef CMD_TYPES_0 -#undef CMD_TYPES_1 -#undef IMPL + std::variant + Node; } CurMatchState; struct CurMatchRAII { @@ -852,10 +797,17 @@ CurMatchRAII(MatchASTVisitor &MV, const MatchCallback *CB, const NodeType &NT) : MV(MV) { - MV.CurMatchState.SetCallbackAndRawNode(CB, NT); + assert(MV.CurMatchState.Callback == nullptr && + std::holds_alternative(MV.CurMatchState.Node) && + MV.CurMatchState.BNodes == nullptr); + MV.CurMatchState.Callback = CB; + MV.CurMatchState.Node = &NT; } - ~CurMatchRAII() { MV.CurMatchState.reset(); } + ~CurMatchRAII() { + MV.CurMatchState.Callback = nullptr; + MV.CurMatchState.Node = std::monostate{}; + } private: MatchASTVisitor &MV; @@ -890,30 +842,22 @@ static void dumpNodeFromState(const ASTContext &Ctx, const CurMatchData &State, raw_ostream &OS) { - if (const DynTypedNode *MatchNode = State.getNode()) { - dumpNode(Ctx, *MatchNode, OS); - } else if (const auto *QT = State.getNode()) { - dumpNode(Ctx, DynTypedNode::create(*QT), OS); - } else if (const auto *TL = State.getNode()) { - dumpNode(Ctx, DynTypedNode::create(*TL), OS); - } else if (const auto *NNS = State.getNode()) { - dumpNode(Ctx, DynTypedNode::create(*NNS), OS); - } else if (const auto *NNSL = State.getNode()) { - dumpNode(Ctx, DynTypedNode::create(*NNSL), OS); - } else if (const auto *CtorInit = State.getNode()) { - dumpNode(Ctx, DynTypedNode::create(*CtorInit), OS); - } else if (const auto *TAL = State.getNode()) { - dumpNode(Ctx, DynTypedNode::create(*TAL), OS); - } else if (const auto *At = State.getNode()) { - dumpNode(Ctx, DynTypedNode::create(*At), OS); - } + std::visit(llvm::makeVisitor([&](std::monostate) {}, + [&](const DynTypedNode *Node) { + dumpNode(Ctx, *Node, OS); + }, + [&](const auto *Node) { + dumpNode(Ctx, DynTypedNode::create(*Node), + OS); + }), + State.Node); } public: TraceReporter(const MatchASTVisitor &MV) : MV(MV) {} void print(raw_ostream &OS) const override { const CurMatchData &State = MV.CurMatchState; - const MatchCallback *CB = State.getCallback(); + const MatchCallback *CB = State.Callback; if (!CB) { OS << "ASTMatcher: Not currently matching\n"; return; @@ -924,7 +868,7 @@ ASTContext &Ctx = MV.getASTContext(); - if (const BoundNodes *Nodes = State.getBoundNodes()) { + if (const BoundNodes *Nodes = State.BNodes) { OS << "ASTMatcher: Processing '" << CB->getID() << "' against:\n\t"; dumpNodeFromState(Ctx, State, OS); const BoundNodes::IDToNodeMap &Map = Nodes->getMap(); @@ -1102,6 +1046,9 @@ void matchDispatch(const Attr *Node) { matchWithoutFilter(*Node, Matchers->Attr); } + void matchDispatch(const ConceptReference *Node) { + matchWithoutFilter(*Node, Matchers->ConceptReference); + } void matchDispatch(const void *) { /* Do nothing. */ } /// @} @@ -1240,10 +1187,11 @@ struct CurBoundScope { CurBoundScope(MatchASTVisitor::CurMatchData &State, const BoundNodes &BN) : State(State) { - State.SetBoundNodes(BN); + assert(State.BNodes == nullptr); + State.BNodes = &BN; } - ~CurBoundScope() { State.clearBoundNodes(); } + ~CurBoundScope() { State.BNodes = nullptr; } private: MatchASTVisitor::CurMatchData &State; @@ -1526,6 +1474,11 @@ return RecursiveASTVisitor::TraverseAttr(AttrNode); } +bool MatchASTVisitor::TraverseConceptReference(ConceptReference *CR) { + match(*CR); + return RecursiveASTVisitor::TraverseConceptReference(CR); +} + class MatchASTConsumer : public ASTConsumer { public: MatchASTConsumer(MatchFinder *Finder, @@ -1626,6 +1579,12 @@ Matchers.AllCallbacks.insert(Action); } +void MatchFinder::addMatcher(const ConceptReferenceMatcher &AttrMatch, + MatchCallback *Action) { + Matchers.ConceptReference.emplace_back(AttrMatch, Action); + Matchers.AllCallbacks.insert(Action); +} + bool MatchFinder::addDynamicMatcher(const internal::DynTypedMatcher &NodeMatch, MatchCallback *Action) { if (NodeMatch.canConvertTo()) { @@ -1655,6 +1614,9 @@ } else if (NodeMatch.canConvertTo()) { addMatcher(NodeMatch.convertTo(), Action); return true; + } else if (NodeMatch.canConvertTo()) { + addMatcher(NodeMatch.convertTo(), Action); + return true; } return false; } diff --git a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp --- a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp +++ b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp @@ -768,6 +768,7 @@ templateTypeParmDecl; const internal::VariadicDynCastAllOfMatcher templateTemplateParmDecl; +const internal::VariadicAllOfMatcher conceptReference; const internal::VariadicAllOfMatcher lambdaCapture; const internal::VariadicAllOfMatcher qualType; diff --git a/clang/lib/ASTMatchers/Dynamic/Registry.cpp b/clang/lib/ASTMatchers/Dynamic/Registry.cpp --- a/clang/lib/ASTMatchers/Dynamic/Registry.cpp +++ b/clang/lib/ASTMatchers/Dynamic/Registry.cpp @@ -173,6 +173,7 @@ REGISTER_MATCHER(compoundStmt); REGISTER_MATCHER(coawaitExpr); REGISTER_MATCHER(conceptDecl); + REGISTER_MATCHER(conceptReference); REGISTER_MATCHER(conditionalOperator); REGISTER_MATCHER(constantArrayType); REGISTER_MATCHER(constantExpr); diff --git a/clang/unittests/AST/ASTTypeTraitsTest.cpp b/clang/unittests/AST/ASTTypeTraitsTest.cpp --- a/clang/unittests/AST/ASTTypeTraitsTest.cpp +++ b/clang/unittests/AST/ASTTypeTraitsTest.cpp @@ -210,7 +210,13 @@ ast_matchers::attr())); } -// FIXME: add tests for ConceptReference once we add an ASTMatcher. +TEST(DynTypedNode, ConceptReferenceSourceRange) { + RangeVerifier Verifier; + Verifier.expectRange(2, 10, 2, 15); + EXPECT_TRUE(Verifier.match("template concept C = true;\n" + "auto X = C;", + conceptReference())); +} TEST(DynTypedNode, DeclDump) { DumpVerifier Verifier; @@ -224,6 +230,14 @@ EXPECT_TRUE(Verifier.match("void f() {}", stmt())); } +TEST(DynTypedNode, ConceptReferenceDump) { + DumpVerifier Verifier; + Verifier.expectSubstring("ConceptReference"); + EXPECT_TRUE(Verifier.match("template concept C = true;\n" + "auto X = C;", + conceptReference())); +} + TEST(DynTypedNode, DeclPrint) { PrintVerifier Verifier; Verifier.expectString("void f() {\n}\n"); @@ -236,6 +250,14 @@ EXPECT_TRUE(Verifier.match("void f() {}", stmt())); } +TEST(DynTypedNode, ConceptReferencePrint) { + PrintVerifier Verifier; + Verifier.expectString("C"); + EXPECT_TRUE(Verifier.match("template concept C = true;\n" + "auto X = C;", + conceptReference())); +} + TEST(DynTypedNode, QualType) { QualType Q; DynTypedNode Node = DynTypedNode::create(Q); diff --git a/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp --- a/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp @@ -770,6 +770,17 @@ cxxCtorInitializer(forField(hasName("i"))))); } +TEST_P(ASTMatchersTest, Match_ConceptReference) { + if (!GetParam().isCXX20OrLater()) { + return; + } + std::string Concept = "template concept C = true;\n"; + EXPECT_TRUE(matches(Concept + "auto X = C;", conceptReference())); + EXPECT_TRUE(matches(Concept + "C auto X = 0;", conceptReference())); + EXPECT_TRUE(matches(Concept + "template int i;", conceptReference())); + EXPECT_TRUE(matches(Concept + "void foo(C auto X) {}", conceptReference())); +} + TEST_P(ASTMatchersTest, Matcher_ThisExpr) { if (!GetParam().isCXX()) { return;