diff --git a/mlir/include/mlir/Debug/DebuggerExecutionContextHook.h b/mlir/include/mlir/Debug/DebuggerExecutionContextHook.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Debug/DebuggerExecutionContextHook.h @@ -0,0 +1,96 @@ +//===- DebuggerExecutionContextHook.h - Debugger Support --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of C API functions that are used by the debugger to +// interact with the ExecutionContext. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_DEBUGGEREXECUTIONCONTEXTHOOK_H +#define MLIR_SUPPORT_DEBUGGEREXECUTIONCONTEXTHOOK_H + +#include "mlir-c/IR.h" +#include "mlir/Debug/ExecutionContext.h" +#include "llvm/Support/Compiler.h" + +extern "C" { +struct MLIRBreakpoint; +struct MLIRIRunit; +typedef struct MLIRBreakpoint *BreakpointHandle; +typedef struct MLIRIRunit *irunitHandle; + +/// This is used by the debugger to control what to do after a breakpoint is +/// hit. See tracing::ExecutionContext::Control for more information. +void mlirDebuggerSetControl(int controlOption); + +/// Print the available context for the current Action. +void mlirDebuggerPrintContext(); + +/// Print the current action backtrace. +void mlirDebuggerPrintActionBacktrace(bool withContext); + +//===----------------------------------------------------------------------===// +// Cursor Management: The cursor is used to select an IRUnit from the context +// and to navigate through the IRUnit hierarchy. +//===----------------------------------------------------------------------===// + +/// Print the current IR unit cursor. +void mlirDebuggerCursorPrint(bool withRegion); + +/// Select the IR unit from the current context by ID. +void mlirDebuggerCursorSelectIRUnitFromContext(int index); + +/// Select the parent IR unit of the provided IR unit, or print an error if the +/// IR unit has no parent. +void mlirDebuggerCursorSelectParentIRUnit(); + +/// Select the child IR unit at the provided index, print an error if the index +/// is out of bound. For example if the irunit is an operation, the children IR +/// units will be the operation's regions. +void mlirDebuggerCursorSelectChildIRUnit(int index); + +/// Return the next IR unit logically in the IR. For example if the irunit is a +/// Region the next IR unit will be the next region in the parent operation or +/// nullptr if there is no next region. +void mlirDebuggerCursorSelectPreviousIRUnit(); + +/// Return the previous IR unit logically in the IR. For example if the irunit +/// is a Region, the previous IR unit will be the previous region in the parent +/// operation or nullptr if there is no previous region. +void mlirDebuggerCursorSelectNextIRUnit(); + +//===----------------------------------------------------------------------===// +// Breakpoint Management +//===----------------------------------------------------------------------===// + +/// Enable the provided breakpoint. +void mlirDebuggerEnableBreakpoint(BreakpointHandle breakpoint); + +/// Disable the provided breakpoint. +void mlirDebuggerDisableBreakpoint(BreakpointHandle breakpoint); + +/// Add a breakpoint matching exactly the provided tag. +BreakpointHandle mlirDebuggerAddTagBreakpoint(const char *tag); + +/// Add a breakpoint matching a pattern by name. +void mlirDebuggerAddRewritePatternBreakpoint(const char *patternNameInfo); + +/// Add a breakpoint matching a file, line and column. +void mlirDebuggerAddFileLineColLocBreakpoint(const char *file, int line, + int col); + +} // extern "C" + +namespace mlir { +// Setup the debugger hooks as a callback on the provided ExecutionContext. +void setupDebuggerExecutionContextHook( + tracing::ExecutionContext &executionContext); + +} // namespace mlir + +#endif // MLIR_SUPPORT_DEBUGGEREXECUTIONCONTEXTHOOK_H diff --git a/mlir/include/mlir/Debug/ExecutionContext.h b/mlir/include/mlir/Debug/ExecutionContext.h --- a/mlir/include/mlir/Debug/ExecutionContext.h +++ b/mlir/include/mlir/Debug/ExecutionContext.h @@ -28,8 +28,16 @@ const ActionActiveStack *getParent() const { return parent; } const Action &getAction() const { return action; } int getDepth() const { return depth; } + void print(raw_ostream &os, bool withContext) const; + void dump() const { + print(llvm::errs(), /*withContext=*/true); + llvm::errs() << "\n"; + } + Breakpoint *getBreakpoint() const { return breakpoint; } + void setBreakpoint(Breakpoint *breakpoint) { this->breakpoint = breakpoint; } private: + Breakpoint *breakpoint = nullptr; const ActionActiveStack *parent; const Action &action; int depth; @@ -69,7 +77,9 @@ ExecutionContext() = default; /// Set the callback that is used to control the execution. - void setCallback(CallbackTy callback); + void setCallback(CallbackTy callback) { + onBreakpointControlExecutionCallback = callback; + } /// This abstract class defines the interface used to observe an Action /// execution. It allows to be notified before and after the callback is diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -835,7 +835,7 @@ OpPrintingFlags &printGenericOpForm(); /// Skip printing regions. - OpPrintingFlags &skipRegions(); + OpPrintingFlags &skipRegions(bool skip = true); /// Do not verify the operation when using custom operation printers. OpPrintingFlags &assumeVerified(); diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h --- a/mlir/include/mlir/Rewrite/PatternApplicator.h +++ b/mlir/include/mlir/Rewrite/PatternApplicator.h @@ -37,8 +37,7 @@ "Encapsulate the application of rewrite patterns"; void print(raw_ostream &os) const override { - os << "`" << tag << "`\n" - << " pattern: " << pattern.getDebugName() << '\n'; + os << "`" << tag << " pattern: " << pattern.getDebugName(); } private: diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -78,6 +78,17 @@ } bool shouldEmitBytecode() const { return emitBytecodeFlag; } + /// Enable the debugger action hook: it makes the debugger able to intercept + /// MLIR Actions. + void enableDebuggerActionHook(bool enabled = true) { + enableDebuggerActionHookFlag = enabled; + } + + /// Return true if the Debugger action hook is enabled. + bool isDebuggerActionHookEnabled() const { + return enableDebuggerActionHookFlag; + } + /// Set the IRDL file to load before processing the input. MlirOptMainConfig &setIrdlFile(StringRef file) { irdlFileFlag = file; @@ -180,6 +191,9 @@ /// Emit bytecode instead of textual assembly when generating output. bool emitBytecodeFlag = false; + /// Enable the Debugger action hook: Debugger can intercept MLIR Actions. + bool enableDebuggerActionHookFlag = false; + /// IRDL file to register before processing the input. std::string irdlFileFlag = ""; diff --git a/mlir/lib/Debug/CMakeLists.txt b/mlir/lib/Debug/CMakeLists.txt --- a/mlir/lib/Debug/CMakeLists.txt +++ b/mlir/lib/Debug/CMakeLists.txt @@ -4,6 +4,7 @@ DebugCounter.cpp ExecutionContext.cpp BreakpointManagers/FileLineColLocBreakpointManager.cpp + DebuggerExecutionContextHook.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Debug diff --git a/mlir/lib/Debug/DebuggerExecutionContextHook.cpp b/mlir/lib/Debug/DebuggerExecutionContextHook.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Debug/DebuggerExecutionContextHook.cpp @@ -0,0 +1,370 @@ +//===- DebuggerExecutionContextHook.cpp - Debugger Support ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Debug/DebuggerExecutionContextHook.h" + +#include "mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h" +#include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h" + +using namespace mlir; +using namespace mlir::tracing; + +namespace { +/// This structure tracks the state of the interactive debugger. +struct DebuggerState { + /// This variable keeps track of the current control option. This is set by + /// the debugger when control is handed over to it. + ExecutionContext::Control debuggerControl = ExecutionContext::Apply; + + /// The breakpoint manager that allows the debugger to set breakpoints on + /// action tags. + TagBreakpointManager tagBreakpointManager; + + /// The breakpoint manager that allows the debugger to set breakpoints on + /// FileLineColLoc locations. + FileLineColLocBreakpointManager fileLineColLocBreakpointManager; + + /// Map of breakpoint IDs to breakpoint objects. + DenseMap breakpointIdsMap; + + /// The current stack of actiive actions. + const tracing::ActionActiveStack *actionActiveStack; + + /// This is a "cursor" in the IR, it is used for the debugger to navigate the + /// IR associated to the actions. + IRUnit cursor; +}; +} // namespace + +static DebuggerState &getGlobalDebuggerState() { + static LLVM_THREAD_LOCAL DebuggerState debuggerState; + return debuggerState; +} + +extern "C" { +void mlirDebuggerSetControl(int controlOption) { + getGlobalDebuggerState().debuggerControl = + static_cast(controlOption); +} + +void mlirDebuggerPrintContext() { + DebuggerState &state = getGlobalDebuggerState(); + if (!state.actionActiveStack) { + llvm::outs() << "No active action.\n"; + return; + } + const ArrayRef &units = + state.actionActiveStack->getAction().getContextIRUnits(); + llvm::outs() << units.size() << " available IRUnits:\n"; + for (const IRUnit &unit : units) { + llvm::outs() << " - "; + unit.print( + llvm::outs(), + OpPrintingFlags().useLocalScope().skipRegions().enableDebugInfo()); + llvm::outs() << "\n"; + } +} + +void mlirDebuggerPrintActionBacktrace(bool withContext) { + DebuggerState &state = getGlobalDebuggerState(); + if (!state.actionActiveStack) { + llvm::outs() << "No active action.\n"; + return; + } + state.actionActiveStack->print(llvm::outs(), withContext); +} + +//===----------------------------------------------------------------------===// +// Cursor Management +//===----------------------------------------------------------------------===// + +void mlirDebuggerCursorPrint(bool withRegion) { + auto &state = getGlobalDebuggerState(); + if (!state.cursor) { + llvm::outs() << "No active MLIR cursor, select from the context first\n"; + return; + } + state.cursor.print(llvm::outs(), OpPrintingFlags() + .skipRegions(!withRegion) + .useLocalScope() + .enableDebugInfo()); + llvm::outs() << "\n"; +} + +void mlirDebuggerCursorSelectIRUnitFromContext(int index) { + auto &state = getGlobalDebuggerState(); + if (!state.actionActiveStack) { + llvm::outs() << "No active MLIR Action stack\n"; + return; + } + ArrayRef units = + state.actionActiveStack->getAction().getContextIRUnits(); + if (index < 0 || index >= static_cast(units.size())) { + llvm::outs() << "Index invalid, bounds: [0, " << units.size() + << "] but got " << index << "\n"; + return; + } + state.cursor = units[index]; + state.cursor.print(llvm::outs()); + llvm::outs() << "\n"; +} + +void mlirDebuggerCursorSelectParentIRUnit() { + auto &state = getGlobalDebuggerState(); + if (!state.cursor) { + llvm::outs() << "No active MLIR cursor, select from the context first\n"; + return; + } + IRUnit *unit = &state.cursor; + if (auto *op = unit->dyn_cast()) { + state.cursor = op->getBlock(); + } else if (auto *region = unit->dyn_cast()) { + + state.cursor = region->getParentOp(); + } else if (auto *block = unit->dyn_cast()) { + state.cursor = block->getParent(); + } else { + llvm::outs() << "Current cursor is not a valid IRUnit"; + return; + } + state.cursor.print(llvm::outs()); + llvm::outs() << "\n"; +} + +void mlirDebuggerCursorSelectChildIRUnit(int index) { + auto &state = getGlobalDebuggerState(); + if (!state.cursor) { + llvm::outs() << "No active MLIR cursor, select from the context first\n"; + return; + } + IRUnit *unit = &state.cursor; + if (auto *op = unit->dyn_cast()) { + if (index < 0 || index >= static_cast(op->getNumRegions())) { + llvm::outs() << "Index invalid, op has " << op->getNumRegions() + << " but got " << index << "\n"; + return; + } + state.cursor = &op->getRegion(index); + } else if (auto *region = unit->dyn_cast()) { + auto block = region->begin(); + int count = 0; + while (block != region->end() && count != index) { + ++block; + ++count; + } + + if (block == region->end()) { + llvm::outs() << "Index invalid, region has " << count << " block but got " + << index << "\n"; + return; + } + state.cursor = &*block; + } else if (auto *block = unit->dyn_cast()) { + auto op = block->begin(); + int count = 0; + while (op != block->end() && count != index) { + ++op; + ++count; + } + + if (op == block->end()) { + llvm::outs() << "Index invalid, block has " << count + << "operations but got " << index << "\n"; + return; + } + state.cursor = &*op; + } else { + llvm::outs() << "Current cursor is not a valid IRUnit"; + return; + } + state.cursor.print(llvm::outs()); + llvm::outs() << "\n"; +} + +void mlirDebuggerCursorSelectPreviousIRUnit() { + auto &state = getGlobalDebuggerState(); + if (!state.cursor) { + llvm::outs() << "No active MLIR cursor, select from the context first\n"; + return; + } + IRUnit *unit = &state.cursor; + if (auto *op = unit->dyn_cast()) { + Operation *previous = op->getPrevNode(); + if (!previous) { + llvm::outs() << "No previous operation in the current block\n"; + return; + } + state.cursor = previous; + } else if (auto *region = unit->dyn_cast()) { + llvm::outs() << "Has region\n"; + Operation *parent = region->getParentOp(); + if (!parent) { + llvm::outs() << "No parent operation for the current region\n"; + return; + } + if (region->getRegionNumber() == 0) { + llvm::outs() << "No previous region in the current operation\n"; + return; + } + state.cursor = + ®ion->getParentOp()->getRegion(region->getRegionNumber() - 1); + } else if (auto *block = unit->dyn_cast()) { + Block *previous = block->getPrevNode(); + if (!previous) { + llvm::outs() << "No previous block in the current region\n"; + return; + } + state.cursor = previous; + } else { + llvm::outs() << "Current cursor is not a valid IRUnit"; + return; + } + state.cursor.print(llvm::outs()); + llvm::outs() << "\n"; +} + +void mlirDebuggerCursorSelectNextIRUnit() { + auto &state = getGlobalDebuggerState(); + if (!state.cursor) { + llvm::outs() << "No active MLIR cursor, select from the context first\n"; + return; + } + IRUnit *unit = &state.cursor; + if (auto *op = unit->dyn_cast()) { + Operation *next = op->getNextNode(); + if (!next) { + llvm::outs() << "No next operation in the current block\n"; + return; + } + state.cursor = next; + } else if (auto *region = unit->dyn_cast()) { + Operation *parent = region->getParentOp(); + if (!parent) { + llvm::outs() << "No parent operation for the current region\n"; + return; + } + if (region->getRegionNumber() == parent->getNumRegions() - 1) { + llvm::outs() << "No next region in the current operation\n"; + return; + } + state.cursor = + ®ion->getParentOp()->getRegion(region->getRegionNumber() + 1); + } else if (auto *block = unit->dyn_cast()) { + Block *next = block->getNextNode(); + if (!next) { + llvm::outs() << "No next block in the current region\n"; + return; + } + state.cursor = next; + } else { + llvm::outs() << "Current cursor is not a valid IRUnit"; + return; + } + state.cursor.print(llvm::outs()); + llvm::outs() << "\n"; +} + +//===----------------------------------------------------------------------===// +// Breakpoint Management +//===----------------------------------------------------------------------===// + +void mlirDebuggerEnableBreakpoint(BreakpointHandle breakpoint) { + reinterpret_cast(breakpoint)->enable(); +} + +void mlirDebuggerDisableBreakpoint(BreakpointHandle breakpoint) { + reinterpret_cast(breakpoint)->disable(); +} + +BreakpointHandle mlirDebuggerAddTagBreakpoint(const char *tag) { + DebuggerState &state = getGlobalDebuggerState(); + Breakpoint *breakpoint = + state.tagBreakpointManager.addBreakpoint(StringRef(tag, strlen(tag))); + int breakpointId = state.breakpointIdsMap.size() + 1; + state.breakpointIdsMap[breakpointId] = breakpoint; + return reinterpret_cast(breakpoint); +} + +void mlirDebuggerAddRewritePatternBreakpoint(const char *patternNameInfo) {} + +void mlirDebuggerAddFileLineColLocBreakpoint(const char *file, int line, + int col) { + getGlobalDebuggerState().fileLineColLocBreakpointManager.addBreakpoint( + StringRef(file, strlen(file)), line, col); +} + +} // extern "C" + +LLVM_ATTRIBUTE_NOINLINE void mlirDebuggerBreakpointHook() { + static LLVM_THREAD_LOCAL void *volatile sink; + sink = (void *)&sink; +} + +static void preventLinkerDeadCodeElim() { + static void *volatile sink; + static bool initialized = [&]() { + sink = (void *)mlirDebuggerSetControl; + sink = (void *)mlirDebuggerEnableBreakpoint; + sink = (void *)mlirDebuggerDisableBreakpoint; + sink = (void *)mlirDebuggerPrintContext; + sink = (void *)mlirDebuggerPrintActionBacktrace; + sink = (void *)mlirDebuggerCursorPrint; + sink = (void *)mlirDebuggerCursorSelectIRUnitFromContext; + sink = (void *)mlirDebuggerCursorSelectParentIRUnit; + sink = (void *)mlirDebuggerCursorSelectChildIRUnit; + sink = (void *)mlirDebuggerCursorSelectPreviousIRUnit; + sink = (void *)mlirDebuggerCursorSelectNextIRUnit; + sink = (void *)mlirDebuggerAddTagBreakpoint; + sink = (void *)mlirDebuggerAddRewritePatternBreakpoint; + sink = (void *)mlirDebuggerAddFileLineColLocBreakpoint; + sink = (void *)&sink; + return true; + }(); + (void)initialized; +} + +static tracing::ExecutionContext::Control +debuggerCallBackFunction(const tracing::ActionActiveStack *actionStack) { + preventLinkerDeadCodeElim(); + // Invoke the breakpoint hook, the debugger is supposed to trap this. + // The debugger controls the execution from there by invoking + // `mlirDebuggerSetControl()`. + auto &state = getGlobalDebuggerState(); + state.actionActiveStack = actionStack; + getGlobalDebuggerState().debuggerControl = ExecutionContext::Apply; + actionStack->getAction().print(llvm::outs()); + llvm::outs() << "\n"; + mlirDebuggerBreakpointHook(); + return getGlobalDebuggerState().debuggerControl; +} + +namespace { +/// Manage the stack of actions that are currently active. +class DebuggerObserver : public ExecutionContext::Observer { + void beforeExecute(const ActionActiveStack *action, Breakpoint *breakpoint, + bool willExecute) override { + auto &state = getGlobalDebuggerState(); + state.actionActiveStack = action; + } + void afterExecute(const ActionActiveStack *action) override { + auto &state = getGlobalDebuggerState(); + state.actionActiveStack = action->getParent(); + state.cursor = nullptr; + } +}; +} // namespace + +void mlir::setupDebuggerExecutionContextHook( + tracing::ExecutionContext &executionContext) { + executionContext.setCallback(debuggerCallBackFunction); + DebuggerState &state = getGlobalDebuggerState(); + static DebuggerObserver observer; + executionContext.registerObserver(&observer); + executionContext.addBreakpointManager(&state.fileLineColLocBreakpointManager); + executionContext.addBreakpointManager(&state.tagBreakpointManager); +} diff --git a/mlir/lib/Debug/ExecutionContext.cpp b/mlir/lib/Debug/ExecutionContext.cpp --- a/mlir/lib/Debug/ExecutionContext.cpp +++ b/mlir/lib/Debug/ExecutionContext.cpp @@ -9,6 +9,7 @@ #include "mlir/Debug/ExecutionContext.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/Support/FormatVariadic.h" #include @@ -16,15 +17,37 @@ using namespace mlir::tracing; //===----------------------------------------------------------------------===// -// ExecutionContext +// ActionActiveStack //===----------------------------------------------------------------------===// -static const thread_local ActionActiveStack *actionStack = nullptr; - -void ExecutionContext::setCallback(CallbackTy callback) { - onBreakpointControlExecutionCallback = callback; +void ActionActiveStack::print(raw_ostream &os, bool withContext) const { + os << "ActionActiveStack depth " << getDepth() << "\n"; + const ActionActiveStack *current = this; + int count = 0; + while (current) { + llvm::errs() << llvm::formatv("#{0,3}: ", count++); + current->action.print(llvm::errs()); + llvm::errs() << "\n"; + ArrayRef context = current->action.getContextIRUnits(); + if (withContext && !context.empty()) { + llvm::errs() << "Context:\n"; + for (const IRUnit &unit : current->action.getContextIRUnits()) { + llvm::errs() << " - "; + unit.print(llvm::errs()); + llvm::errs() << "\n"; + } + llvm::errs() << "\n"; + } + current = current->parent; + } } +//===----------------------------------------------------------------------===// +// ExecutionContext +//===----------------------------------------------------------------------===// + +static const LLVM_THREAD_LOCAL ActionActiveStack *actionStack = nullptr; + void ExecutionContext::registerObserver(Observer *observer) { observers.push_back(observer); } @@ -72,6 +95,7 @@ if (breakpoint) break; } + info.setBreakpoint(breakpoint); bool shouldExecuteAction = true; // If we have a breakpoint, or if `depthToBreak` was previously set and the diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -225,8 +225,8 @@ } /// Always skip Regions. -OpPrintingFlags &OpPrintingFlags::skipRegions() { - skipRegionsFlag = true; +OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) { + skipRegionsFlag = skip; return *this; } diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -14,6 +14,7 @@ #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/Debug/Counter.h" +#include "mlir/Debug/DebuggerExecutionContextHook.h" #include "mlir/Debug/ExecutionContext.h" #include "mlir/Debug/Observers/ActionLogging.h" #include "mlir/Dialect/IRDL/IR/IRDL.h" @@ -77,6 +78,11 @@ cl::desc("IRDL file to register before processing the input"), cl::location(irdlFileFlag), cl::init(""), cl::value_desc("filename")); + static cl::opt enableDebuggerHook( + "mlir-enable-debugger-hook", + cl::desc("Enable Debugger hook for debugging MLIR Actions"), + cl::location(enableDebuggerActionHookFlag), cl::init(false)); + static cl::opt explicitModule( "no-implicit-module", cl::desc("Disable implicit addition of a top-level module op during " @@ -217,30 +223,38 @@ class InstallDebugHandler { public: InstallDebugHandler(MLIRContext &context, const MlirOptMainConfig &config) { - if (config.getLogActionsTo().empty()) { + if (config.getLogActionsTo().empty() && + !config.isDebuggerActionHookEnabled()) { if (tracing::DebugCounter::isActivated()) context.registerActionHandler(tracing::DebugCounter()); return; } + llvm::errs() << "ExecutionContext registered on the context"; if (tracing::DebugCounter::isActivated()) emitError(UnknownLoc::get(&context), - "Debug counters are incompatible with --log-actions-to option " - "and are disabled"); - std::string errorMessage; - logActionsFile = openOutputFile(config.getLogActionsTo(), &errorMessage); - if (!logActionsFile) { - emitError(UnknownLoc::get(&context), - "Opening file for --log-actions-to failed: ") - << errorMessage << "\n"; - return; + "Debug counters are incompatible with --log-actions-to and " + "--mlir-enable-debugger-hook options and are disabled"); + if (!config.getLogActionsTo().empty()) { + std::string errorMessage; + logActionsFile = openOutputFile(config.getLogActionsTo(), &errorMessage); + if (!logActionsFile) { + emitError(UnknownLoc::get(&context), + "Opening file for --log-actions-to failed: ") + << errorMessage << "\n"; + return; + } + logActionsFile->keep(); + raw_fd_ostream &logActionsStream = logActionsFile->os(); + actionLogger = std::make_unique(logActionsStream); + for (const auto *locationBreakpoint : config.getLogActionsLocFilters()) + actionLogger->addBreakpointManager(locationBreakpoint); + executionContext.registerObserver(actionLogger.get()); } - logActionsFile->keep(); - raw_fd_ostream &logActionsStream = logActionsFile->os(); - actionLogger = std::make_unique(logActionsStream); - for (const auto *locationBreakpoint : config.getLogActionsLocFilters()) - actionLogger->addBreakpointManager(locationBreakpoint); - - executionContext.registerObserver(actionLogger.get()); + if (config.isDebuggerActionHookEnabled()) { + llvm::errs() << " (with Debugger hook)"; + setupDebuggerExecutionContextHook(executionContext); + } + llvm::errs() << "\n"; context.registerActionHandler(executionContext); } diff --git a/mlir/utils/lldb-scripts/action_debugging.py b/mlir/utils/lldb-scripts/action_debugging.py new file mode 100644 --- /dev/null +++ b/mlir/utils/lldb-scripts/action_debugging.py @@ -0,0 +1,568 @@ +#!/usr/bin/env python + +# --------------------------------------------------------------------- +# Be sure to add the python path that points to the LLDB shared library. +# +# # To use this in the embedded python interpreter using "lldb" just +# import it with the full path using the "command script import" +# command +# (lldb) command script import /path/to/cmdtemplate.py +# --------------------------------------------------------------------- + +import inspect +import lldb +import argparse +import shlex +import sys + +# Each new breakpoint gets a unique ID starting from 1. +nextid = 1 +# List of breakpoint set from python, the key is the ID and the value the +# actual breakpoint. These are NOT LLDB SBBreakpoint objects. +breakpoints = dict() + +exprOptions = lldb.SBExpressionOptions() +exprOptions.SetIgnoreBreakpoints() +exprOptions.SetLanguage(lldb.eLanguageTypeC) + + +class MlirDebug: + """MLIR debugger commands + This is the class that hooks into LLDB and registers the `mlir` command. + Other providers can register subcommands below this one. + """ + + lldb_command = "mlir" + parser = None + + def __init__(self, debugger, unused): + super().__init__() + self.create_options() + self.help_string = MlirDebug.parser.format_help() + + @classmethod + def create_options(cls): + if MlirDebug.parser: + return MlirDebug.parser + usage = "usage: %s [options]" % (cls.lldb_command) + description = "TODO." + + # Pass add_help_option = False, since this keeps the command in line + # with lldb commands, and we wire up "help command" to work by + # providing the long & short help methods below. + MlirDebug.parser = argparse.ArgumentParser( + prog=cls.lldb_command, usage=usage, description=description, add_help=False + ) + MlirDebug.subparsers = MlirDebug.parser.add_subparsers(dest="command") + return MlirDebug.parser + + def get_short_help(self): + return "MLIR debugger commands" + + def get_long_help(self): + return self.help_string + + def __call__(self, debugger, command, exe_ctx, result): + # Use the Shell Lexer to properly parse up command options just like a + # shell would + command_args = shlex.split(command) + + try: + args = MlirDebug.parser.parse_args(command_args) + except: + result.SetError("option parsing failed") + raise + args.func(args, debugger, command, exe_ctx, result) + + @classmethod + def on_process_start(frame, bp_loc, dict): + print("Process started") + + +class SetControl: + # Define the subcommands that controls what to do when a breakpoint is hit. + # The key is the subcommand name, the value is a tuple of the command ID to + # pass to MLIR and the help string. + commands = { + "apply": (1, "Apply the current action and continue the execution"), + "skip": (2, "Skip the current action and continue the execution"), + "step": (3, "Step into the current action"), + "next": (4, "Step over the current action"), + "finish": (5, "Step out of the current action"), + } + + @classmethod + def register_mlir_subparser(cls): + for cmd, (cmdInt, help) in cls.commands.items(): + parser = MlirDebug.subparsers.add_parser( + cmd, + help=help, + ) + parser.set_defaults(func=cls.process_options) + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + frame = exe_ctx.GetFrame() + if not frame.IsValid(): + result.SetError("No valid frame (program not running?)") + return + cmdInt = cls.commands.get(options.command, None) + if not cmdInt: + result.SetError("Invalid command: %s" % (options.command)) + return + + result = frame.EvaluateExpression( + "((bool (*)(int))mlirDebuggerSetControl)(%d)" % (cmdInt[0]), + exprOptions, + ) + if not result.error.Success(): + print("Error setting up command: %s" % (result.error)) + return + debugger.SetAsync(True) + result = exe_ctx.GetProcess().Continue() + debugger.SetAsync(False) + + +class PrintContext: + @classmethod + def register_mlir_subparser(cls): + cls.parser = MlirDebug.subparsers.add_parser( + "context", help="Print the current context" + ) + cls.parser.set_defaults(func=cls.process_options) + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + frame = exe_ctx.GetFrame() + if not frame.IsValid(): + result.SetError("Can't print context without a valid frame") + return + result = frame.EvaluateExpression( + "((bool (*)())&mlirDebuggerPrintContext)()", exprOptions + ) + if not result.error.Success(): + print("Error printing context: %s" % (result.error)) + return + + +class Backtrace: + @classmethod + def register_mlir_subparser(cls): + cls.parser = MlirDebug.subparsers.add_parser( + "backtrace", aliases=["bt"], help="Print the current backtrace" + ) + cls.parser.set_defaults(func=cls.process_options) + cls.parser.add_argument("--context", default=False, action="store_true") + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + frame = exe_ctx.GetFrame() + if not frame.IsValid(): + result.SetError( + "Can't backtrace without a valid frame (program not running?)" + ) + result = frame.EvaluateExpression( + "((bool(*)(bool))mlirDebuggerPrintActionBacktrace)(%d)" % (options.context), + exprOptions, + ) + if not result.error.Success(): + print("Error printing breakpoints: %s" % (result.error)) + return + + +############################################################################### +# Cursor manipulation +############################################################################### + + +class PrintCursor: + @classmethod + def register_mlir_subparser(cls): + cls.parser = MlirDebug.subparsers.add_parser( + "cursor-print", aliases=["cursor-p"], help="Print the current cursor" + ) + cls.parser.add_argument( + "--print-region", "--regions", "-r", default=False, action="store_true" + ) + cls.parser.set_defaults(func=cls.process_options) + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + frame = exe_ctx.GetFrame() + if not frame.IsValid(): + result.SetError( + "Can't print cursor without a valid frame (program not running?)" + ) + result = frame.EvaluateExpression( + "((bool(*)(bool))mlirDebuggerCursorPrint)(%d)" % (options.print_region), + exprOptions, + ) + if not result.error.Success(): + print("Error printing cursor: %s" % (result.error)) + return + + +class SelectCursorFromContext: + @classmethod + def register_mlir_subparser(cls): + cls.parser = MlirDebug.subparsers.add_parser( + "cursor-select-from-context", + aliases=["cursor-s"], + help="Select the cursor from the current context", + ) + cls.parser.add_argument("index", type=int, help="Index in the context") + cls.parser.set_defaults(func=cls.process_options) + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + frame = exe_ctx.GetFrame() + if not frame.IsValid(): + result.SetError( + "Can't manipulate cursor without a valid frame (program not running?)" + ) + result = frame.EvaluateExpression( + "((bool(*)(int))mlirDebuggerCursorSelectIRUnitFromContext)(%d)" + % options.index, + exprOptions, + ) + if not result.error.Success(): + print("Error manipulating cursor: %s" % (result.error)) + return + + +class CursorSelectParent: + @classmethod + def register_mlir_subparser(cls): + cls.parser = MlirDebug.subparsers.add_parser( + "cursor-parent", aliases=["cursor-up"], help="Select the cursor parent" + ) + cls.parser.set_defaults(func=cls.process_options) + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + frame = exe_ctx.GetFrame() + if not frame.IsValid(): + result.SetError( + "Can't manipulate cursor without a valid frame (program not running?)" + ) + result = frame.EvaluateExpression( + "((bool(*)())mlirDebuggerCursorSelectParentIRUnit)()", + exprOptions, + ) + if not result.error.Success(): + print("Error manipulating cursor: %s" % (result.error)) + return + + +class SelectCursorChild: + @classmethod + def register_mlir_subparser(cls): + cls.parser = MlirDebug.subparsers.add_parser( + "cursor-child", aliases=["cursor-c"], help="Select the nth child" + ) + cls.parser.add_argument("index", type=int, help="Index of the child to select") + cls.parser.set_defaults(func=cls.process_options) + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + frame = exe_ctx.GetFrame() + if not frame.IsValid(): + result.SetError( + "Can't manipulate cursor without a valid frame (program not running?)" + ) + result = frame.EvaluateExpression( + "((bool(*)(int))mlirDebuggerCursorSelectChildIRUnit)(%d)" % options.index, + exprOptions, + ) + if not result.error.Success(): + print("Error manipulating cursor: %s" % (result.error)) + return + + +class CursorSelecPrevious: + @classmethod + def register_mlir_subparser(cls): + cls.parser = MlirDebug.subparsers.add_parser( + "cursor-previous", + aliases=["cursor-prev"], + help="Select the cursor previous element", + ) + cls.parser.set_defaults(func=cls.process_options) + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + frame = exe_ctx.GetFrame() + if not frame.IsValid(): + result.SetError( + "Can't manipulate cursor without a valid frame (program not running?)" + ) + result = frame.EvaluateExpression( + "((bool(*)())mlirDebuggerCursorSelectPreviousIRUnit)()", + exprOptions, + ) + if not result.error.Success(): + print("Error manipulating cursor: %s" % (result.error)) + return + + +class CursorSelecNext: + @classmethod + def register_mlir_subparser(cls): + cls.parser = MlirDebug.subparsers.add_parser( + "cursor-next", aliases=["cursor-n"], help="Select the cursor next element" + ) + cls.parser.set_defaults(func=cls.process_options) + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + frame = exe_ctx.GetFrame() + if not frame.IsValid(): + result.SetError( + "Can't manipulate cursor without a valid frame (program not running?)" + ) + result = frame.EvaluateExpression( + "((bool(*)())mlirDebuggerCursorSelectNextIRUnit)()", + exprOptions, + ) + if not result.error.Success(): + print("Error manipulating cursor: %s" % (result.error)) + return + + +############################################################################### +# Breakpoints +############################################################################### + + +class EnableBreakpoint: + @classmethod + def register_mlir_subparser(cls): + cls.parser = MlirDebug.subparsers.add_parser( + "enable", help="Enable a single breakpoint (given its ID)" + ) + cls.parser.add_argument("id", help="ID of the breakpoint to enable") + cls.parser.set_defaults(func=cls.process_options) + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + bp = breakpoints.get(int(options.id), None) + if not bp: + result.SetError("No breakpoint with ID %d" % int(options.id)) + return + bp.enable(exe_ctx.GetFrame()) + + +class DisableBreakpoint: + @classmethod + def register_mlir_subparser(cls): + cls.parser = MlirDebug.subparsers.add_parser( + "disable", help="Disable a single breakpoint (given its ID)" + ) + cls.parser.add_argument("id", help="ID of the breakpoint to disable") + cls.parser.set_defaults(func=cls.process_options) + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + bp = breakpoints.get(int(options.id), None) + if not bp: + result.SetError("No breakpoint with ID %s" % options.id) + return + bp.disable(exe_ctx.GetFrame()) + + +class ListBreakpoints: + @classmethod + def register_mlir_subparser(cls): + cls.parser = MlirDebug.subparsers.add_parser( + "list", help="List all current breakpoints" + ) + cls.parser.set_defaults(func=cls.process_options) + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + for id, bp in sorted(breakpoints.items()): + print(id, type(id), str(bp), "enabled" if bp.isEnabled else "disabled") + + +class Breakpoint: + def __init__(self): + global nextid + self.id = nextid + nextid += 1 + breakpoints[self.id] = self + self.isEnabled = True + + def enable(self, frame=None): + self.isEnabled = True + if not frame or not frame.IsValid(): + return + # use a C cast to force the type of the breakpoint handle to be void * so + # that we don't rely on DWARF. Also add a fake bool return value otherwise + # LLDB can't signal any error with the expression evaluation (at least I don't know how). + cmd = ( + "((bool (*)(void *))mlirDebuggerEnableBreakpoint)((void *)%s)" % self.handle + ) + result = frame.EvaluateExpression(cmd, exprOptions) + if not result.error.Success(): + print("Error enabling breakpoint: %s" % (result.error)) + return + + def disable(self, frame=None): + self.isEnabled = False + if not frame or not frame.IsValid(): + return + # use a C cast to force the type of the breakpoint handle to be void * so + # that we don't rely on DWARF. Also add a fake bool return value otherwise + # LLDB can't signal any error with the expression evaluation (at least I don't know how). + cmd = ( + "((bool (*)(void *)) mlirDebuggerDisableBreakpoint)((void *)%s)" + % self.handle + ) + result = frame.EvaluateExpression(cmd, exprOptions) + if not result.error.Success(): + print("Error disabling breakpoint: %s" % (result.error)) + return + + +class TagBreakpoint(Breakpoint): + mlir_subcommand = "break-on-tag" + + def __init__(self, tag): + super().__init__() + self.tag = tag + + def __str__(self): + return "[%d] TagBreakpoint(%s)" % (self.id, self.tag) + + @classmethod + def register_mlir_subparser(cls): + cls.parser = MlirDebug.subparsers.add_parser( + cls.mlir_subcommand, help="add a breakpoint on actions' tag matching" + ) + cls.parser.set_defaults(func=cls.process_options) + cls.parser.add_argument("tag", help="tag to match") + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + breakpoint = TagBreakpoint(options.tag) + print("Added breakpoint %s" % str(breakpoint)) + + frame = exe_ctx.GetFrame() + if frame.IsValid(): + breakpoint.install(frame) + + def install(self, frame): + result = frame.EvaluateExpression( + '((void *(*)(const char *))mlirDebuggerAddTagBreakpoint)("%s")' + % (self.tag), + exprOptions, + ) + if not result.error.Success(): + print("Error installing breakpoint: %s" % (result.error)) + return + # Save the handle, this is necessary to implement enable/disable. + self.handle = result.GetValue() + + +class FileLineBreakpoint(Breakpoint): + mlir_subcommand = "break-on-file" + + def __init__(self, file, line, col): + super().__init__() + self.file = file + self.line = line + self.col = col + + def __str__(self): + return "[%d] FileLineBreakpoint(%s, %d, %d)" % ( + self.id, + self.file, + self.line, + self.col, + ) + + @classmethod + def register_mlir_subparser(cls): + cls.parser = MlirDebug.subparsers.add_parser( + cls.mlir_subcommand, + help="add a breakpoint that filters on location of the IR affected by an action. The syntax is file:line:col where file and col are optional", + ) + cls.parser.set_defaults(func=cls.process_options) + cls.parser.add_argument("location", type=str) + + @classmethod + def process_options(cls, options, debugger, command, exe_ctx, result): + split_loc = options.location.split(":") + file = split_loc[0] + line = int(split_loc[1]) if len(split_loc) > 1 else -1 + col = int(split_loc[2]) if len(split_loc) > 2 else -1 + breakpoint = FileLineBreakpoint(file, line, col) + print("Added breakpoint %s" % str(breakpoint)) + + frame = exe_ctx.GetFrame() + if frame.IsValid(): + breakpoint.install(frame) + + def install(self, frame): + result = frame.EvaluateExpression( + '((void *(*)(const char *, int, int))mlirDebuggerAddFileLineColLocBreakpoint)("%s", %d, %d)' + % (self.file, self.line, self.col), + exprOptions, + ) + if not result.error.Success(): + print("Error installing breakpoint: %s" % (result.error)) + return + # Save the handle, this is necessary to implement enable/disable. + self.handle = result.GetValue() + + +def on_start(frame, bpno, err): + print("MLIR debugger attaching...") + for _, bp in sorted(breakpoints.items()): + if bp.isEnabled: + print("Installing breakpoint %s" % (str(bp))) + bp.install(frame) + else: + print("Skipping disabled breakpoint %s" % (str(bp))) + + return True + + +def __lldb_init_module(debugger, dict): + target = debugger.GetTargetAtIndex(0) + debugger.SetAsync(False) + if not target: + print("No target is loaded, please load a target before loading this script.") + return + if debugger.GetNumTargets() > 1: + print( + "Multiple targets (%s) loaded, attaching MLIR debugging to %s" + % (debugger.GetNumTargets(), target) + ) + + # Register all classes that have a register_lldb_command method + module_name = __name__ + parser = MlirDebug.create_options() + MlirDebug.__doc__ = parser.format_help() + + # Add the MLIR entry point to LLDB as a command. + command = "command script add -o -c %s.%s %s" % ( + module_name, + MlirDebug.__name__, + MlirDebug.lldb_command, + ) + debugger.HandleCommand(command) + + main_bp = target.BreakpointCreateByName("main") + main_bp.SetScriptCallbackFunction("action_debugging.on_start") + main_bp.SetAutoContinue(auto_continue=True) + + on_breackpoint = target.BreakpointCreateByName("mlirDebuggerBreakpointHook") + + print( + 'The "{0}" command has been installed for target `{1}`, type "help {0}" or "{0} ' + '--help" for detailed help.'.format(MlirDebug.lldb_command, target) + ) + for _name, cls in inspect.getmembers(sys.modules[module_name]): + if inspect.isclass(cls) and getattr(cls, "register_mlir_subparser", None): + cls.register_mlir_subparser()