diff --git a/mlir/include/mlir/Debug/Observers/ActionLogging.h b/mlir/include/mlir/Debug/Observers/ActionLogging.h --- a/mlir/include/mlir/Debug/Observers/ActionLogging.h +++ b/mlir/include/mlir/Debug/Observers/ActionLogging.h @@ -22,9 +22,9 @@ /// on the provided stream. struct ActionLogger : public ExecutionContext::Observer { ActionLogger(raw_ostream &os, bool printActions = true, - bool printBreakpoints = true) - : os(os), printActions(printActions), printBreakpoints(printBreakpoints) { - } + bool printBreakpoints = true, bool printIRUnits = true) + : os(os), printActions(printActions), printBreakpoints(printBreakpoints), + printIRUnits(printIRUnits) {} void beforeExecute(const ActionActiveStack *action, Breakpoint *breakpoint, bool willExecute) override; @@ -34,6 +34,7 @@ raw_ostream &os; bool printActions; bool printBreakpoints; + bool printIRUnits; }; } // namespace tracing diff --git a/mlir/include/mlir/IR/Action.h b/mlir/include/mlir/IR/Action.h --- a/mlir/include/mlir/IR/Action.h +++ b/mlir/include/mlir/IR/Action.h @@ -15,6 +15,7 @@ #ifndef MLIR_IR_ACTION_H #define MLIR_IR_ACTION_H +#include "mlir/IR/Unit.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/TypeID.h" #include "llvm/ADT/ArrayRef.h" @@ -51,11 +52,19 @@ os << "Action \"" << getTag() << "\""; } + /// Return the set of IR units that are associated with this action. + virtual ArrayRef getContextIRUnits() const { return irUnits; } + protected: - Action(TypeID actionID) : actionID(actionID) {} + Action(TypeID actionID, ArrayRef irUnits) + : actionID(actionID), irUnits(irUnits) {} /// The type of the derived action class, used for `isa`/`dyn_cast`. TypeID actionID; + + /// Set of IR units (operations, regions, blocks, values) that are associated + /// with this action. + ArrayRef irUnits; }; /// CRTP Implementation of an action. This class provides a base class for @@ -67,7 +76,8 @@ template class ActionImpl : public Action { public: - ActionImpl() : Action(TypeID::get()) {} + ActionImpl(ArrayRef irUnits = {}) + : Action(TypeID::get(), irUnits) {} /// Provide classof to allow casting between action types. static bool classof(const Action *action) { diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -11,6 +11,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/TypeID.h" +#include "llvm/ADT/ArrayRef.h" #include #include #include @@ -265,9 +266,10 @@ /// Dispatch the provided action to the handler if any, or just execute it. template - void executeAction(function_ref actionFn, Args &&...args) { + void executeAction(function_ref actionFn, ArrayRef irUnits, + Args &&...args) { if (LLVM_UNLIKELY(hasActionHandler())) - executeActionInternal(actionFn, + executeActionInternal(actionFn, irUnits, std::forward(args)...); else actionFn(); @@ -286,8 +288,10 @@ /// avoid calling the ctor for the Action unnecessarily. template LLVM_ATTRIBUTE_NOINLINE void - executeActionInternal(function_ref actionFn, Args &&...args) { - executeActionInternal(actionFn, ActionTy(std::forward(args)...)); + executeActionInternal(function_ref actionFn, ArrayRef irUnits, + Args &&...args) { + executeActionInternal(actionFn, + ActionTy(irUnits, std::forward(args)...)); } const std::unique_ptr impl; diff --git a/mlir/include/mlir/IR/Unit.h b/mlir/include/mlir/IR/Unit.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/Unit.h @@ -0,0 +1,42 @@ +//===- Unit.h - IR Unit definition--------------------*- C++ -*-=============// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_UNIT_H +#define MLIR_IR_UNIT_H + +#include "mlir/IR/OperationSupport.h" +#include "llvm/ADT/PointerUnion.h" +#include "llvm/Support/raw_ostream.h" + +namespace llvm { +class raw_ostream; +} // namespace llvm +namespace mlir { +class Operation; +class Region; +class Block; +class Value; + +/// IRUnit is a union of the different types of IR objects that consistute the +/// IR structure (other than Type and Attribute), that is Operation, Region, and +/// Block. +class IRUnit : public PointerUnion { +public: + using PointerUnion::PointerUnion; + + /// Print the IRUnit to the given stream. + void print(raw_ostream &os, + OpPrintingFlags flags = + OpPrintingFlags().skipRegions().useLocalScope()) const; +}; + +raw_ostream &operator<<(raw_ostream &os, const IRUnit &unit); + +} // end namespace mlir + +#endif // MLIR_IR_UNIT_H diff --git a/mlir/lib/Debug/Observers/ActionLogging.cpp b/mlir/lib/Debug/Observers/ActionLogging.cpp --- a/mlir/lib/Debug/Observers/ActionLogging.cpp +++ b/mlir/lib/Debug/Observers/ActionLogging.cpp @@ -7,9 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Debug/Observers/ActionLogging.h" +#include "mlir/IR/Action.h" #include "llvm/Support/Threading.h" -#include -#include +#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::tracing; @@ -22,6 +22,10 @@ Breakpoint *breakpoint, bool willExecute) { SmallVector name; llvm::get_thread_name(name); + if (name.empty()) { + llvm::raw_svector_ostream os(name); + os << llvm::get_threadid(); + } os << "[thread " << name << "] "; if (willExecute) os << "begins "; @@ -29,21 +33,30 @@ os << "skipping "; if (printBreakpoints) { if (breakpoint) - os << " (on breakpoint: " << *breakpoint << ") "; + os << "(on breakpoint: " << *breakpoint << ") "; else - os << " (no breakpoint) "; + os << "(no breakpoint) "; } os << "Action "; if (printActions) action->getAction().print(os); else os << action->getAction().getTag(); + if (printIRUnits) { + os << " ("; + interleaveComma(action->getAction().getContextIRUnits(), os); + os << ")"; + } os << "`\n"; } void ActionLogger::afterExecute(const ActionActiveStack *action) { SmallVector name; llvm::get_thread_name(name); + if (name.empty()) { + llvm::raw_svector_ostream os(name); + os << llvm::get_threadid(); + } os << "[thread " << name << "] completed `" << action->getAction().getTag() << "`\n"; } diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -32,6 +32,7 @@ Types.cpp TypeRange.cpp TypeUtilities.cpp + Unit.cpp Value.cpp ValueRange.cpp Verifier.cpp diff --git a/mlir/lib/IR/Unit.cpp b/mlir/lib/IR/Unit.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/Unit.cpp @@ -0,0 +1,63 @@ +//===- Unit.cpp - Support for manipulating IR Unit ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Unit.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Region.h" +#include "llvm/Support/raw_ostream.h" +#include +#include + +using namespace mlir; + +static void printOp(llvm::raw_ostream &os, Operation *op, + OpPrintingFlags &flags) { + if (!op) { + os << ""; + return; + } + op->print(os, flags); +} + +static void printRegion(llvm::raw_ostream &os, Region *region, + OpPrintingFlags &flags) { + if (!region) { + os << ""; + return; + } + os << "Region #" << region->getRegionNumber() << " for op "; + printOp(os, region->getParentOp(), flags); +} + +static void printBlock(llvm::raw_ostream &os, Block *block, + OpPrintingFlags &flags) { + Region *region = block->getParent(); + Block *entry = ®ion->front(); + int blockId = std::distance(entry->getIterator(), block->getIterator()); + os << "Block #" << blockId << " for "; + bool shouldSkipRegions = flags.shouldSkipRegions(); + printRegion(os, region, flags.skipRegions()); + if (!shouldSkipRegions) + block->print(os); +} + +void mlir::IRUnit::print(llvm::raw_ostream &os, OpPrintingFlags flags) const { + if (auto *op = this->dyn_cast()) + return printOp(os, op, flags); + if (auto *region = this->dyn_cast()) + return printRegion(os, region, flags); + if (auto *block = this->dyn_cast()) + return printBlock(os, block, flags); + llvm_unreachable("unknown IRUnit"); +} + +llvm::raw_ostream &mlir::operator<<(llvm::raw_ostream &os, const IRUnit &unit) { + unit.print(os); + return os; +} diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -482,7 +482,7 @@ pass->runOnOperation(); passFailed = pass->passState->irAndPassFailed.getInt(); }, - *pass, op); + {op}, *pass); // Invalidate any non preserved analyses. am.invalidate(pass->passState->preservedAnalyses); diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -11,17 +11,23 @@ #include "mlir/IR/Action.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/FormatVariadic.h" namespace mlir { /// Encapsulate the "action" of executing a single pass, used for the MLIR /// tracing infrastructure. struct PassExecutionAction : public tracing::ActionImpl { - PassExecutionAction(const Pass &pass, Operation *op) : pass(pass), op(op) {} + using Base = tracing::ActionImpl; + PassExecutionAction(ArrayRef irUnits, const Pass &pass) + : Base(irUnits), pass(pass) {} static constexpr StringLiteral tag = "pass-execution-action"; void print(raw_ostream &os) const override; const Pass &getPass() const { return pass; } - Operation *getOp() const { return op; } + Operation *getOp() const { + ArrayRef irUnits = getContextIRUnits(); + return irUnits.empty() ? nullptr : irUnits[0].dyn_cast(); + } public: const Pass &pass; diff --git a/mlir/test/Pass/action-logging.mlir b/mlir/test/Pass/action-logging.mlir --- a/mlir/test/Pass/action-logging.mlir +++ b/mlir/test/Pass/action-logging.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s --log-actions-to=- -canonicalize -test-module-pass | FileCheck %s -// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running `Canonicalizer` on Operation `builtin.module` -// CHECK: [thread {{.*}}] completed `pass-execution-action` -// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running `(anonymous namespace)::TestModulePass` on Operation `builtin.module` -// CHECK: [thread {{.*}}] completed `pass-execution-action` +// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running `Canonicalizer` on Operation `builtin.module` (module {...} +// CHECK-NEXT: [thread {{.*}}] completed `pass-execution-action` +// CHECK-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running `{{.*}}TestModulePass` on Operation `builtin.module` (module {...} +// CHECK-NEXT: [thread {{.*}}] completed `pass-execution-action` +// CHECK-NOT: Action