diff --git a/mlir/docs/DebugActions.md b/mlir/docs/DebugActions.md --- a/mlir/docs/DebugActions.md +++ b/mlir/docs/DebugActions.md @@ -54,10 +54,12 @@ /// * The Tag is specified via a static `StringRef getTag()` method. /// * The Description is specified via a static `StringRef getDescription()` /// method. -/// * The parameters for the action are provided via template parameters when -/// inheriting from `DebugAction`. +/// * `DebugAction` is a CRTP class, so the first template parameter is the +/// action type class itself. +/// * The parameters for the action are provided via additional template +/// parameters when inheriting from `DebugAction`. struct ApplyPatternAction - : public DebugAction { + : public DebugAction { static StringRef getTag() { return "apply-pattern"; } static StringRef getDescription() { return "Control the application of rewrite patterns"; @@ -95,7 +97,7 @@ ```c++ /// A debug action that allows for controlling the application of patterns. struct ApplyPatternAction - : public DebugAction { + : public DebugAction { static StringRef getTag() { return "apply-pattern"; } static StringRef getDescription() { return "Control the application of rewrite patterns"; diff --git a/mlir/include/mlir/Support/DebugAction.h b/mlir/include/mlir/Support/DebugAction.h --- a/mlir/include/mlir/Support/DebugAction.h +++ b/mlir/include/mlir/Support/DebugAction.h @@ -190,14 +190,12 @@ /// This class provides a handler class that can be derived from to handle /// instances of this action. The parameters to its query methods map 1-1 to the /// types on the action type. -template +template class DebugAction { public: class Handler : public DebugActionManager::HandlerBase { public: - Handler() - : HandlerBase( - TypeID::get::Handler>()) {} + Handler() : HandlerBase(TypeID::get()) {} /// This hook allows for controlling whether an action should execute or /// not. `parameters` correspond to the set of values provided by the @@ -209,8 +207,7 @@ /// Provide classof to allow casting between handler types. static bool classof(const DebugActionManager::HandlerBase *handler) { - return handler->getHandlerID() == - TypeID::get::Handler>(); + return handler->getHandlerID() == TypeID::get(); } }; diff --git a/mlir/unittests/Support/DebugActionTest.cpp b/mlir/unittests/Support/DebugActionTest.cpp --- a/mlir/unittests/Support/DebugActionTest.cpp +++ b/mlir/unittests/Support/DebugActionTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Support/DebugAction.h" +#include "mlir/Support/TypeID.h" #include "gmock/gmock.h" // DebugActionManager is only enabled in DEBUG mode. @@ -15,11 +16,20 @@ using namespace mlir; namespace { -struct SimpleAction : public DebugAction<> { +struct SimpleAction : DebugAction { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SimpleAction) static StringRef getTag() { return "simple-action"; } static StringRef getDescription() { return "simple-action-description"; } }; -struct ParametricAction : public DebugAction { +struct OtherSimpleAction : DebugAction { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OtherSimpleAction) + static StringRef getTag() { return "other-simple-action"; } + static StringRef getDescription() { + return "other-simple-action-description"; + } +}; +struct ParametricAction : DebugAction { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ParametricAction) static StringRef getTag() { return "param-action"; } static StringRef getDescription() { return "param-action-description"; } }; @@ -29,7 +39,7 @@ // A generic handler that always executes the simple action, but not the // parametric action. - struct GenericHandler : public DebugActionManager::GenericHandler { + struct GenericHandler : DebugActionManager::GenericHandler { FailureOr shouldExecute(StringRef tag, StringRef desc) final { if (tag == SimpleAction::getTag()) { EXPECT_EQ(desc, SimpleAction::getDescription()); @@ -51,7 +61,7 @@ DebugActionManager manager; // Handler that simply uses the input as the decider. - struct ActionSpecificHandler : public ParametricAction::Handler { + struct ActionSpecificHandler : ParametricAction::Handler { FailureOr shouldExecute(bool shouldExecuteParam) final { return shouldExecuteParam; } @@ -69,7 +79,7 @@ DebugActionManager manager; // Handler that uses the number of action executions as the decider. - struct DebugCounterHandler : public SimpleAction::Handler { + struct DebugCounterHandler : SimpleAction::Handler { FailureOr shouldExecute() final { return numExecutions++ < 3; } unsigned numExecutions = 0; }; @@ -83,6 +93,22 @@ EXPECT_FALSE(manager.shouldExecute()); } +TEST(DebugActionTest, NonOverlappingActionSpecificHandlers) { + DebugActionManager manager; + + // One handler returns true and another returns false + struct SimpleActionHandler : SimpleAction::Handler { + FailureOr shouldExecute() final { return true; } + }; + struct OtherSimpleActionHandler : OtherSimpleAction::Handler { + FailureOr shouldExecute() final { return false; } + }; + manager.registerActionHandler(); + manager.registerActionHandler(); + EXPECT_TRUE(manager.shouldExecute()); + EXPECT_FALSE(manager.shouldExecute()); +} + } // namespace #endif diff --git a/mlir/unittests/Support/DebugCounterTest.cpp b/mlir/unittests/Support/DebugCounterTest.cpp --- a/mlir/unittests/Support/DebugCounterTest.cpp +++ b/mlir/unittests/Support/DebugCounterTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Support/DebugCounter.h" +#include "mlir/Support/TypeID.h" #include "gmock/gmock.h" using namespace mlir; @@ -16,7 +17,8 @@ namespace { -struct CounterAction : public DebugAction<> { +struct CounterAction : public DebugAction { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CounterAction) static StringRef getTag() { return "counter-action"; } static StringRef getDescription() { return "Test action for debug counters"; } };