diff --git a/mlir/include/mlir/Debug/Observers/ActionLogging.h b/mlir/include/mlir/Debug/Observers/ActionLogging.h --- a/mlir/include/mlir/Debug/Observers/ActionLogging.h +++ b/mlir/include/mlir/Debug/Observers/ActionLogging.h @@ -22,9 +22,9 @@ /// on the provided stream. struct ActionLogger : public ExecutionContext::Observer { ActionLogger(raw_ostream &os, bool printActions = true, - bool printBreakpoints = true) - : os(os), printActions(printActions), printBreakpoints(printBreakpoints) { - } + bool printBreakpoints = true, bool printIRUnits = true) + : os(os), printActions(printActions), printBreakpoints(printBreakpoints), + printIRUnits(printIRUnits) {} void beforeExecute(const ActionActiveStack *action, Breakpoint *breakpoint, bool willExecute) override; @@ -34,6 +34,7 @@ raw_ostream &os; bool printActions; bool printBreakpoints; + bool printIRUnits; }; } // namespace tracing diff --git a/mlir/include/mlir/IR/Action.h b/mlir/include/mlir/IR/Action.h --- a/mlir/include/mlir/IR/Action.h +++ b/mlir/include/mlir/IR/Action.h @@ -15,6 +15,7 @@ #ifndef MLIR_IR_ACTION_H #define MLIR_IR_ACTION_H +#include "mlir/IR/Unit.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/TypeID.h" #include "llvm/ADT/ArrayRef.h" @@ -51,11 +52,19 @@ os << "Action \"" << getTag() << "\""; } + /// Return the set of IR units that are associated with this action. + virtual ArrayRef getContextIRUnits() const { return irUnits; } + protected: - Action(TypeID actionID) : actionID(actionID) {} + Action(TypeID actionID, ArrayRef irUnits) + : actionID(actionID), irUnits(irUnits) {} /// The type of the derived action class, used for `isa`/`dyn_cast`. TypeID actionID; + + /// Set of IR units (operations, regions, blocks, values) that are associated + /// with this action. + ArrayRef irUnits; }; /// CRTP Implementation of an action. This class provides a base class for @@ -67,7 +76,8 @@ template class ActionImpl : public Action { public: - ActionImpl() : Action(TypeID::get()) {} + ActionImpl(ArrayRef irUnits = {}) + : Action(TypeID::get(), irUnits) {} /// Provide classof to allow casting between action types. static bool classof(const Action *action) { diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -11,6 +11,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/TypeID.h" +#include "llvm/ADT/ArrayRef.h" #include #include #include @@ -265,9 +266,10 @@ /// Dispatch the provided action to the handler if any, or just execute it. template - void executeAction(function_ref actionFn, Args &&...args) { + void executeAction(function_ref actionFn, ArrayRef irUnits, + Args &&...args) { if (LLVM_UNLIKELY(hasActionHandler())) - executeActionInternal(actionFn, + executeActionInternal(actionFn, irUnits, std::forward(args)...); else actionFn(); @@ -286,8 +288,10 @@ /// avoid calling the ctor for the Action unnecessarily. template LLVM_ATTRIBUTE_NOINLINE void - executeActionInternal(function_ref actionFn, Args &&...args) { - executeActionInternal(actionFn, ActionTy(std::forward(args)...)); + executeActionInternal(function_ref actionFn, ArrayRef irUnits, + Args &&...args) { + executeActionInternal(actionFn, + ActionTy(irUnits, std::forward(args)...)); } const std::unique_ptr impl; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -834,6 +834,9 @@ /// Always print operations in the generic form. OpPrintingFlags &printGenericOpForm(); + /// Skip printing regions. + OpPrintingFlags &skipRegions(); + /// Do not verify the operation when using custom operation printers. OpPrintingFlags &assumeVerified(); @@ -861,6 +864,9 @@ /// Return if operations should be printed in the generic form. bool shouldPrintGenericOpForm() const; + /// Return if regions should be skipped. + bool shouldSkipRegions() const; + /// Return if operation verification should be skipped. bool shouldAssumeVerified() const; @@ -882,6 +888,9 @@ /// Print operations in the generic form. bool printGenericOpFormFlag : 1; + /// Always skip Regions. + bool skipRegionsFlag : 1; + /// Skip operation verification. bool assumeVerifiedFlag : 1; diff --git a/mlir/include/mlir/IR/Unit.h b/mlir/include/mlir/IR/Unit.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/Unit.h @@ -0,0 +1,43 @@ +//===- Unit.h - IR Unit definition--------------------*- C++ -*-=============// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_UNIT_H +#define MLIR_IR_UNIT_H + +#include "mlir/IR/OperationSupport.h" +#include "llvm/ADT/PointerUnion.h" +#include "llvm/Support/raw_ostream.h" + +namespace llvm { +class raw_ostream; +} // namespace llvm +namespace mlir { +class Operation; +class Region; +class Block; +class Value; + +/// IRUnit is a union of the different types of IR objects that consistute the +/// IR structure (other than Type and Attribute), that is Operation, Region, and +/// Block. +class IRUnit + : public llvm::PointerUnion { +public: + using PointerUnion::PointerUnion; + + /// Print the IRUnit to the given stream. + void print(llvm::raw_ostream &os, + OpPrintingFlags flags = + OpPrintingFlags().skipRegions().useLocalScope()) const; +}; + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const IRUnit &unit); + +} // end namespace mlir + +#endif // MLIR_IR_UNIT_H diff --git a/mlir/lib/Debug/Observers/ActionLogging.cpp b/mlir/lib/Debug/Observers/ActionLogging.cpp --- a/mlir/lib/Debug/Observers/ActionLogging.cpp +++ b/mlir/lib/Debug/Observers/ActionLogging.cpp @@ -7,9 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Debug/Observers/ActionLogging.h" +#include "mlir/IR/Action.h" #include "llvm/Support/Threading.h" -#include -#include using namespace mlir; using namespace mlir::tracing; @@ -27,15 +26,20 @@ os << "skipping "; if (printBreakpoints) { if (breakpoint) - os << " (on breakpoint: " << *breakpoint << ") "; + os << "(on breakpoint: " << *breakpoint << ") "; else - os << " (no breakpoint) "; + os << "(no breakpoint) "; } os << "Action "; if (printActions) action->getAction().print(os); else os << action->getAction().getTag(); + if (printIRUnits) { + os << " ("; + interleaveComma(action->getAction().getContextIRUnits(), os); + os << ")\n"; + } } void ActionLogger::afterExecute(const ActionActiveStack *action) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -183,8 +183,9 @@ /// Initialize the printing flags with default supplied by the cl::opts above. OpPrintingFlags::OpPrintingFlags() : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), - printGenericOpFormFlag(false), assumeVerifiedFlag(false), - printLocalScope(false), printValueUsersFlag(false) { + printGenericOpFormFlag(false), skipRegionsFlag(false), + assumeVerifiedFlag(false), printLocalScope(false), + printValueUsersFlag(false) { // Initialize based upon command line options, if they are available. if (!clOptions.isConstructed()) return; @@ -223,6 +224,12 @@ return *this; } +/// Always skip Regions. +OpPrintingFlags &OpPrintingFlags::skipRegions() { + skipRegionsFlag = true; + return *this; +} + /// Do not verify the operation when using custom operation printers. OpPrintingFlags &OpPrintingFlags::assumeVerified() { assumeVerifiedFlag = true; @@ -270,6 +277,9 @@ return printGenericOpFormFlag; } +/// Return if Region should be skipped. +bool OpPrintingFlags::shouldSkipRegions() const { return skipRegionsFlag; } + /// Return if operation verification should be skipped. bool OpPrintingFlags::shouldAssumeVerified() const { return assumeVerifiedFlag; @@ -614,9 +624,11 @@ /// Print the given operation in the generic form. void printGenericOp(Operation *op, bool printOpName = true) override { // Consider nested operations for aliases. - for (Region ®ion : op->getRegions()) - printRegion(region, /*printEntryBlockArgs=*/true, - /*printBlockTerminators=*/true); + if (!printerFlags.shouldSkipRegions()) { + for (Region ®ion : op->getRegions()) + printRegion(region, /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/true); + } // Visit all the types used in the operation. for (Type type : op->getOperandTypes()) @@ -665,6 +677,10 @@ bool printEmptyBlock = false) override { if (region.empty()) return; + if (printerFlags.shouldSkipRegions()) { + os << "{/*skip region*/}"; + return; + } auto *entryBlock = ®ion.front(); print(entryBlock, printEntryBlockArgs, printBlockTerminators); @@ -3341,10 +3357,14 @@ // Print regions. if (op->getNumRegions() != 0) { os << " ("; - interleaveComma(op->getRegions(), [&](Region ®ion) { - printRegion(region, /*printEntryBlockArgs=*/true, - /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); - }); + if (!printerFlags.shouldSkipRegions()) { + interleaveComma(op->getRegions(), [&](Region ®ion) { + printRegion(region, /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); + }); + } else { + os << "/*skip " << op->getNumRegions() << " regions*/"; + } os << ')'; } @@ -3463,6 +3483,10 @@ void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, bool printBlockTerminators, bool printEmptyBlock) { + if (printerFlags.shouldSkipRegions()) { + os << "{/*skip region*/}"; + return; + } os << "{" << newLine; if (!region.empty()) { auto restoreDefaultDialect = diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -32,6 +32,7 @@ Types.cpp TypeRange.cpp TypeUtilities.cpp + Unit.cpp Value.cpp ValueRange.cpp Verifier.cpp diff --git a/mlir/lib/IR/Unit.cpp b/mlir/lib/IR/Unit.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/Unit.cpp @@ -0,0 +1,72 @@ +//===- Unit.cpp - Support for manipulating IR Unit ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Unit.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Region.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir; + +static void printOp(llvm::raw_ostream &os, Operation *op, + OpPrintingFlags &flags) { + if (!op) { + os << ""; + return; + } + op->print(os, flags); +} + +static void printRegion(llvm::raw_ostream &os, Region *region, + OpPrintingFlags &flags) { + if (!region) { + os << ""; + return; + } + os << "Region #" << region->getRegionNumber() << " for op "; + printOp(os, region->getParentOp(), flags); +} + +static void printBlock(llvm::raw_ostream &os, Block *block, + OpPrintingFlags &flags) { + int blockId = 0; + Region *region = block->getParent(); + Block *cur = ®ion->front(); + while (cur != block) { + cur = cur->getNextNode(); + blockId++; + } + os << "Block #" << blockId << " for "; + bool shouldSkipRegions = flags.shouldSkipRegions(); + printRegion(os, region, flags.skipRegions()); + if (!shouldSkipRegions) + block->print(os); +} + +void mlir::IRUnit::print(llvm::raw_ostream &os, OpPrintingFlags flags) const { + if (auto *op = this->dyn_cast()) { + printOp(os, op, flags); + return; + } + if (auto *region = this->dyn_cast()) { + printRegion(os, region, flags); + return; + } + if (auto *block = this->dyn_cast()) { + printBlock(os, block, flags); + return; + } + llvm_unreachable("unknown IRUnit"); +} + +llvm::raw_ostream &mlir::operator<<(llvm::raw_ostream &os, const IRUnit &unit) { + unit.print(os); + return os; +} diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -473,7 +473,7 @@ pass->runOnOperation(); passFailed = pass->passState->irAndPassFailed.getInt(); }, - *pass, op); + {op}, *pass); // Invalidate any non preserved analyses. am.invalidate(pass->passState->preservedAnalyses); diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -17,11 +17,19 @@ /// Encapsulate the "action" of executing a single pass, used for the MLIR /// tracing infrastructure. struct PassExecutionAction : public tracing::ActionImpl { - PassExecutionAction(const Pass &pass, Operation *op) : pass(pass), op(op) {} + using Base = tracing::ActionImpl; + PassExecutionAction(ArrayRef irUnits, const Pass &pass) + : Base(irUnits), pass(pass) {} static constexpr StringLiteral tag = "pass-execution-action"; void print(raw_ostream &os) const override { - os << llvm::formatv("`{0}` running \"{1}\" on Operation \"{2}\"\n", tag, - pass.getName(), op->getName()); + os << "`" << tag << "` " + << " running \"" << pass.getName() << "\" on Operation \""; + ArrayRef irUnits = getContextIRUnits(); + if (irUnits.empty()) { + os << ""; + } else { + os << irUnits.front().dyn_cast()->getName() << "\""; + } } const Pass &getPass() { return pass; } Operation *getOp() { return op; } diff --git a/mlir/test/Pass/action-logging.mlir b/mlir/test/Pass/action-logging.mlir --- a/mlir/test/Pass/action-logging.mlir +++ b/mlir/test/Pass/action-logging.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s --log-actions-to=- -canonicalize -test-module-pass | FileCheck %s -// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running "Canonicalizer" on Operation "builtin.module" -// CHECK: [thread {{.*}}] completed `pass-execution-action` -// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running "(anonymous namespace)::TestModulePass" on Operation "builtin.module" -// CHECK: [thread {{.*}}] completed `pass-execution-action` +// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running "Canonicalizer" on Operation "builtin.module" (module {/*skip region*/} +// CHECK-NEXT: [thread {{.*}}] completed `pass-execution-action` +// CHECK-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running "(anonymous namespace)::TestModulePass" on Operation "builtin.module" (module {/*skip region*/} +// CHECK-NEXT: [thread {{.*}}] completed `pass-execution-action` +// CHECK-NOT: Action