diff --git a/mlir/include/mlir/Debug/BreakpointManager.h b/mlir/include/mlir/Debug/BreakpointManager.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Debug/BreakpointManager.h @@ -0,0 +1,95 @@ +//===- BreakpointManager.h - Breakpoint Manager Support ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRACING_BREAKPOINTMANAGER_H +#define MLIR_TRACING_BREAKPOINTMANAGER_H + +#include "mlir/IR/Action.h" +#include "llvm/ADT/MapVector.h" + +namespace mlir { +namespace tracing { + +/// This abstract class represents a breakpoint. +class Breakpoint { +public: + virtual ~Breakpoint() = default; + + /// TypeID for the subclass, used for casting purpose. + TypeID getTypeID() const { return typeID; } + + bool isEnabled() const { return enableStatus; } + void enable() { enableStatus = true; } + void disable() { enableStatus = false; } + virtual void print(raw_ostream &os) const = 0; + +protected: + Breakpoint(TypeID typeID) : enableStatus(true), typeID(typeID) {} + +private: + /// The current state of the breakpoint. A breakpoint can be either enabled + /// or disabled. + bool enableStatus; + TypeID typeID; +}; + +inline raw_ostream &operator<<(raw_ostream &os, const Breakpoint &breakpoint) { + breakpoint.print(os); + return os; +} + +/// This class provides a CRTP wrapper around a base breakpoint class to define +/// a few necessary utility methods. +template +class BreakpointBase : public Breakpoint { +public: + /// Support isa/dyn_cast functionality for the derived pass class. + static bool classof(const Breakpoint *breakpoint) { + return breakpoint->getTypeID() == TypeID::get(); + } + +protected: + BreakpointBase() : Breakpoint(TypeID::get()) {} +}; + +/// A breakpoint manager is responsible for managing a set of breakpoints and +/// matching them to a given action. +class BreakpointManager { +public: + virtual ~BreakpointManager() = default; + + /// TypeID for the subclass, used for casting purpose. + TypeID getTypeID() const { return typeID; } + + /// Try to match a Breakpoint to a given Action. If there is a match and + /// the breakpoint is enabled, return the breakpoint. Otherwise, return + /// nullptr. + virtual Breakpoint *match(const Action &action) const = 0; + +protected: + BreakpointManager(TypeID typeID) : typeID(typeID) {} + + TypeID typeID; +}; + +/// CRTP base class for BreakpointManager implementations. +template +class BreakpointManagerBase : public BreakpointManager { +public: + BreakpointManagerBase() : BreakpointManager(TypeID::get()) {} + + /// Provide classof to allow casting between breakpoint manager types. + static bool classof(const BreakpointManager *breakpointManager) { + return breakpointManager->getTypeID() == TypeID::get(); + } +}; + +} // namespace tracing +} // namespace mlir + +#endif // MLIR_TRACING_BREAKPOINTMANAGER_H diff --git a/mlir/include/mlir/Debug/BreakpointManagers/TagBreakpointManager.h b/mlir/include/mlir/Debug/BreakpointManagers/TagBreakpointManager.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Debug/BreakpointManagers/TagBreakpointManager.h @@ -0,0 +1,65 @@ +//===- TagBreakpointManager.h - Simple breakpoint Support -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H +#define MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H + +#include "mlir/Debug/BreakpointManager.h" +#include "mlir/Debug/ExecutionContext.h" +#include "mlir/IR/Action.h" +#include "llvm/ADT/MapVector.h" + +namespace mlir { +namespace tracing { + +/// Simple breakpoint matching an action "tag". +class TagBreakpoint : public BreakpointBase { +public: + TagBreakpoint(StringRef tag) : tag(tag) {} + + void print(raw_ostream &os) const override { os << "Tag: `" << tag << '`'; } + +private: + /// A tag to associate the TagBreakpoint with. + std::string tag; + + /// Allow access to `tag`. + friend class TagBreakpointManager; +}; + +/// This is a manager to store a collection of breakpoints that trigger +/// on tags. +class TagBreakpointManager + : public BreakpointManagerBase { +public: + Breakpoint *match(const Action &action) const override { + auto it = breakpoints.find(action.getTag()); + if (it != breakpoints.end() && it->second->isEnabled()) + return it->second.get(); + return {}; + } + + /// Add a breakpoint to the manager for the given tag and return it. + /// If a breakpoint already exists for the given tag, return the existing + /// instance. + TagBreakpoint *addBreakpoint(StringRef tag) { + auto result = breakpoints.insert({tag, nullptr}); + auto &it = result.first; + if (result.second) + it->second = std::make_unique(tag.str()); + return it->second.get(); + } + +private: + llvm::StringMap> breakpoints; +}; + +} // namespace tracing +} // namespace mlir + +#endif // MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H diff --git a/mlir/include/mlir/Debug/ExecutionContext.h b/mlir/include/mlir/Debug/ExecutionContext.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Debug/ExecutionContext.h @@ -0,0 +1,132 @@ +//===- ExecutionContext.h - Execution Context Support *- C++ -*-=============// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRACING_EXECUTIONCONTEXT_H +#define MLIR_TRACING_EXECUTIONCONTEXT_H + +#include "mlir/Debug/BreakpointManager.h" +#include "mlir/IR/Action.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace tracing { + +/// This class is used to keep track of the active actions in the stack. +/// It provides the current action but also access to the parent entry in the +/// stack. This allows to keep track of the nested nature in which actions may +/// be executed. +struct ActionActiveStack { +public: + ActionActiveStack(const ActionActiveStack *parent, const Action &action, + int depth) + : parent(parent), action(action), depth(depth) {} + const ActionActiveStack *getParent() const { return parent; } + const Action &getAction() const { return action; } + int getDepth() const { return depth; } + +private: + const ActionActiveStack *parent; + const Action &action; + int depth; +}; + +/// The ExecutionContext is the main orchestration of the infrastructure, it +/// acts as a handler in the MLIRContext for executing an Action. When an action +/// is dispatched, it'll query its set of Breakpoints managers for a breakpoint +/// matching this action. If a breakpoint is hit, it passes the action and the +/// breakpoint information to a callback. The callback is responsible for +/// controlling the execution of the action through an enum value it returns. +/// Optionally, observers can be registered to be notified before and after the +/// callback is executed. +class ExecutionContext { +public: + /// Enum that allows the client of the context to control the execution of the + /// action. + /// - Apply: The action is executed. + /// - Skip: The action is skipped. + /// - Step: The action is executed and the execution is paused before the next + /// action, including for nested actions encountered before the + /// current action finishes. + /// - Next: The action is executed and the execution is paused after the + /// current action finishes before the next action. + /// - Finish: The action is executed and the execution is paused only when we + /// reach the parent/enclosing operation. If there are no enclosing + /// operation, the execution continues without stopping. + enum Control { Apply = 1, Skip = 2, Step = 3, Next = 4, Finish = 5 }; + + /// The type of the callback that is used to control the execution. + /// The callback is passed the current action. + using CallbackTy = function_ref; + + /// Create an ExecutionContext with a callback that is used to control the + /// execution. + ExecutionContext(CallbackTy callback) { setCallback(callback); } + ExecutionContext() = default; + + /// Set the callback that is used to control the execution. + void setCallback(CallbackTy callback); + + /// This abstract class defines the interface used to observe an Action + /// execution. It allows to be notified before and after the callback is + /// processed, but can't affect the execution. + struct Observer { + virtual ~Observer() = default; + /// This method is called before the Action is executed + /// If a breakpoint was hit, it is passed as an argument to the callback. + /// The `willExecute` argument indicates whether the action will be executed + /// or not. + /// Note that this method will be called from multiple threads concurrently + /// when MLIR multi-threading is enabled. + virtual void beforeExecute(const ActionActiveStack *action, + Breakpoint *breakpoint, bool willExecute) {} + + /// This method is called after the Action is executed, if it was executed. + /// It is not called if the action is skipped. + /// Note that this method will be called from multiple threads concurrently + /// when MLIR multi-threading is enabled. + virtual void afterExecute(const ActionActiveStack *action) {} + }; + + /// Register a new `Observer` on this context. It'll be notified before and + /// after executing an action. Note that this method is not thread-safe: it + /// isn't supported to add a new observer while actions may be executed. + void registerObserver(Observer *observer); + + /// Register a new `BreakpointManager` on this context. It'll have a chance to + /// match an action before it gets executed. Note that this method is not + /// thread-safe: it isn't supported to add a new manager while actions may be + /// executed. + void addBreakpointManager(BreakpointManager *manager) { + breakpoints.push_back(manager); + } + + /// Process the given action. This is the operator called by MLIRContext on + /// `executeAction()`. + void operator()(function_ref transform, const Action &action); + +private: + /// Callback that is executed when a breakpoint is hit and allows the client + /// to control the execution. + CallbackTy onBreakpointControlExecutionCallback; + + /// Next point to stop execution as describe by `Control` enum. + /// This is handle by indicating at which levels of depth the next + /// break should happen. + Optional depthToBreak; + + /// Observers that are notified before and after the callback is executed. + SmallVector observers; + + /// The list of managers that are queried for breakpoints. + SmallVector breakpoints; +}; + +} // namespace tracing +} // namespace mlir + +#endif // MLIR_TRACING_EXECUTIONCONTEXT_H diff --git a/mlir/lib/Debug/CMakeLists.txt b/mlir/lib/Debug/CMakeLists.txt --- a/mlir/lib/Debug/CMakeLists.txt +++ b/mlir/lib/Debug/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(MLIRDebug DebugCounter.cpp + ExecutionContext.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Debug diff --git a/mlir/lib/Debug/ExecutionContext.cpp b/mlir/lib/Debug/ExecutionContext.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Debug/ExecutionContext.cpp @@ -0,0 +1,97 @@ +//===- ExecutionContext.cpp - Debug Execution Context Support -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Debug/ExecutionContext.h" + +#include "llvm/ADT/ScopeExit.h" + +#include + +using namespace mlir; +using namespace mlir::tracing; + +//===----------------------------------------------------------------------===// +// ExecutionContext +//===----------------------------------------------------------------------===// + +static const thread_local ActionActiveStack *actionStack = nullptr; + +void ExecutionContext::setCallback(CallbackTy callback) { + onBreakpointControlExecutionCallback = callback; +} + +void ExecutionContext::registerObserver(Observer *observer) { + observers.push_back(observer); +} + +void ExecutionContext::operator()(llvm::function_ref transform, + const Action &action) { + // Update the top of the stack with the current action. + int depth = 0; + if (actionStack) + depth = actionStack->getDepth() + 1; + ActionActiveStack info{actionStack, action, depth}; + actionStack = &info; + auto raii = llvm::make_scope_exit([&]() { actionStack = info.getParent(); }); + Breakpoint *breakpoint = nullptr; + + // Invoke the callback here and handles control requests here. + auto handleUserInput = [&]() -> bool { + if (!onBreakpointControlExecutionCallback) + return true; + auto todoNext = onBreakpointControlExecutionCallback(actionStack); + switch (todoNext) { + case ExecutionContext::Apply: + depthToBreak = std::nullopt; + return true; + case ExecutionContext::Skip: + depthToBreak = std::nullopt; + return false; + case ExecutionContext::Step: + depthToBreak = depth + 1; + return true; + case ExecutionContext::Next: + depthToBreak = depth; + return true; + case ExecutionContext::Finish: + depthToBreak = depth - 1; + return true; + } + llvm::report_fatal_error("Unknown control request"); + }; + + // Try to find a breakpoint that would hit on this action. + // Right now there is no way to collect them all, we stop at the first one. + for (auto *breakpointManager : breakpoints) { + breakpoint = breakpointManager->match(action); + if (breakpoint) + break; + } + + bool shouldExecuteAction = true; + // If we have a breakpoint, or if `depthToBreak` was previously set and the + // current depth matches, we invoke the user-provided callback. + if (breakpoint || (depthToBreak && depth <= depthToBreak)) + shouldExecuteAction = handleUserInput(); + + // Notify the observers about the current action. + for (auto *observer : observers) + observer->beforeExecute(actionStack, breakpoint, shouldExecuteAction); + + if (shouldExecuteAction) { + // Execute the action here. + transform(); + + // Notify the observers about completion of the action. + for (auto *observer : observers) + observer->afterExecute(actionStack); + } + + if (depthToBreak && depth <= depthToBreak) + handleUserInput(); +} diff --git a/mlir/unittests/Debug/CMakeLists.txt b/mlir/unittests/Debug/CMakeLists.txt --- a/mlir/unittests/Debug/CMakeLists.txt +++ b/mlir/unittests/Debug/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(MLIRDebugTests DebugCounterTest.cpp + ExecutionContextTest.cpp ) target_link_libraries(MLIRDebugTests diff --git a/mlir/unittests/Debug/ExecutionContextTest.cpp b/mlir/unittests/Debug/ExecutionContextTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Debug/ExecutionContextTest.cpp @@ -0,0 +1,352 @@ +//===- ExecutionContextTest.cpp - Debug Execution Context first impl ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Debug/ExecutionContext.h" +#include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h" +#include "llvm/ADT/MapVector.h" +#include "gmock/gmock.h" + +using namespace mlir; +using namespace mlir::tracing; + +namespace { +struct DebuggerAction : public ActionImpl { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DebuggerAction) + static constexpr StringLiteral tag = "debugger-action"; +}; +struct OtherAction : public ActionImpl { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OtherAction) + static constexpr StringLiteral tag = "other-action"; +}; +struct ThirdAction : public ActionImpl { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ThirdAction) + static constexpr StringLiteral tag = "third-action"; +}; + +// Simple action that does nothing. +void noOp() { return; } + +/// This test executes a stack of nested action and check that the backtrace is +/// as expect. +TEST(ExecutionContext, ActionActiveStackTest) { + + // We'll break three time, once on each action, the backtraces should match + // each of the entries here. + std::vector> expectedStacks = { + {DebuggerAction::tag}, + {OtherAction::tag, DebuggerAction::tag}, + {ThirdAction::tag, OtherAction::tag, DebuggerAction::tag}}; + + auto checkStacks = [&](const ActionActiveStack *backtrace, + const std::vector ¤tStack) { + ASSERT_EQ((int)currentStack.size(), backtrace->getDepth() + 1); + for (StringRef stackEntry : currentStack) { + ASSERT_NE(backtrace, nullptr); + ASSERT_EQ(stackEntry, backtrace->getAction().getTag()); + backtrace = backtrace->getParent(); + } + }; + + std::vector controlSequence = { + ExecutionContext::Step, ExecutionContext::Step, ExecutionContext::Apply}; + int idx = 0; + StringRef current; + int currentDepth = -1; + auto onBreakpoint = [&](const ActionActiveStack *backtrace) { + current = backtrace->getAction().getTag(); + currentDepth = backtrace->getDepth(); + checkStacks(backtrace, expectedStacks[idx]); + return controlSequence[idx++]; + }; + + TagBreakpointManager simpleManager; + ExecutionContext executionCtx(onBreakpoint); + executionCtx.addBreakpointManager(&simpleManager); + std::vector breakpoints; + breakpoints.push_back(simpleManager.addBreakpoint(DebuggerAction::tag)); + breakpoints.push_back(simpleManager.addBreakpoint(OtherAction::tag)); + breakpoints.push_back(simpleManager.addBreakpoint(ThirdAction::tag)); + + auto third = [&]() { + EXPECT_EQ(current, ThirdAction::tag); + EXPECT_EQ(currentDepth, 2); + }; + auto nested = [&]() { + EXPECT_EQ(current, OtherAction::tag); + EXPECT_EQ(currentDepth, 1); + executionCtx(third, ThirdAction{}); + }; + auto original = [&]() { + EXPECT_EQ(current, DebuggerAction::tag); + EXPECT_EQ(currentDepth, 0); + executionCtx(nested, OtherAction{}); + return; + }; + + executionCtx(original, DebuggerAction{}); +} + +TEST(ExecutionContext, DebuggerTest) { + // Check matching and non matching breakpoints, with various enable/disable + // schemes. + int match = 0; + auto onBreakpoint = [&match](const ActionActiveStack *backtrace) { + match++; + return ExecutionContext::Skip; + }; + TagBreakpointManager simpleManager; + ExecutionContext executionCtx(onBreakpoint); + executionCtx.addBreakpointManager(&simpleManager); + + executionCtx(noOp, DebuggerAction{}); + EXPECT_EQ(match, 0); + + Breakpoint *dbgBreakpoint = simpleManager.addBreakpoint(DebuggerAction::tag); + executionCtx(noOp, DebuggerAction{}); + EXPECT_EQ(match, 1); + + dbgBreakpoint->disable(); + executionCtx(noOp, DebuggerAction{}); + EXPECT_EQ(match, 1); + + dbgBreakpoint->enable(); + executionCtx(noOp, DebuggerAction{}); + EXPECT_EQ(match, 2); + + executionCtx(noOp, OtherAction{}); + EXPECT_EQ(match, 2); +} + +TEST(ExecutionContext, ApplyTest) { + // Test the "apply" control. + std::vector tagSequence = {DebuggerAction::tag}; + std::vector controlSequence = { + ExecutionContext::Apply}; + int idx = 0, counter = 0; + auto onBreakpoint = [&](const ActionActiveStack *backtrace) { + ++counter; + EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag()); + return controlSequence[idx++]; + }; + auto callback = [&]() { EXPECT_EQ(counter, 1); }; + TagBreakpointManager simpleManager; + ExecutionContext executionCtx(onBreakpoint); + executionCtx.addBreakpointManager(&simpleManager); + simpleManager.addBreakpoint(DebuggerAction::tag); + + executionCtx(callback, DebuggerAction{}); + EXPECT_EQ(counter, 1); +} + +TEST(ExecutionContext, SkipTest) { + // Test the "skip" control. + std::vector tagSequence = {DebuggerAction::tag, + DebuggerAction::tag}; + std::vector controlSequence = { + ExecutionContext::Apply, ExecutionContext::Skip}; + int idx = 0, counter = 0, executionCounter = 0; + auto onBreakpoint = [&](const ActionActiveStack *backtrace) { + ++counter; + EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag()); + return controlSequence[idx++]; + }; + auto callback = [&]() { ++executionCounter; }; + TagBreakpointManager simpleManager; + ExecutionContext executionCtx(onBreakpoint); + executionCtx.addBreakpointManager(&simpleManager); + simpleManager.addBreakpoint(DebuggerAction::tag); + + executionCtx(callback, DebuggerAction{}); + executionCtx(callback, DebuggerAction{}); + EXPECT_EQ(counter, 2); + EXPECT_EQ(executionCounter, 1); +} + +TEST(ExecutionContext, StepApplyTest) { + // Test the "step" control with a nested action. + std::vector tagSequence = {DebuggerAction::tag, OtherAction::tag}; + std::vector controlSequence = { + ExecutionContext::Step, ExecutionContext::Apply}; + int idx = 0, counter = 0; + auto onBreakpoint = [&](const ActionActiveStack *backtrace) { + ++counter; + EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag()); + return controlSequence[idx++]; + }; + TagBreakpointManager simpleManager; + ExecutionContext executionCtx(onBreakpoint); + executionCtx.addBreakpointManager(&simpleManager); + simpleManager.addBreakpoint(DebuggerAction::tag); + auto nested = [&]() { EXPECT_EQ(counter, 2); }; + auto original = [&]() { + EXPECT_EQ(counter, 1); + executionCtx(nested, OtherAction{}); + }; + + executionCtx(original, DebuggerAction{}); + EXPECT_EQ(counter, 2); +} + +TEST(ExecutionContext, StepNothingInsideTest) { + // Test the "step" control without a nested action. + std::vector tagSequence = {DebuggerAction::tag, + DebuggerAction::tag}; + std::vector controlSequence = { + ExecutionContext::Step, ExecutionContext::Step}; + int idx = 0, counter = 0; + auto onBreakpoint = [&](const ActionActiveStack *backtrace) { + ++counter; + EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag()); + return controlSequence[idx++]; + }; + auto callback = [&]() { EXPECT_EQ(counter, 1); }; + TagBreakpointManager simpleManager; + ExecutionContext executionCtx(onBreakpoint); + executionCtx.addBreakpointManager(&simpleManager); + simpleManager.addBreakpoint(DebuggerAction::tag); + + executionCtx(callback, DebuggerAction{}); + EXPECT_EQ(counter, 2); +} + +TEST(ExecutionContext, NextTest) { + // Test the "next" control. + std::vector tagSequence = {DebuggerAction::tag, + DebuggerAction::tag}; + std::vector controlSequence = { + ExecutionContext::Next, ExecutionContext::Next}; + int idx = 0, counter = 0; + auto onBreakpoint = [&](const ActionActiveStack *backtrace) { + ++counter; + EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag()); + return controlSequence[idx++]; + }; + auto callback = [&]() { EXPECT_EQ(counter, 1); }; + TagBreakpointManager simpleManager; + ExecutionContext executionCtx(onBreakpoint); + executionCtx.addBreakpointManager(&simpleManager); + simpleManager.addBreakpoint(DebuggerAction::tag); + + executionCtx(callback, DebuggerAction{}); + EXPECT_EQ(counter, 2); +} + +TEST(ExecutionContext, FinishTest) { + // Test the "finish" control. + std::vector tagSequence = {DebuggerAction::tag, OtherAction::tag, + DebuggerAction::tag}; + std::vector controlSequence = { + ExecutionContext::Step, ExecutionContext::Finish, + ExecutionContext::Apply}; + int idx = 0, counter = 0; + auto onBreakpoint = [&](const ActionActiveStack *backtrace) { + ++counter; + EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag()); + return controlSequence[idx++]; + }; + TagBreakpointManager simpleManager; + ExecutionContext executionCtx(onBreakpoint); + executionCtx.addBreakpointManager(&simpleManager); + simpleManager.addBreakpoint(DebuggerAction::tag); + auto nested = [&]() { EXPECT_EQ(counter, 2); }; + auto original = [&]() { + EXPECT_EQ(counter, 1); + executionCtx(nested, OtherAction{}); + EXPECT_EQ(counter, 2); + }; + + executionCtx(original, DebuggerAction{}); + EXPECT_EQ(counter, 3); +} + +TEST(ExecutionContext, FinishBreakpointInNestedTest) { + // Test the "finish" control with a breakpoint in the nested action. + std::vector tagSequence = {OtherAction::tag, DebuggerAction::tag}; + std::vector controlSequence = { + ExecutionContext::Finish, ExecutionContext::Apply}; + int idx = 0, counter = 0; + auto onBreakpoint = [&](const ActionActiveStack *backtrace) { + ++counter; + EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag()); + return controlSequence[idx++]; + }; + TagBreakpointManager simpleManager; + ExecutionContext executionCtx(onBreakpoint); + executionCtx.addBreakpointManager(&simpleManager); + simpleManager.addBreakpoint(OtherAction::tag); + + auto nested = [&]() { EXPECT_EQ(counter, 1); }; + auto original = [&]() { + EXPECT_EQ(counter, 0); + executionCtx(nested, OtherAction{}); + EXPECT_EQ(counter, 1); + }; + + executionCtx(original, DebuggerAction{}); + EXPECT_EQ(counter, 2); +} + +TEST(ExecutionContext, FinishNothingBackTest) { + // Test the "finish" control without a nested action. + std::vector tagSequence = {DebuggerAction::tag}; + std::vector controlSequence = { + ExecutionContext::Finish}; + int idx = 0, counter = 0; + auto onBreakpoint = [&](const ActionActiveStack *backtrace) { + ++counter; + EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag()); + return controlSequence[idx++]; + }; + auto callback = [&]() { EXPECT_EQ(counter, 1); }; + TagBreakpointManager simpleManager; + ExecutionContext executionCtx(onBreakpoint); + executionCtx.addBreakpointManager(&simpleManager); + simpleManager.addBreakpoint(DebuggerAction::tag); + + executionCtx(callback, DebuggerAction{}); + EXPECT_EQ(counter, 1); +} + +TEST(ExecutionContext, EnableDisableBreakpointOnCallback) { + // Test enabling and disabling breakpoints while executing the action. + std::vector tagSequence = {DebuggerAction::tag, ThirdAction::tag, + OtherAction::tag, DebuggerAction::tag}; + std::vector controlSequence = { + ExecutionContext::Apply, ExecutionContext::Finish, + ExecutionContext::Finish, ExecutionContext::Apply}; + int idx = 0, counter = 0; + auto onBreakpoint = [&](const ActionActiveStack *backtrace) { + ++counter; + EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag()); + return controlSequence[idx++]; + }; + + TagBreakpointManager simpleManager; + ExecutionContext executionCtx(onBreakpoint); + executionCtx.addBreakpointManager(&simpleManager); + simpleManager.addBreakpoint(DebuggerAction::tag); + Breakpoint *toBeDisabled = simpleManager.addBreakpoint(OtherAction::tag); + + auto third = [&]() { EXPECT_EQ(counter, 2); }; + auto nested = [&]() { + EXPECT_EQ(counter, 1); + executionCtx(third, ThirdAction{}); + EXPECT_EQ(counter, 2); + }; + auto original = [&]() { + EXPECT_EQ(counter, 1); + toBeDisabled->disable(); + simpleManager.addBreakpoint(ThirdAction::tag); + executionCtx(nested, OtherAction{}); + EXPECT_EQ(counter, 3); + }; + + executionCtx(original, DebuggerAction{}); + EXPECT_EQ(counter, 4); +} +} // namespace