diff --git a/mlir/include/mlir/Debug/Counter.h b/mlir/include/mlir/Debug/Counter.h --- a/mlir/include/mlir/Debug/Counter.h +++ b/mlir/include/mlir/Debug/Counter.h @@ -47,6 +47,8 @@ /// Register the command line options for debug counters. static void registerCLOptions(); + /// Returns true if any of the CL options are activated. + static bool isActivated(); private: // Returns true if the next action matching this tag should be executed. diff --git a/mlir/include/mlir/Debug/Observers/ActionLogging.h b/mlir/include/mlir/Debug/Observers/ActionLogging.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Debug/Observers/ActionLogging.h @@ -0,0 +1,42 @@ +//===- ActionLogging.h - Logging Actions *- C++ -*-==========================// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRACING_OBSERVERS_ACTIONLOGGING_H +#define MLIR_TRACING_OBSERVERS_ACTIONLOGGING_H + +#include "mlir/Debug/BreakpointManager.h" +#include "mlir/Debug/ExecutionContext.h" +#include "mlir/IR/Action.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace tracing { + +/// This class defines an observer that print Actions before and after execution +/// on the provided stream. +struct ActionLogger : public ExecutionContext::Observer { + ActionLogger(raw_ostream &os, bool printActions = true, + bool printBreakpoints = true) + : os(os), printActions(printActions), printBreakpoints(printBreakpoints) { + } + + void beforeExecute(const ActionActiveStack *action, Breakpoint *breakpoint, + bool willExecute) override; + void afterExecute(const ActionActiveStack *action) override; + +private: + raw_ostream &os; + bool printActions; + bool printBreakpoints; +}; + +} // namespace tracing +} // namespace mlir + +#endif // MLIR_TRACING_OBSERVERS_ACTIONLOGGING_H diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -74,6 +74,14 @@ } bool shouldEmitBytecode() const { return emitBytecodeFlag; } + /// Set the filename to use for logging actions, use "-" for stdout. + MlirOptMainConfig &logActionsTo(StringRef filename) { + logActionsToFlag = filename; + return *this; + } + /// Get the filename to use for logging actions. + StringRef getLogActionsTo() const { return logActionsToFlag; } + /// Set the callback to populate the pass manager. MlirOptMainConfig & setPassPipelineSetupFn(std::function callback) { @@ -149,6 +157,9 @@ /// Emit bytecode instead of textual assembly when generating output. bool emitBytecodeFlag = false; + /// Log action execution to the given file (or "-" for stdout) + std::string logActionsToFlag; + /// The callback to populate the pass manager. std::function passPipelineCallback; diff --git a/mlir/lib/Debug/CMakeLists.txt b/mlir/lib/Debug/CMakeLists.txt --- a/mlir/lib/Debug/CMakeLists.txt +++ b/mlir/lib/Debug/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(Observers) + add_mlir_library(MLIRDebug DebugCounter.cpp ExecutionContext.cpp @@ -10,3 +12,4 @@ MLIRIR MLIRSupport ) + diff --git a/mlir/lib/Debug/DebugCounter.cpp b/mlir/lib/Debug/DebugCounter.cpp --- a/mlir/lib/Debug/DebugCounter.cpp +++ b/mlir/lib/Debug/DebugCounter.cpp @@ -116,6 +116,11 @@ #endif } +bool DebugCounter::isActivated() { + return clOptions->counters.getNumOccurrences() || + clOptions->printCounterInfo.getNumOccurrences(); +} + // This is called by the command line parser when it sees a value for the // debug-counter option defined above. void DebugCounter::applyCLOptions() { diff --git a/mlir/lib/Debug/Observers/ActionLogging.cpp b/mlir/lib/Debug/Observers/ActionLogging.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Debug/Observers/ActionLogging.cpp @@ -0,0 +1,48 @@ +//===- ActionLogging.cpp - Logging Actions *- C++ -*-========================// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Debug/Observers/ActionLogging.h" +#include "llvm/Support/Threading.h" +#include +#include + +using namespace mlir; +using namespace mlir::tracing; + +//===----------------------------------------------------------------------===// +// ActionLogger +//===----------------------------------------------------------------------===// + +void ActionLogger::beforeExecute(const ActionActiveStack *action, + Breakpoint *breakpoint, bool willExecute) { + SmallVector name; + llvm::get_thread_name(name); + os << "[thread " << name << "] "; + if (willExecute) + os << "begins "; + else + os << "skipping "; + if (printBreakpoints) { + if (breakpoint) + os << " (on breakpoint: " << *breakpoint << ") "; + else + os << " (no breakpoint) "; + } + os << "Action "; + if (printActions) + action->getAction().print(os); + else + os << action->getAction().getTag(); +} + +void ActionLogger::afterExecute(const ActionActiveStack *action) { + SmallVector name; + llvm::get_thread_name(name); + os << "[thread " << name << "] completed `" << action->getAction().getTag() + << "`\n"; +} diff --git a/mlir/lib/Debug/Observers/CMakeLists.txt b/mlir/lib/Debug/Observers/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Debug/Observers/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_library(MLIRObservers + ActionLogging.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Debug/Observers + + LINK_LIBS PUBLIC + ${LLVM_PTHREAD_LIB} + MLIRSupport +) diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -31,6 +31,15 @@ using namespace mlir; using namespace mlir::detail; +//===----------------------------------------------------------------------===// +// PassExecutionAction +//===----------------------------------------------------------------------===// + +void PassExecutionAction::print(raw_ostream &os) const { + os << llvm::formatv("`{0}` running `{1}` on Operation `{2}`\n", tag, + pass.getName(), op->getName()); +} + //===----------------------------------------------------------------------===// // Pass //===----------------------------------------------------------------------===// @@ -463,12 +472,17 @@ if (pi) pi->runBeforePass(pass, op); - // Invoke the virtual runOnOperation method. - if (auto *adaptor = dyn_cast(pass)) - adaptor->runOnOperation(verifyPasses); - else - pass->runOnOperation(); - bool passFailed = pass->passState->irAndPassFailed.getInt(); + bool passFailed = false; + op->getContext()->executeAction( + [&]() { + // Invoke the virtual runOnOperation method. + if (auto *adaptor = dyn_cast(pass)) + adaptor->runOnOperation(verifyPasses); + else + pass->runOnOperation(); + passFailed = pass->passState->irAndPassFailed.getInt(); + }, + *pass, op); // Invalidate any non preserved analyses. am.invalidate(pass->passState->preservedAnalyses); diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -8,10 +8,26 @@ #ifndef MLIR_PASS_PASSDETAIL_H_ #define MLIR_PASS_PASSDETAIL_H_ +#include "mlir/IR/Action.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "llvm/Support/FormatVariadic.h" namespace mlir { +/// Encapsulate the "action" of executing a single pass, used for the MLIR +/// tracing infrastructure. +struct PassExecutionAction : public tracing::ActionImpl { + PassExecutionAction(const Pass &pass, Operation *op) : pass(pass), op(op) {} + static constexpr StringLiteral tag = "pass-execution-action"; + void print(raw_ostream &os) const override; + const Pass &getPass() { return pass; } + Operation *getOp() { return op; } + +public: + const Pass &pass; + Operation *op; +}; + namespace detail { //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-opt/CMakeLists.txt b/mlir/lib/Tools/mlir-opt/CMakeLists.txt --- a/mlir/lib/Tools/mlir-opt/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-opt/CMakeLists.txt @@ -7,6 +7,7 @@ LINK_LIBS PUBLIC MLIRBytecodeWriter MLIRDebug + MLIRObservers MLIRPass MLIRParser MLIRSupport diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -14,6 +14,8 @@ #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/Debug/Counter.h" +#include "mlir/Debug/ExecutionContext.h" +#include "mlir/Debug/Observers/ActionLogging.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" @@ -71,6 +73,12 @@ "parsing"), cl::location(useExplicitModuleFlag), cl::init(false)); + static cl::opt logActionsTo{ + "log-actions-to", + cl::desc("Log action execution to a file, or stderr if " + " '-' is passed"), + cl::location(logActionsToFlag)}; + static cl::opt showDialects( "show-dialects", cl::desc("Print the list of registered dialects and exit"), @@ -126,6 +134,41 @@ return *this; } +/// Set the ExecutionContext on the context and handle the observers. +class InstallDebugHandler { +public: + InstallDebugHandler(MLIRContext &context, const MlirOptMainConfig &config) { + if (config.getLogActionsTo().empty()) { + if (tracing::DebugCounter::isActivated()) + context.registerActionHandler(tracing::DebugCounter()); + return; + } + if (tracing::DebugCounter::isActivated()) + emitError(UnknownLoc::get(&context), + "Debug counters are incompatible with --log-actions-to option " + "and are disabled"); + std::string errorMessage; + logActionsFile = openOutputFile(config.getLogActionsTo(), &errorMessage); + if (!logActionsFile) { + emitError(UnknownLoc::get(&context), + "Opening file for --log-actions-to failed: ") + << errorMessage << "\n"; + return; + } + logActionsFile->keep(); + raw_fd_ostream &logActionsStream = logActionsFile->os(); + actionLogger = std::make_unique(logActionsStream); + + executionContext.registerObserver(actionLogger.get()); + context.registerActionHandler(executionContext); + } + +private: + std::unique_ptr logActionsFile; + std::unique_ptr actionLogger; + tracing::ExecutionContext executionContext; +}; + /// Perform the actions on the input file indicated by the command line flags /// within the specified context. /// @@ -213,7 +256,8 @@ context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects()); if (config.shouldVerifyDiagnostics()) context.printOpOnDiagnostic(false); - context.registerActionHandler(tracing::DebugCounter()); + + InstallDebugHandler installDebugHandler(context, config); // If we are in verify diagnostics mode then we have a lot of work to do, // otherwise just perform the actions without worrying about it. diff --git a/mlir/test/Pass/action-logging.mlir b/mlir/test/Pass/action-logging.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Pass/action-logging.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-opt %s --log-actions-to=- -canonicalize -test-module-pass | FileCheck %s + +// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running "Canonicalizer" on Operation "builtin.module" +// CHECK: [thread {{.*}}] completed `pass-execution-action` +// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running "(anonymous namespace)::TestModulePass" on Operation "builtin.module" +// CHECK: [thread {{.*}}] completed `pass-execution-action`