diff --git a/clang/docs/LibASTMatchersReference.html b/clang/docs/LibASTMatchersReference.html --- a/clang/docs/LibASTMatchersReference.html +++ b/clang/docs/LibASTMatchersReference.html @@ -611,6 +611,17 @@ +Matcher<ConceptReference>conceptReferenceMatcher<ConceptReference>... +
Matches concept references.
+
+Given
+  template <class> concept C = true;
+  bool X = C<int>;
+conceptReference()
+  matches 'C<int>'.
+
+ + Matcher<Decl>accessSpecDeclMatcher<AccessSpecDecl>...
Matches C++ access specifier declarations.
 
diff --git a/clang/include/clang/ASTMatchers/ASTMatchFinder.h b/clang/include/clang/ASTMatchers/ASTMatchFinder.h
--- a/clang/include/clang/ASTMatchers/ASTMatchFinder.h
+++ b/clang/include/clang/ASTMatchers/ASTMatchFinder.h
@@ -169,6 +169,8 @@
   void addMatcher(const TemplateArgumentLocMatcher &NodeMatch,
                   MatchCallback *Action);
   void addMatcher(const AttrMatcher &NodeMatch, MatchCallback *Action);
+  void addMatcher(const ConceptReferenceMatcher &NodeMatch,
+                  MatchCallback *Action);
   /// @}
 
   /// Adds a matcher to execute when running over the AST.
@@ -222,6 +224,8 @@
     std::vector>
         TemplateArgumentLoc;
     std::vector> Attr;
+    std::vector>
+        ConceptReference;
     /// All the callbacks in one container to simplify iteration.
     llvm::SmallPtrSet AllCallbacks;
   };
diff --git a/clang/include/clang/ASTMatchers/ASTMatchers.h b/clang/include/clang/ASTMatchers/ASTMatchers.h
--- a/clang/include/clang/ASTMatchers/ASTMatchers.h
+++ b/clang/include/clang/ASTMatchers/ASTMatchers.h
@@ -152,6 +152,7 @@
 using TemplateArgumentLocMatcher = internal::Matcher;
 using LambdaCaptureMatcher = internal::Matcher;
 using AttrMatcher = internal::Matcher;
+using ConceptReferenceMatcher = internal::Matcher;
 /// @}
 
 /// Matches any node.
@@ -611,6 +612,18 @@
                                                    TemplateTemplateParmDecl>
     templateTemplateParmDecl;
 
+/// Matches concept references.
+///
+/// Given
+/// \code
+///   template  concept C = true;
+///   bool X = C;
+/// \endcode
+/// conceptReference()
+///   matches 'C'.
+extern const internal::VariadicAllOfMatcher conceptReference;
+
+
 /// Matches public C++ declarations and C++ base specifers that specify public
 /// inheritance.
 ///
diff --git a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h
--- a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h
+++ b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h
@@ -1163,19 +1163,10 @@
 /// IsBaseType::value is true if T is a "base" type in the AST
 /// node class hierarchies.
 template 
-struct IsBaseType {
-  static const bool value =
-      std::is_same::value || std::is_same::value ||
-      std::is_same::value || std::is_same::value ||
-      std::is_same::value ||
-      std::is_same::value ||
-      std::is_same::value ||
-      std::is_same::value ||
-      std::is_same::value ||
-      std::is_same::value;
-};
-template 
-const bool IsBaseType::value;
+using IsBaseType =
+    llvm::is_one_of;
 
 /// A "type list" that contains all types.
 ///
diff --git a/clang/lib/ASTMatchers/ASTMatchFinder.cpp b/clang/lib/ASTMatchers/ASTMatchFinder.cpp
--- a/clang/lib/ASTMatchers/ASTMatchFinder.cpp
+++ b/clang/lib/ASTMatchers/ASTMatchFinder.cpp
@@ -26,6 +26,7 @@
 #include 
 #include 
 #include 
+#include 
 
 namespace clang {
 namespace ast_matchers {
@@ -136,6 +137,8 @@
       traverse(*TALoc);
     else if (const Attr *A = DynNode.get())
       traverse(*A);
+    else if (const ConceptReference *CR = DynNode.get())
+      traverse(*CR);
     // FIXME: Add other base types after adding tests.
 
     // It's OK to always overwrite the bound nodes, as if there was
@@ -275,6 +278,12 @@
     ScopedIncrement ScopedDepth(&CurrentDepth);
     return traverse(*A);
   }
+  bool TraverseConceptReference(ConceptReference *CR) {
+    if (CR == nullptr)
+      return true;
+    ScopedIncrement ScopedDepth(&CurrentDepth);
+    return traverse(*CR);
+  }
   bool TraverseLambdaExpr(LambdaExpr *Node) {
     if (!Finder->isTraversalIgnoringImplicitNodes())
       return VisitorBase::TraverseLambdaExpr(Node);
@@ -360,6 +369,10 @@
   bool baseTraverse(const Attr &AttrNode) {
     return VisitorBase::TraverseAttr(const_cast(&AttrNode));
   }
+  bool baseTraverse(const ConceptReference &CR) {
+    return VisitorBase::TraverseConceptReference(
+        const_cast(&CR));
+  }
 
   // Sets 'Matched' to true if 'Matcher' matches 'Node' and:
   //   0 < CurrentDepth <= MaxDepth.
@@ -505,6 +518,7 @@
   bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit);
   bool TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL);
   bool TraverseAttr(Attr *AttrNode);
+  bool TraverseConceptReference(ConceptReference *);
 
   bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue) {
     if (auto *RF = dyn_cast(S)) {
@@ -712,6 +726,8 @@
       match(*N);
     } else if (auto *N = Node.get()) {
       match(*N);
+    } else if (auto *N = Node.get()) {
+      match(*N);
     }
   }
 
@@ -766,85 +782,14 @@
   bool TraversingASTNodeNotAsIs = false;
   bool TraversingASTChildrenNotSpelledInSource = false;
 
-  class CurMatchData {
-// We don't have enough free low bits in 32bit builds to discriminate 8 pointer
-// types in PointerUnion. so split the union in 2 using a free bit from the
-// callback pointer.
-#define CMD_TYPES_0                                                            \
-  const QualType *, const TypeLoc *, const NestedNameSpecifier *,              \
-      const NestedNameSpecifierLoc *
-#define CMD_TYPES_1                                                            \
-  const CXXCtorInitializer *, const TemplateArgumentLoc *, const Attr *,       \
-      const DynTypedNode *
-
-#define IMPL(Index)                                                            \
-  template                                                  \
-  std::enable_if_t<                                                            \
-      llvm::is_one_of::value>             \
-  SetCallbackAndRawNode(const MatchCallback *CB, const NodeType &N) {          \
-    assertEmpty();                                                             \
-    Callback.setPointerAndInt(CB, Index);                                      \
-    Node##Index = &N;                                                          \
-  }                                                                            \
-                                                                               \
-  template                                                         \
-  std::enable_if_t::value,       \
-                   const T *>                                                  \
-  getNode() const {                                                            \
-    assertHoldsState();                                                        \
-    return Callback.getInt() == (Index) ? Node##Index.dyn_cast()    \
-                                        : nullptr;                             \
-  }
-
-  public:
-    CurMatchData() : Node0(nullptr) {}
-
-    IMPL(0)
-    IMPL(1)
-
-    const MatchCallback *getCallback() const { return Callback.getPointer(); }
-
-    void SetBoundNodes(const BoundNodes &BN) {
-      assertHoldsState();
-      BNodes = &BN;
-    }
-
-    void clearBoundNodes() {
-      assertHoldsState();
-      BNodes = nullptr;
-    }
-
-    const BoundNodes *getBoundNodes() const {
-      assertHoldsState();
-      return BNodes;
-    }
-
-    void reset() {
-      assertHoldsState();
-      Callback.setPointerAndInt(nullptr, 0);
-      Node0 = nullptr;
-    }
-
-  private:
-    void assertHoldsState() const {
-      assert(Callback.getPointer() != nullptr && !Node0.isNull());
-    }
-
-    void assertEmpty() const {
-      assert(Callback.getPointer() == nullptr && Node0.isNull() &&
-             BNodes == nullptr);
-    }
-
-    llvm::PointerIntPair Callback;
-    union {
-      llvm::PointerUnion Node0;
-      llvm::PointerUnion Node1;
-    };
+  struct CurMatchData {
+    const MatchCallback *Callback = nullptr;
     const BoundNodes *BNodes = nullptr;
-
-#undef CMD_TYPES_0
-#undef CMD_TYPES_1
-#undef IMPL
+    std::variant
+        Node;
   } CurMatchState;
 
   struct CurMatchRAII {
@@ -852,10 +797,17 @@
     CurMatchRAII(MatchASTVisitor &MV, const MatchCallback *CB,
                  const NodeType &NT)
         : MV(MV) {
-      MV.CurMatchState.SetCallbackAndRawNode(CB, NT);
+      assert(MV.CurMatchState.Callback == nullptr &&
+             std::holds_alternative(MV.CurMatchState.Node) &&
+             MV.CurMatchState.BNodes == nullptr);
+      MV.CurMatchState.Callback = CB;
+      MV.CurMatchState.Node = &NT;
     }
 
-    ~CurMatchRAII() { MV.CurMatchState.reset(); }
+    ~CurMatchRAII() {
+      MV.CurMatchState.Callback = nullptr;
+      MV.CurMatchState.Node = std::monostate{};
+    }
 
   private:
     MatchASTVisitor &MV;
@@ -890,30 +842,22 @@
 
     static void dumpNodeFromState(const ASTContext &Ctx,
                                   const CurMatchData &State, raw_ostream &OS) {
-      if (const DynTypedNode *MatchNode = State.getNode()) {
-        dumpNode(Ctx, *MatchNode, OS);
-      } else if (const auto *QT = State.getNode()) {
-        dumpNode(Ctx, DynTypedNode::create(*QT), OS);
-      } else if (const auto *TL = State.getNode()) {
-        dumpNode(Ctx, DynTypedNode::create(*TL), OS);
-      } else if (const auto *NNS = State.getNode()) {
-        dumpNode(Ctx, DynTypedNode::create(*NNS), OS);
-      } else if (const auto *NNSL = State.getNode()) {
-        dumpNode(Ctx, DynTypedNode::create(*NNSL), OS);
-      } else if (const auto *CtorInit = State.getNode()) {
-        dumpNode(Ctx, DynTypedNode::create(*CtorInit), OS);
-      } else if (const auto *TAL = State.getNode()) {
-        dumpNode(Ctx, DynTypedNode::create(*TAL), OS);
-      } else if (const auto *At = State.getNode()) {
-        dumpNode(Ctx, DynTypedNode::create(*At), OS);
-      }
+      std::visit(llvm::makeVisitor([&](std::monostate) {},
+                                   [&](const DynTypedNode *Node) {
+                                     dumpNode(Ctx, *Node, OS);
+                                   },
+                                   [&](const auto *Node) {
+                                     dumpNode(Ctx, DynTypedNode::create(*Node),
+                                              OS);
+                                   }),
+                 State.Node);
     }
 
   public:
     TraceReporter(const MatchASTVisitor &MV) : MV(MV) {}
     void print(raw_ostream &OS) const override {
       const CurMatchData &State = MV.CurMatchState;
-      const MatchCallback *CB = State.getCallback();
+      const MatchCallback *CB = State.Callback;
       if (!CB) {
         OS << "ASTMatcher: Not currently matching\n";
         return;
@@ -924,7 +868,7 @@
 
       ASTContext &Ctx = MV.getASTContext();
 
-      if (const BoundNodes *Nodes = State.getBoundNodes()) {
+      if (const BoundNodes *Nodes = State.BNodes) {
         OS << "ASTMatcher: Processing '" << CB->getID() << "' against:\n\t";
         dumpNodeFromState(Ctx, State, OS);
         const BoundNodes::IDToNodeMap &Map = Nodes->getMap();
@@ -1102,6 +1046,9 @@
   void matchDispatch(const Attr *Node) {
     matchWithoutFilter(*Node, Matchers->Attr);
   }
+  void matchDispatch(const ConceptReference *Node) {
+    matchWithoutFilter(*Node, Matchers->ConceptReference);
+  }
   void matchDispatch(const void *) { /* Do nothing. */ }
   /// @}
 
@@ -1240,10 +1187,11 @@
     struct CurBoundScope {
       CurBoundScope(MatchASTVisitor::CurMatchData &State, const BoundNodes &BN)
           : State(State) {
-        State.SetBoundNodes(BN);
+        assert(State.BNodes == nullptr);
+        State.BNodes = &BN;
       }
 
-      ~CurBoundScope() { State.clearBoundNodes(); }
+      ~CurBoundScope() { State.BNodes = nullptr; }
 
     private:
       MatchASTVisitor::CurMatchData &State;
@@ -1526,6 +1474,11 @@
   return RecursiveASTVisitor::TraverseAttr(AttrNode);
 }
 
+bool MatchASTVisitor::TraverseConceptReference(ConceptReference *CR) {
+  match(*CR);
+  return RecursiveASTVisitor::TraverseConceptReference(CR);
+}
+
 class MatchASTConsumer : public ASTConsumer {
 public:
   MatchASTConsumer(MatchFinder *Finder,
@@ -1626,6 +1579,12 @@
   Matchers.AllCallbacks.insert(Action);
 }
 
+void MatchFinder::addMatcher(const ConceptReferenceMatcher &AttrMatch,
+                             MatchCallback *Action) {
+  Matchers.ConceptReference.emplace_back(AttrMatch, Action);
+  Matchers.AllCallbacks.insert(Action);
+}
+
 bool MatchFinder::addDynamicMatcher(const internal::DynTypedMatcher &NodeMatch,
                                     MatchCallback *Action) {
   if (NodeMatch.canConvertTo()) {
@@ -1655,6 +1614,9 @@
   } else if (NodeMatch.canConvertTo()) {
     addMatcher(NodeMatch.convertTo(), Action);
     return true;
+  } else if (NodeMatch.canConvertTo()) {
+    addMatcher(NodeMatch.convertTo(), Action);
+    return true;
   }
   return false;
 }
diff --git a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp
--- a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp
+++ b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp
@@ -768,6 +768,7 @@
     templateTypeParmDecl;
 const internal::VariadicDynCastAllOfMatcher
     templateTemplateParmDecl;
+const internal::VariadicAllOfMatcher conceptReference;
 
 const internal::VariadicAllOfMatcher lambdaCapture;
 const internal::VariadicAllOfMatcher qualType;
diff --git a/clang/lib/ASTMatchers/Dynamic/Registry.cpp b/clang/lib/ASTMatchers/Dynamic/Registry.cpp
--- a/clang/lib/ASTMatchers/Dynamic/Registry.cpp
+++ b/clang/lib/ASTMatchers/Dynamic/Registry.cpp
@@ -173,6 +173,7 @@
   REGISTER_MATCHER(compoundStmt);
   REGISTER_MATCHER(coawaitExpr);
   REGISTER_MATCHER(conceptDecl);
+  REGISTER_MATCHER(conceptReference);
   REGISTER_MATCHER(conditionalOperator);
   REGISTER_MATCHER(constantArrayType);
   REGISTER_MATCHER(constantExpr);
diff --git a/clang/unittests/AST/ASTTypeTraitsTest.cpp b/clang/unittests/AST/ASTTypeTraitsTest.cpp
--- a/clang/unittests/AST/ASTTypeTraitsTest.cpp
+++ b/clang/unittests/AST/ASTTypeTraitsTest.cpp
@@ -210,7 +210,13 @@
                              ast_matchers::attr()));
 }
 
-// FIXME: add tests for ConceptReference once we add an ASTMatcher.
+TEST(DynTypedNode, ConceptReferenceSourceRange) {
+  RangeVerifier Verifier;
+  Verifier.expectRange(2, 10, 2, 15);
+  EXPECT_TRUE(Verifier.match("template  concept C = true;\n"
+                             "auto X = C;",
+                             conceptReference()));
+}
 
 TEST(DynTypedNode, DeclDump) {
   DumpVerifier Verifier;
@@ -224,6 +230,14 @@
   EXPECT_TRUE(Verifier.match("void f() {}", stmt()));
 }
 
+TEST(DynTypedNode, ConceptReferenceDump) {
+  DumpVerifier Verifier;
+  Verifier.expectSubstring("ConceptReference");
+  EXPECT_TRUE(Verifier.match("template  concept C = true;\n"
+                             "auto X = C;",
+                             conceptReference()));
+}
+
 TEST(DynTypedNode, DeclPrint) {
   PrintVerifier Verifier;
   Verifier.expectString("void f() {\n}\n");
@@ -236,6 +250,14 @@
   EXPECT_TRUE(Verifier.match("void f() {}", stmt()));
 }
 
+TEST(DynTypedNode, ConceptReferencePrint) {
+  PrintVerifier Verifier;
+  Verifier.expectString("C");
+  EXPECT_TRUE(Verifier.match("template  concept C = true;\n"
+                             "auto X = C;",
+                             conceptReference()));
+}
+
 TEST(DynTypedNode, QualType) {
   QualType Q;
   DynTypedNode Node = DynTypedNode::create(Q);
diff --git a/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp
--- a/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp
+++ b/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp
@@ -770,6 +770,17 @@
                       cxxCtorInitializer(forField(hasName("i")))));
 }
 
+TEST_P(ASTMatchersTest, Match_ConceptReference) {
+  if (!GetParam().isCXX20OrLater()) {
+    return;
+  }
+  std::string Concept = "template  concept C = true;\n";
+  EXPECT_TRUE(matches(Concept + "auto X = C;", conceptReference()));
+  EXPECT_TRUE(matches(Concept + "C auto X = 0;", conceptReference()));
+  EXPECT_TRUE(matches(Concept + "template  int i;", conceptReference()));
+  EXPECT_TRUE(matches(Concept + "void foo(C auto X) {}", conceptReference()));
+}
+
 TEST_P(ASTMatchersTest, Matcher_ThisExpr) {
   if (!GetParam().isCXX()) {
     return;