diff --git a/mlir/include/mlir/Debug/Counter.h b/mlir/include/mlir/Debug/Counter.h --- a/mlir/include/mlir/Debug/Counter.h +++ b/mlir/include/mlir/Debug/Counter.h @@ -26,10 +26,10 @@ /// a counter for an action with `skip=47` and `count=2`, would skip the /// first 47 executions, then execute twice, and finally prevent any further /// executions. -class DebugCounter : public ActionManager::GenericHandler { +class DebugCounter { public: DebugCounter(); - ~DebugCounter() override; + ~DebugCounter(); /// Add a counter for the given action tag. `countToSkip` is the number /// of counter executions to skip before enabling execution of the action. @@ -38,9 +38,8 @@ void addCounter(StringRef actionTag, int64_t countToSkip, int64_t countToStopAfter); - /// Register a counter with the specified name. - FailureOr execute(llvm::function_ref transform, - const Action &action) final; + /// Entry point for handling actions. + void operator()(llvm::function_ref transform, const Action &action); /// Print the counters that have been registered with this instance to the /// provided output stream. @@ -50,6 +49,9 @@ static void registerCLOptions(); private: + // Returns true if the next action matching this tag should be executed. + bool shouldExecute(StringRef tag); + /// Apply the registered CL options to this debug counter instance. void applyCLOptions(); diff --git a/mlir/include/mlir/IR/Action.h b/mlir/include/mlir/IR/Action.h --- a/mlir/include/mlir/IR/Action.h +++ b/mlir/include/mlir/IR/Action.h @@ -35,9 +35,7 @@ /// An action is a specific action that is to be taken by the compiler, /// that can be toggled and controlled by an external user. There are no /// constraints on the granularity of an action, it could be as simple as -/// "perform this fold" and as complex as "run this pass pipeline". Via template -/// parameters `ParameterTs`, a user may provide the set of argument types that -/// are provided when handling a query on this action. +/// "perform this fold" and as complex as "run this pass pipeline". /// /// This class represents the base class of the ActionImpl class (see below). /// This holds the template-invariant elements of the Action class. @@ -72,143 +70,6 @@ TypeID actionID; }; -//===----------------------------------------------------------------------===// -// ActionManager -//===----------------------------------------------------------------------===// - -/// This class represents manages actions, and orchestrates the -/// communication between action queries and action handlers. An action handler -/// is either an action specific handler, i.e. a derived class of -/// `MyActionType::Handler`, or a generic handler, i.e. a derived class of -/// `ActionManager::GenericHandler`. For more details on action specific -/// handlers, see the definition of `Action::Handler` below. For more -/// details on generic handlers, see `ActionManager::GenericHandler` below. -class ActionManager { -public: - //===--------------------------------------------------------------------===// - // Handlers - //===--------------------------------------------------------------------===// - - /// This class represents the base class of an action handler. - class HandlerBase { - public: - virtual ~HandlerBase() = default; - - /// Return the unique handler id of this handler, use for casting - /// functionality. - TypeID getHandlerID() const { return handlerID; } - - protected: - HandlerBase(TypeID handlerID) : handlerID(handlerID) {} - - /// The type of the derived handler class. This allows for detecting if a - /// handler can handle a given action type. - TypeID handlerID; - }; - - /// This class represents a generic action handler. A generic handler allows - /// for handling any action type. Handlers of this type are useful for - /// implementing general functionality that doesn't necessarily need to - /// interpret the exact action parameters, or can rely on an external - /// interpreter (such as the user). Given that these handlers are generic, - /// they take a set of opaque parameters that try to map the context of the - /// action type in a generic way. - class GenericHandler : public HandlerBase { - public: - GenericHandler() : HandlerBase(TypeID::get()) {} - - /// This hook allows for controlling the execution of an action. It should - /// return failure if the handler could not process the action, or whether - /// the `transform` was executed or not. - virtual FailureOr execute(llvm::function_ref transform, - const Action &action) { - return failure(); - } - - /// Provide classof to allow casting between handler types. - static bool classof(const ActionManager::HandlerBase *handler) { - return handler->getHandlerID() == TypeID::get(); - } - }; - - /// Register the given action handler with the manager. - void registerActionHandler(std::unique_ptr handler) { - actionHandlers.emplace_back(std::move(handler)); - } - template - void registerActionHandler() { - registerActionHandler(std::make_unique()); - } - - //===--------------------------------------------------------------------===// - // Action Queries - //===--------------------------------------------------------------------===// - - /// Dispatch an action represented by the `transform` callback. If no handler - /// is found, the `transform` callback is invoked directly. - /// Return true if the action was executed, false otherwise. - template - bool execute(llvm::function_ref transform, Args &&...args) { - if (actionHandlers.empty()) { - transform(); - return true; - } - - // Invoke the `execute` method on the provided handler. - auto executeFn = [&](auto *handler, auto &&...handlerParams) { - return handler->execute( - transform, - ActionType(std::forward(handlerParams)...)); - }; - FailureOr result = dispatchToHandler( - executeFn, std::forward(args)...); - if (failed(result)) { - transform(); - return true; - } - - // Return the result of the handler. - return *result; - } - -private: - //===--------------------------------------------------------------------===// - // Query to Handler Dispatch - //===--------------------------------------------------------------------===// - - /// Dispath a given callback on any handlers that are able to process queries - /// on the given action type. This method returns failure if no handlers could - /// process the action, or success(with a result) if a handler processed the - /// action. - template - FailureOr dispatchToHandler(HandlerCallbackT &&handlerCallback, - Args &&...args) { - static_assert(ActionType::template canHandleWith(), - "cannot execute action with the given set of parameters"); - - // Process any generic or action specific handlers. - // The first handler that gives us a result is the one that we will return. - for (std::unique_ptr &it : llvm::reverse(actionHandlers)) { - FailureOr result = failure(); - if (auto *handler = dyn_cast(&*it)) { - result = handlerCallback(handler, std::forward(args)...); - } else if (auto *genericHandler = dyn_cast(&*it)) { - result = handlerCallback(genericHandler, std::forward(args)...); - } - - // If the handler succeeded, return the result. Otherwise, try a new - // handler. - if (succeeded(result)) - return result; - } - return failure(); - } - - /// The set of action handlers that have been registered with the manager. - SmallVector> actionHandlers; -}; - /// CRTP Implementation of an action. This class provides a base class for /// implementing specific actions. /// @@ -216,7 +77,7 @@ /// * static constexpr StringLiteral tag = "..."; /// - This method returns a unique string identifier, similar to a command /// line flag or DEBUG_TYPE. -template +template class ActionImpl : public Action { public: @@ -229,40 +90,9 @@ /// Forward tag access to the derived class. StringRef getTag() const final { return Derived::tag; } - - class Handler : public ActionManager::HandlerBase { - public: - Handler() : HandlerBase(TypeID::get()) {} - - /// This hook allows for controlling the execution of an action. - /// `parameters` correspond to the set of values provided by the - /// action as context. It should return failure if the handler could not - /// process the action, passing it to the next registered handler. - virtual FailureOr execute(llvm::function_ref transform, - const Derived &action) { - return failure(); - } - - /// Provide classof to allow casting between handler types. - static bool classof(const ActionManager::HandlerBase *handler) { - return handler->getHandlerID() == TypeID::get(); - } - }; - -private: - /// Returns true if the action can be handled within the given set of - /// parameter types. - template - static constexpr bool canHandleWith() { - return std::is_invocable_v, - CallerParameterTs...>; - } - - /// Allow access to `canHandleWith`. - friend class ActionManager; }; } // namespace tracing } // namespace mlir #endif // MLIR_IR_ACTION_H \ No newline at end of file diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -21,7 +21,7 @@ namespace mlir { namespace tracing { -class ActionManager; +class Action; } class DiagnosticEngine; class Dialect; @@ -213,8 +213,33 @@ /// instances. This should not be used directly. StorageUniquer &getAttributeUniquer(); - /// Returns the manager of debug actions within the context. - tracing::ActionManager &getActionManager(); + /// Signatures for the action handler that can be registered with the context. + using HandlerTy = + std::function, const tracing::Action &)>; + + /// Register a handler for handling actions that are dispatched through this + /// context. A nullptr handler can be set to disable a previously set handler. + void registerActionHandler(HandlerTy handler); + + /// Return true if a valid ActionHandler is set. + bool hasActionHandler(); + + /// Dispatch the provided action to the handler if any, or just execute it. + void dispatch(function_ref actionFn, const tracing::Action &action) { + if (LLVM_UNLIKELY(hasActionHandler())) + dispatchInternal(actionFn, action); + else + actionFn(); + } + + /// Dispatch the provided action to the handler if any, or just execute it. + template + void dispatch(function_ref actionFn, Args &&...args) { + if (LLVM_UNLIKELY(hasActionHandler())) + dispatchInteral(actionFn, ActionTy(std::forward(args)...)); + else + actionFn(); + } /// These APIs are tracking whether the context will be used in a /// multithreading environment: this has no effect other than enabling @@ -242,6 +267,10 @@ /// Return true if the given dialect is currently loading. bool isDialectLoading(StringRef dialectNamespace); + /// Internal helper for the dispatch method. + void dispatchInternal(function_ref actionFn, + const tracing::Action &action); + const std::unique_ptr impl; MLIRContext(const MLIRContext &) = delete; diff --git a/mlir/lib/Debug/DebugCounter.cpp b/mlir/lib/Debug/DebugCounter.cpp --- a/mlir/lib/Debug/DebugCounter.cpp +++ b/mlir/lib/Debug/DebugCounter.cpp @@ -62,10 +62,14 @@ counters.try_emplace(actionTag, countToSkip, countToStopAfter); } -// Register a counter with the specified name. -FailureOr DebugCounter::execute(llvm::function_ref transform, - const Action &action) { - auto counterIt = counters.find(action.getTag()); +void DebugCounter::operator()(llvm::function_ref transform, + const Action &action) { + if (shouldExecute(action.getTag())) + transform(); +} + +bool DebugCounter::shouldExecute(StringRef tag) { + auto counterIt = counters.find(tag); if (counterIt == counters.end()) return true; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Mutex.h" #include "llvm/Support/RWMutex.h" @@ -123,8 +124,10 @@ // Debugging //===--------------------------------------------------------------------===// - /// An action manager for use within the context. - tracing::ActionManager actionManager; + /// An action handler for handling actions that are dispatched through this + /// context. + std::function, const tracing::Action &)> + actionHandler; //===--------------------------------------------------------------------===// // Diagnostics @@ -345,11 +348,22 @@ } //===----------------------------------------------------------------------===// -// Debugging +// Action Handling //===----------------------------------------------------------------------===// -tracing::ActionManager &MLIRContext::getActionManager() { - return getImpl().actionManager; +void MLIRContext::registerActionHandler(HandlerTy handler) { + getImpl().actionHandler = std::move(handler); +} + +/// Dispatch the provided action to the handler if any, or just execute it. +void MLIRContext::dispatchInternal(function_ref actionFn, + const tracing::Action &action) { + assert(getImpl().actionHandler); + getImpl().actionHandler(actionFn, action); +} + +bool MLIRContext::hasActionHandler() { + return (bool)getImpl().actionHandler; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -207,7 +207,7 @@ context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects()); if (config.shouldVerifyDiagnostics()) context.printOpOnDiagnostic(false); - context.getActionManager().registerActionHandler(); + context.registerActionHandler(tracing::DebugCounter{}); // If we are in verify diagnostics mode then we have a lot of work to do, // otherwise just perform the actions without worrying about it. diff --git a/mlir/unittests/Debug/DebugCounterTest.cpp b/mlir/unittests/Debug/DebugCounterTest.cpp --- a/mlir/unittests/Debug/DebugCounterTest.cpp +++ b/mlir/unittests/Debug/DebugCounterTest.cpp @@ -21,23 +21,29 @@ }; TEST(DebugCounterTest, CounterTest) { - std::unique_ptr counter = std::make_unique(); - counter->addCounter(CounterAction{}.getTag(), /*countToSkip=*/1, - /*countToStopAfter=*/3); + DebugCounter counter; + counter.addCounter(CounterAction{}.getTag(), /*countToSkip=*/1, + /*countToStopAfter=*/3); - ActionManager manager; - manager.registerActionHandler(std::move(counter)); - - auto noOp = []() { return; }; + int count = 0; + auto noOp = [&]() { + ++count; + return; + }; // The first execution is skipped. - EXPECT_FALSE(manager.execute(noOp)); + counter(noOp, CounterAction{}); + EXPECT_EQ(count, 0); // The counter stops after 3 successful executions. - EXPECT_TRUE(manager.execute(noOp)); - EXPECT_TRUE(manager.execute(noOp)); - EXPECT_TRUE(manager.execute(noOp)); - EXPECT_FALSE(manager.execute(noOp)); + counter(noOp, CounterAction{}); + EXPECT_EQ(count, 1); + counter(noOp, CounterAction{}); + EXPECT_EQ(count, 2); + counter(noOp, CounterAction{}); + EXPECT_EQ(count, 3); + counter(noOp, CounterAction{}); + EXPECT_EQ(count, 3); } } // namespace diff --git a/mlir/unittests/IR/ActionTest.cpp b/mlir/unittests/IR/ActionTest.cpp deleted file mode 100644 --- a/mlir/unittests/IR/ActionTest.cpp +++ /dev/null @@ -1,134 +0,0 @@ -//===- ActionTest.cpp - Debug Action Tests ---------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/Action.h" -#include "mlir/Support/TypeID.h" -#include "gmock/gmock.h" - -using namespace mlir; -using namespace mlir::tracing; - -namespace { -struct SimpleAction : ActionImpl { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SimpleAction) - static constexpr StringLiteral tag = "simple-action"; -}; -struct OtherSimpleAction : ActionImpl { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OtherSimpleAction) - static constexpr StringLiteral tag = "other-simple-action"; -}; -struct ParametricAction : ActionImpl { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ParametricAction) - ParametricAction(bool executeParam) : executeParam(executeParam) {} - bool executeParam; - static constexpr StringLiteral tag = "param-action"; -}; - -TEST(ActionTest, GenericHandler) { - ActionManager manager; - - // A generic handler that always executes the simple action, but not the - // parametric action. - struct GenericHandler : ActionManager::GenericHandler { - FailureOr execute(llvm::function_ref transform, - const Action &action) final { - StringRef tag = action.getTag(); - if (isa(action)) { - EXPECT_EQ(tag, SimpleAction{}.getTag()); - transform(); - return true; - } - - EXPECT_TRUE(isa(action)); - return false; - } - }; - manager.registerActionHandler(); - - auto noOp = []() { return; }; - EXPECT_TRUE(manager.execute(noOp)); - EXPECT_FALSE(manager.execute(noOp, true)); -} - -TEST(ActionTest, ActionSpecificHandler) { - ActionManager manager; - - // Handler that simply uses the input as the decider. - struct ActionSpecificHandler : ParametricAction::Handler { - FailureOr execute(llvm::function_ref transform, - const ParametricAction &action) final { - if (action.executeParam) - transform(); - return action.executeParam; - } - }; - manager.registerActionHandler(); - - int count = 0; - auto incCount = [&]() { count++; }; - EXPECT_TRUE(manager.execute(incCount, true)); - EXPECT_EQ(count, 1); - EXPECT_FALSE(manager.execute(incCount, false)); - EXPECT_EQ(count, 1); - - // There is no handler for the simple action, so it is always executed. - EXPECT_TRUE(manager.execute(incCount)); - EXPECT_EQ(count, 2); -} - -TEST(ActionTest, DebugCounterHandler) { - ActionManager manager; - - // Handler that uses the number of action executions as the decider. - struct DebugCounterHandler : SimpleAction::Handler { - FailureOr execute(llvm::function_ref transform, - const SimpleAction &action) final { - bool shouldExecute = numExecutions++ < 3; - if (shouldExecute) - transform(); - return shouldExecute; - } - unsigned numExecutions = 0; - }; - manager.registerActionHandler(); - - // Check that the action is executed 3 times, but no more after. - auto noOp = []() { return; }; - EXPECT_TRUE(manager.execute(noOp)); - EXPECT_TRUE(manager.execute(noOp)); - EXPECT_TRUE(manager.execute(noOp)); - EXPECT_FALSE(manager.execute(noOp)); - EXPECT_FALSE(manager.execute(noOp)); -} - -TEST(ActionTest, NonOverlappingActionSpecificHandlers) { - ActionManager manager; - - // One handler returns true and another returns false - struct SimpleActionHandler : SimpleAction::Handler { - FailureOr execute(llvm::function_ref transform, - const SimpleAction &action) final { - transform(); - return true; - } - }; - struct OtherSimpleActionHandler : OtherSimpleAction::Handler { - FailureOr execute(llvm::function_ref transform, - const OtherSimpleAction &action) final { - transform(); - return false; - } - }; - manager.registerActionHandler(); - manager.registerActionHandler(); - auto noOp = []() { return; }; - EXPECT_TRUE(manager.execute(noOp)); - EXPECT_FALSE(manager.execute(noOp)); -} - -} // namespace diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_unittest(MLIRIRTests - ActionTest.cpp AdaptorTest.cpp AttributeTest.cpp DialectTest.cpp