diff --git a/mlir/include/mlir/Support/DebugAction.h b/mlir/include/mlir/Support/DebugAction.h --- a/mlir/include/mlir/Support/DebugAction.h +++ b/mlir/include/mlir/Support/DebugAction.h @@ -190,14 +190,12 @@ /// This class provides a handler class that can be derived from to handle /// instances of this action. The parameters to its query methods map 1-1 to the /// types on the action type. -template +template class DebugAction { public: class Handler : public DebugActionManager::HandlerBase { public: - Handler() - : HandlerBase( - TypeID::get::Handler>()) {} + Handler() : HandlerBase(TypeID::get()) {} /// This hook allows for controlling whether an action should execute or /// not. `parameters` correspond to the set of values provided by the @@ -209,8 +207,7 @@ /// Provide classof to allow casting between handler types. static bool classof(const DebugActionManager::HandlerBase *handler) { - return handler->getHandlerID() == - TypeID::get::Handler>(); + return handler->getHandlerID() == TypeID::get(); } }; diff --git a/mlir/unittests/Support/DebugActionTest.cpp b/mlir/unittests/Support/DebugActionTest.cpp --- a/mlir/unittests/Support/DebugActionTest.cpp +++ b/mlir/unittests/Support/DebugActionTest.cpp @@ -15,11 +15,17 @@ using namespace mlir; namespace { -struct SimpleAction : public DebugAction<> { +struct SimpleAction : DebugAction { static StringRef getTag() { return "simple-action"; } static StringRef getDescription() { return "simple-action-description"; } }; -struct ParametricAction : public DebugAction { +struct OtherSimpleAction : DebugAction { + static StringRef getTag() { return "other-simple-action"; } + static StringRef getDescription() { + return "other-simple-action-description"; + } +}; +struct ParametricAction : DebugAction { static StringRef getTag() { return "param-action"; } static StringRef getDescription() { return "param-action-description"; } }; @@ -29,7 +35,7 @@ // A generic handler that always executes the simple action, but not the // parametric action. - struct GenericHandler : public DebugActionManager::GenericHandler { + struct GenericHandler : DebugActionManager::GenericHandler { FailureOr shouldExecute(StringRef tag, StringRef desc) final { if (tag == SimpleAction::getTag()) { EXPECT_EQ(desc, SimpleAction::getDescription()); @@ -51,7 +57,7 @@ DebugActionManager manager; // Handler that simply uses the input as the decider. - struct ActionSpecificHandler : public ParametricAction::Handler { + struct ActionSpecificHandler : ParametricAction::Handler { FailureOr shouldExecute(bool shouldExecuteParam) final { return shouldExecuteParam; } @@ -69,7 +75,7 @@ DebugActionManager manager; // Handler that uses the number of action executions as the decider. - struct DebugCounterHandler : public SimpleAction::Handler { + struct DebugCounterHandler : SimpleAction::Handler { FailureOr shouldExecute() final { return numExecutions++ < 3; } unsigned numExecutions = 0; }; @@ -83,6 +89,22 @@ EXPECT_FALSE(manager.shouldExecute()); } +TEST(DebugActionTest, NonOverlappingActionSpecificHandlers) { + DebugActionManager manager; + + // One handler returns true and another returns false + struct SimpleActionHandler : SimpleAction::Handler { + FailureOr shouldExecute() final { return true; } + }; + struct OtherSimpleActionHandler : OtherSimpleAction::Handler { + FailureOr shouldExecute() final { return false; } + }; + manager.registerActionHandler(); + manager.registerActionHandler(); + EXPECT_TRUE(manager.shouldExecute()); + EXPECT_FALSE(manager.shouldExecute()); +} + } // namespace #endif diff --git a/mlir/unittests/Support/DebugCounterTest.cpp b/mlir/unittests/Support/DebugCounterTest.cpp --- a/mlir/unittests/Support/DebugCounterTest.cpp +++ b/mlir/unittests/Support/DebugCounterTest.cpp @@ -16,7 +16,7 @@ namespace { -struct CounterAction : public DebugAction<> { +struct CounterAction : public DebugAction { static StringRef getTag() { return "counter-action"; } static StringRef getDescription() { return "Test action for debug counters"; } };