diff --git a/mlir/include/mlir/Debug/Counter.h b/mlir/include/mlir/Debug/Counter.h --- a/mlir/include/mlir/Debug/Counter.h +++ b/mlir/include/mlir/Debug/Counter.h @@ -47,6 +47,8 @@ /// Register the command line options for debug counters. static void registerCLOptions(); + /// Returns true if any of the CL options is activated. + static bool isDebugCounterActivated(); private: // Returns true if the next action matching this tag should be executed. diff --git a/mlir/include/mlir/Debug/Observers/ActionLogging.h b/mlir/include/mlir/Debug/Observers/ActionLogging.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Debug/Observers/ActionLogging.h @@ -0,0 +1,41 @@ +//===- ExecutionContext.h - Execution Context Support *- C++ -*-=============// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRACING_OBSERVERS_ACTIONLOGGING_H +#define MLIR_TRACING_OBSERVERS_ACTIONLOGGING_H + +#include "mlir/Debug/BreakpointManager.h" +#include "mlir/Debug/ExecutionContext.h" +#include "mlir/IR/Action.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace tracing { + +/// This class... +struct ActionLogger : public ExecutionContext::Observer { + ActionLogger(raw_ostream &os, bool printActions = true, + bool printBreakpoints = true) + : os(os), printActions(printActions), printBreakpoints(printBreakpoints) { + } + + void beforeExecute(const ActionActiveStack *action, Breakpoint *breakpoint, + bool willExecute) override; + void afterExecute(const ActionActiveStack *action) override; + +private: + raw_ostream &os; + bool printActions; + bool printBreakpoints; +}; + +} // namespace tracing +} // namespace mlir + +#endif // MLIR_TRACING_OBSERVERS_ACTIONLOGGING_H \ No newline at end of file diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -240,7 +240,7 @@ template void dispatch(function_ref actionFn, Args &&...args) { if (LLVM_UNLIKELY(hasActionHandler())) - dispatchInteral(actionFn, ActionTy(std::forward(args)...)); + dispatchInternal(actionFn, ActionTy(std::forward(args)...)); else actionFn(); } diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -74,6 +74,14 @@ } bool shouldEmitBytecode() const { return emitBytecodeFlag; } + /// Set the filename to use for logging actions, use "-" for stdout. + MlirOptMainConfig &logActionsTo(StringRef filename) { + logActionsToFlag = filename; + return *this; + } + /// Get the filename to use for logging actions. + StringRef getLogActionsTo() const { return logActionsToFlag; } + /// Set the callback to populate the pass manager. MlirOptMainConfig & setPassPipelineSetupFn(std::function callback) { @@ -149,6 +157,9 @@ /// Emit bytecode instead of textual assembly when generating output. bool emitBytecodeFlag = false; + /// Log action execution to the given file (or "-" for stdout) + std::string logActionsToFlag; + /// The callback to populate the pass manager. std::function passPipelineCallback; diff --git a/mlir/lib/Debug/CMakeLists.txt b/mlir/lib/Debug/CMakeLists.txt --- a/mlir/lib/Debug/CMakeLists.txt +++ b/mlir/lib/Debug/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(Observers) + add_mlir_library(MLIRDebug DebugCounter.cpp ExecutionContext.cpp @@ -10,3 +12,4 @@ MLIRIR MLIRSupport ) + diff --git a/mlir/lib/Debug/DebugCounter.cpp b/mlir/lib/Debug/DebugCounter.cpp --- a/mlir/lib/Debug/DebugCounter.cpp +++ b/mlir/lib/Debug/DebugCounter.cpp @@ -116,6 +116,11 @@ #endif } +bool DebugCounter::isDebugCounterActivated() { + return clOptions->counters.getNumOccurrences() || + clOptions->printCounterInfo.getNumOccurrences(); +} + // This is called by the command line parser when it sees a value for the // debug-counter option defined above. void DebugCounter::applyCLOptions() { diff --git a/mlir/lib/Debug/Observers/ActionLogging.cpp b/mlir/lib/Debug/Observers/ActionLogging.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Debug/Observers/ActionLogging.cpp @@ -0,0 +1,49 @@ +//===- ExecutionContext.cpp - Debug Execution Context Support -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Debug/Observers/ActionLogging.h" +#include +#include + +using namespace mlir; +using namespace mlir::tracing; + +//===----------------------------------------------------------------------===// +// ActionLogger +//===----------------------------------------------------------------------===// + +static std::atomic thread_counter; +int64_t getThreadId() { + thread_local int64_t tid = thread_counter++; + return tid; +} + +void ActionLogger::beforeExecute(const ActionActiveStack *action, + Breakpoint *breakpoint, bool willExecute) { + os << "[thread " << getThreadId() << "] "; + if (willExecute) + os << "begins "; + else + os << "skipping "; + if (printBreakpoints) { + if (breakpoint) + os << " (on breakpoint: " << *breakpoint << ") "; + else + os << " (no breakpoint) "; + } + os << "Action "; + if (printActions) + action->getAction().print(os); + else + os << action->getAction().getTag(); +} + +void ActionLogger::afterExecute(const ActionActiveStack *action) { + os << "[thread " << getThreadId() << "] completed `" + << action->getAction().getTag() << "`\n"; +} diff --git a/mlir/lib/Debug/Observers/CMakeLists.txt b/mlir/lib/Debug/Observers/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Debug/Observers/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_library(MLIRObservers + ActionLogging.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Debug/Observers + + LINK_LIBS PUBLIC + ${LLVM_PTHREAD_LIB} + MLIRSupport +) diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -463,12 +463,17 @@ if (pi) pi->runBeforePass(pass, op); - // Invoke the virtual runOnOperation method. - if (auto *adaptor = dyn_cast(pass)) - adaptor->runOnOperation(verifyPasses); - else - pass->runOnOperation(); - bool passFailed = pass->passState->irAndPassFailed.getInt(); + bool passFailed; + op->getContext()->dispatch( + [&]() { + // Invoke the virtual runOnOperation method. + if (auto *adaptor = dyn_cast(pass)) + adaptor->runOnOperation(verifyPasses); + else + pass->runOnOperation(); + passFailed = pass->passState->irAndPassFailed.getInt(); + }, + *pass, op); // Invalidate any non preserved analyses. am.invalidate(pass->passState->preservedAnalyses); diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -8,10 +8,25 @@ #ifndef MLIR_PASS_PASSDETAIL_H_ #define MLIR_PASS_PASSDETAIL_H_ +#include "mlir/IR/Action.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" namespace mlir { +/// Encapsulate the "action" of executing a single pass, used for the MLIR +/// tracing infrastructure. +struct PassExecutionAction : public tracing::ActionImpl { + const Pass &pass; + PassExecutionAction(const Pass &pass, Operation *op) : pass(pass), op(op) {} + static constexpr StringLiteral tag = "pass-execution-action"; + void print(raw_ostream &os) const override { + os << "`" << getTag() << "` " + << " running \"" << pass.getName() << "\" on Operation \"" + << op->getName() << "\"\n"; + } + Operation *op; +}; + namespace detail { //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-opt/CMakeLists.txt b/mlir/lib/Tools/mlir-opt/CMakeLists.txt --- a/mlir/lib/Tools/mlir-opt/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-opt/CMakeLists.txt @@ -7,6 +7,7 @@ LINK_LIBS PUBLIC MLIRBytecodeWriter MLIRDebug + MLIRObservers MLIRPass MLIRParser MLIRSupport diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -14,6 +14,8 @@ #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/Debug/Counter.h" +#include "mlir/Debug/ExecutionContext.h" +#include "mlir/Debug/Observers/ActionLogging.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" @@ -37,6 +39,8 @@ #include "llvm/Support/StringSaver.h" #include "llvm/Support/ThreadPool.h" #include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/raw_ostream.h" +#include using namespace mlir; using namespace llvm; @@ -65,6 +69,12 @@ "parsing"), cl::location(useExplicitModuleFlag), cl::init(false)); + static cl::opt logActionsTo{ + "log-actions-to", + cl::desc("Log action execution to a file, or stderr if " + " '-' is passed"), + cl::location(logActionsToFlag)}; + static cl::opt showDialects( "show-dialects", cl::desc("Print the list of registered dialects and exit"), @@ -120,6 +130,39 @@ return *this; } +/// Set the ExecutionContext on the context and handle the observers. +class InstallDebugHandler { +public: + InstallDebugHandler(MLIRContext &context, const MlirOptMainConfig &config) { + if (config.getLogActionsTo().empty()) { + if (tracing::DebugCounter::isDebugCounterActivated()) + context.registerActionHandler(tracing::DebugCounter{}); + return; + } + if (tracing::DebugCounter::isDebugCounterActivated()) + emitError(UnknownLoc::get(&context), + "Debug counters are incompatible with --log-actions-to option " + "and are disabled"); + std::string errorMessage; + logActionsFile = openOutputFile(config.getLogActionsTo(), &errorMessage); + if (!logActionsFile) { + llvm::errs() << errorMessage << "\n"; + return; + } + logActionsFile->keep(); + raw_fd_ostream &logActionsStream = logActionsFile->os(); + actionLogger = std::make_unique(logActionsStream); + + executionContext.registerObserver(actionLogger.get()); + context.registerActionHandler(executionContext); + } + +private: + std::unique_ptr logActionsFile; + tracing::ExecutionContext executionContext; + std::unique_ptr actionLogger; +}; + /// Perform the actions on the input file indicated by the command line flags /// within the specified context. /// @@ -207,7 +250,8 @@ context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects()); if (config.shouldVerifyDiagnostics()) context.printOpOnDiagnostic(false); - context.registerActionHandler(tracing::DebugCounter{}); + + InstallDebugHandler installDebugHandler{context, config}; // If we are in verify diagnostics mode then we have a lot of work to do, // otherwise just perform the actions without worrying about it. diff --git a/mlir/test/Pass/action-logging.mlir b/mlir/test/Pass/action-logging.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Pass/action-logging.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-opt %s --log-actions-to=- -canonicalize -test-module-pass | FileCheck %s + +// CHECK: [thread 0] begins (no breakpoint) Action `pass-execution-action` running "Canonicalizer" on Operation "builtin.module" +// CHECK: [thread 0] completed `pass-execution-action` +// CHECK: [thread 0] begins (no breakpoint) Action `pass-execution-action` running "(anonymous namespace)::TestModulePass" on Operation "builtin.module" +// CHECK: [thread 0] completed `pass-execution-action`