diff --git a/mlir/include/mlir/Support/DebugAction.h b/mlir/include/mlir/Support/DebugAction.h --- a/mlir/include/mlir/Support/DebugAction.h +++ b/mlir/include/mlir/Support/DebugAction.h @@ -8,9 +8,7 @@ // // This file contains definitions for the debug action framework. This framework // allows for external entities to control certain actions taken by the compiler -// by registering handler functions. A debug action handler provides the -// internal implementation for the various queries on a debug action, such as -// whether it should execute or not. +// by registering handler functions. // //===----------------------------------------------------------------------===// @@ -29,6 +27,34 @@ namespace mlir { +/// This class represents the base class of a debug action. +class DebugActionBase { +public: + virtual ~DebugActionBase() = default; + + /// Return the unique action id of this action, use for casting + /// functionality. + TypeID getActionID() const { return actionID; } + + StringRef getTag() const { return tag; } + + StringRef getDescription() const { return desc; } + + virtual void print(raw_ostream &os) const { + os << "Action \"" << tag << "\" : " << desc << "\n"; + } + +protected: + DebugActionBase(TypeID actionID, StringRef tag, StringRef desc) + : actionID(actionID), tag(tag), desc(desc) {} + + /// The type of the derived action class. This allows for detecting the + /// specific handler of a given action type. + TypeID actionID; + StringRef tag; + StringRef desc; +}; + //===----------------------------------------------------------------------===// // DebugActionManager //===----------------------------------------------------------------------===// @@ -74,11 +100,11 @@ public: GenericHandler() : HandlerBase(TypeID::get()) {} - /// This hook allows for controlling whether an action should execute or - /// not. It should return failure if the handler could not process the - /// action, passing it to the next registered handler. - virtual FailureOr shouldExecute(StringRef actionTag, - StringRef description) { + /// This hook allows for controlling the execution of an action. It should + /// return failure if the handler could not process the action, or whether + /// the `transform` was executed or not. + virtual FailureOr execute(function_ref transform, + const DebugActionBase &action) { return failure(); } @@ -90,10 +116,7 @@ /// Register the given action handler with the manager. void registerActionHandler(std::unique_ptr handler) { - // The manager is always disabled if built without debug. -#if LLVM_ENABLE_ABI_BREAKING_CHECKS actionHandlers.emplace_back(std::move(handler)); -#endif } template void registerActionHandler() { @@ -104,31 +127,35 @@ // Action Queries //===--------------------------------------------------------------------===// - /// Returns true if the given action type should be executed, false otherwise. - /// `Args` are a set of parameters used by handlers of `ActionType` to - /// determine if the action should be executed. + /// Dispatch an action represented by the `transform` callback. If no handler + /// is found, the `transform` callback is invoked directly. + /// Return true if the action was executed, false otherwise. template - bool shouldExecute(Args &&...args) { - // The manager is always disabled if built without debug. -#if !LLVM_ENABLE_ABI_BREAKING_CHECKS - return true; -#else - // Invoke the `shouldExecute` method on the provided handler. - auto shouldExecuteFn = [&](auto *handler, auto &&...handlerParams) { - return handler->shouldExecute( - std::forward(handlerParams)...); + bool execute(function_ref transform, Args &&...args) { + if (actionHandlers.empty()) { + transform(); + return true; + } + + // Invoke the `execute` method on the provided handler. + auto executeFn = [&](auto *handler, auto &&...handlerParams) { + return handler->execute( + transform, + ActionType(std::forward(handlerParams)...)); }; FailureOr result = dispatchToHandler( - shouldExecuteFn, std::forward(args)...); + executeFn, std::forward(args)...); + // no handler found, execute the transform directly. + if (failed(result)) { + transform(); + return true; + } - // If the action wasn't handled, execute the action by default. - return succeeded(result) ? *result : true; -#endif + // Return the result of the handler. + return *result; } private: -// The manager is always disabled if built without debug. -#if LLVM_ENABLE_ABI_BREAKING_CHECKS //===--------------------------------------------------------------------===// // Query to Handler Dispatch //===--------------------------------------------------------------------===// @@ -145,16 +172,13 @@ "cannot execute action with the given set of parameters"); // Process any generic or action specific handlers. - // TODO: We currently just pick the first handler that gives us a result, - // but in the future we may want to employ a reduction over all of the - // values returned. - for (std::unique_ptr &it : llvm::reverse(actionHandlers)) { + // The first handler that gives us a result is the one that we will return. + for (std::unique_ptr &it : reverse(actionHandlers)) { FailureOr result = failure(); if (auto *handler = dyn_cast(&*it)) { result = handlerCallback(handler, std::forward(args)...); } else if (auto *genericHandler = dyn_cast(&*it)) { - result = handlerCallback(genericHandler, ActionType::getTag(), - ActionType::getDescription()); + result = handlerCallback(genericHandler, std::forward(args)...); } // If the handler succeeded, return the result. Otherwise, try a new @@ -167,7 +191,6 @@ /// The set of action handlers that have been registered with the manager. SmallVector> actionHandlers; -#endif }; //===----------------------------------------------------------------------===// @@ -191,17 +214,27 @@ /// instances of this action. The parameters to its query methods map 1-1 to the /// types on the action type. template -class DebugAction { +class DebugAction : public DebugActionBase { public: + DebugAction() + : DebugActionBase(TypeID::get(), Derived::getTag(), + Derived::getDescription()) {} + + /// Provide classof to allow casting between action types. + static bool classof(const DebugActionBase *action) { + return action->getActionID() == TypeID::get(); + } + class Handler : public DebugActionManager::HandlerBase { public: Handler() : HandlerBase(TypeID::get()) {} - /// This hook allows for controlling whether an action should execute or - /// not. `parameters` correspond to the set of values provided by the + /// This hook allows for controlling the execution of an action. + /// `parameters` correspond to the set of values provided by the /// action as context. It should return failure if the handler could not /// process the action, passing it to the next registered handler. - virtual FailureOr shouldExecute(ParameterTs... parameters) { + virtual FailureOr execute(function_ref transform, + const Derived &action) { return failure(); } diff --git a/mlir/include/mlir/Support/DebugCounter.h b/mlir/include/mlir/Support/DebugCounter.h --- a/mlir/include/mlir/Support/DebugCounter.h +++ b/mlir/include/mlir/Support/DebugCounter.h @@ -38,7 +38,8 @@ int64_t countToStopAfter); /// Register a counter with the specified name. - FailureOr shouldExecute(StringRef tag, StringRef description) final; + FailureOr execute(llvm::function_ref transform, + const DebugActionBase &action) final; /// Print the counters that have been registered with this instance to the /// provided output stream. diff --git a/mlir/lib/Support/DebugCounter.cpp b/mlir/lib/Support/DebugCounter.cpp --- a/mlir/lib/Support/DebugCounter.cpp +++ b/mlir/lib/Support/DebugCounter.cpp @@ -62,9 +62,9 @@ } // Register a counter with the specified name. -FailureOr DebugCounter::shouldExecute(StringRef tag, - StringRef description) { - auto counterIt = counters.find(tag); +FailureOr DebugCounter::execute(llvm::function_ref transform, + const DebugActionBase &action) { + auto counterIt = counters.find(action.getTag()); if (counterIt == counters.end()) return true; diff --git a/mlir/unittests/Support/DebugActionTest.cpp b/mlir/unittests/Support/DebugActionTest.cpp --- a/mlir/unittests/Support/DebugActionTest.cpp +++ b/mlir/unittests/Support/DebugActionTest.cpp @@ -10,9 +10,6 @@ #include "mlir/Support/TypeID.h" #include "gmock/gmock.h" -// DebugActionManager is only enabled in DEBUG mode. -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - using namespace mlir; namespace { @@ -30,6 +27,8 @@ }; struct ParametricAction : DebugAction { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ParametricAction) + ParametricAction(bool executeParam) : executeParam(executeParam) {} + bool executeParam; static StringRef getTag() { return "param-action"; } static StringRef getDescription() { return "param-action-description"; } }; @@ -40,21 +39,25 @@ // A generic handler that always executes the simple action, but not the // parametric action. struct GenericHandler : DebugActionManager::GenericHandler { - FailureOr shouldExecute(StringRef tag, StringRef desc) final { - if (tag == SimpleAction::getTag()) { + FailureOr execute(llvm::function_ref transform, + const DebugActionBase &action) final { + StringRef desc = action.getDescription(); + if (isa(action)) { EXPECT_EQ(desc, SimpleAction::getDescription()); + transform(); return true; } - EXPECT_EQ(tag, ParametricAction::getTag()); + EXPECT_TRUE(isa(action)); EXPECT_EQ(desc, ParametricAction::getDescription()); return false; } }; manager.registerActionHandler(); - EXPECT_TRUE(manager.shouldExecute()); - EXPECT_FALSE(manager.shouldExecute(true)); + auto noOp = []() { return; }; + EXPECT_TRUE(manager.execute(noOp)); + EXPECT_FALSE(manager.execute(noOp, true)); } TEST(DebugActionTest, ActionSpecificHandler) { @@ -62,17 +65,25 @@ // Handler that simply uses the input as the decider. struct ActionSpecificHandler : ParametricAction::Handler { - FailureOr shouldExecute(bool shouldExecuteParam) final { - return shouldExecuteParam; + FailureOr execute(llvm::function_ref transform, + const ParametricAction &action) final { + if (action.executeParam) + transform(); + return action.executeParam; } }; manager.registerActionHandler(); - EXPECT_TRUE(manager.shouldExecute(true)); - EXPECT_FALSE(manager.shouldExecute(false)); + int count = 0; + auto incCount = [&]() { count++; }; + EXPECT_TRUE(manager.execute(incCount, true)); + EXPECT_EQ(count, 1); + EXPECT_FALSE(manager.execute(incCount, false)); + EXPECT_EQ(count, 1); // There is no handler for the simple action, so it is always executed. - EXPECT_TRUE(manager.shouldExecute()); + EXPECT_TRUE(manager.execute(incCount)); + EXPECT_EQ(count, 2); } TEST(DebugActionTest, DebugCounterHandler) { @@ -80,17 +91,24 @@ // Handler that uses the number of action executions as the decider. struct DebugCounterHandler : SimpleAction::Handler { - FailureOr shouldExecute() final { return numExecutions++ < 3; } + FailureOr execute(llvm::function_ref transform, + const SimpleAction &action) final { + bool shouldExecute = numExecutions++ < 3; + if (shouldExecute) + transform(); + return shouldExecute; + } unsigned numExecutions = 0; }; manager.registerActionHandler(); // Check that the action is executed 3 times, but no more after. - EXPECT_TRUE(manager.shouldExecute()); - EXPECT_TRUE(manager.shouldExecute()); - EXPECT_TRUE(manager.shouldExecute()); - EXPECT_FALSE(manager.shouldExecute()); - EXPECT_FALSE(manager.shouldExecute()); + auto noOp = []() { return; }; + EXPECT_TRUE(manager.execute(noOp)); + EXPECT_TRUE(manager.execute(noOp)); + EXPECT_TRUE(manager.execute(noOp)); + EXPECT_FALSE(manager.execute(noOp)); + EXPECT_FALSE(manager.execute(noOp)); } TEST(DebugActionTest, NonOverlappingActionSpecificHandlers) { @@ -98,17 +116,24 @@ // One handler returns true and another returns false struct SimpleActionHandler : SimpleAction::Handler { - FailureOr shouldExecute() final { return true; } + FailureOr execute(llvm::function_ref transform, + const SimpleAction &action) final { + transform(); + return true; + } }; struct OtherSimpleActionHandler : OtherSimpleAction::Handler { - FailureOr shouldExecute() final { return false; } + FailureOr execute(llvm::function_ref transform, + const OtherSimpleAction &action) final { + transform(); + return false; + } }; manager.registerActionHandler(); manager.registerActionHandler(); - EXPECT_TRUE(manager.shouldExecute()); - EXPECT_FALSE(manager.shouldExecute()); + auto noOp = []() { return; }; + EXPECT_TRUE(manager.execute(noOp)); + EXPECT_FALSE(manager.execute(noOp)); } } // namespace - -#endif diff --git a/mlir/unittests/Support/DebugCounterTest.cpp b/mlir/unittests/Support/DebugCounterTest.cpp --- a/mlir/unittests/Support/DebugCounterTest.cpp +++ b/mlir/unittests/Support/DebugCounterTest.cpp @@ -12,9 +12,6 @@ using namespace mlir; -// DebugActionManager is only enabled in DEBUG mode. -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - namespace { struct CounterAction : public DebugAction { @@ -31,16 +28,16 @@ DebugActionManager manager; manager.registerActionHandler(std::move(counter)); + auto noOp = []() { return; }; + // The first execution is skipped. - EXPECT_FALSE(manager.shouldExecute()); + EXPECT_FALSE(manager.execute(noOp)); // The counter stops after 3 successful executions. - EXPECT_TRUE(manager.shouldExecute()); - EXPECT_TRUE(manager.shouldExecute()); - EXPECT_TRUE(manager.shouldExecute()); - EXPECT_FALSE(manager.shouldExecute()); + EXPECT_TRUE(manager.execute(noOp)); + EXPECT_TRUE(manager.execute(noOp)); + EXPECT_TRUE(manager.execute(noOp)); + EXPECT_FALSE(manager.execute(noOp)); } } // namespace - -#endif