diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -688,21 +688,19 @@ let constructor = "mlir::createSymbolDCEPass()"; } -def ViewOpGraphPass : Pass<"view-op-graph", "ModuleOp"> { - let summary = "Print graphviz view of module"; +def ViewOpGraphPass : Pass<"view-op-graph"> { + let summary = "Print Graphviz dataflow visualization of an operation"; let description = [{ - This pass prints a graphviz per block of a module. + This pass prints a Graphviz dataflow graph of a module. - - Op are represented as nodes; + - Operations are represented as nodes; - Uses as edges; + - Regions/blocks as subgraphs. + + Note: See https://www.graphviz.org/doc/info/lang.html for more information + about the Graphviz DOT language. }]; let constructor = "mlir::createPrintOpGraphPass()"; - let options = [ - Option<"title", "title", "std::string", - /*default=*/"", "The prefix of the title of the graph">, - Option<"shortNames", "short-names", "bool", /*default=*/"false", - "Use short names"> - ]; } #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Transforms/ViewOpGraph.h b/mlir/include/mlir/Transforms/ViewOpGraph.h --- a/mlir/include/mlir/Transforms/ViewOpGraph.h +++ b/mlir/include/mlir/Transforms/ViewOpGraph.h @@ -14,27 +14,13 @@ #define MLIR_TRANSFORMS_VIEWOPGRAPH_H_ #include "mlir/Support/LLVM.h" -#include "llvm/Support/GraphWriter.h" #include "llvm/Support/raw_ostream.h" namespace mlir { -class Block; -class ModuleOp; -template class OperationPass; - -/// Displays the graph in a window. This is for use from the debugger and -/// depends on Graphviz to generate the graph. -void viewGraph(Block &block, const Twine &name, bool shortNames = false, - const Twine &title = "", - llvm::GraphProgram::Name program = llvm::GraphProgram::DOT); - -raw_ostream &writeGraph(raw_ostream &os, Block &block, bool shortNames = false, - const Twine &title = ""); +class Pass; /// Creates a pass to print op graphs. -std::unique_ptr> -createPrintOpGraphPass(raw_ostream &os = llvm::errs(), bool shortNames = false, - const Twine &title = ""); +std::unique_ptr createPrintOpGraphPass(raw_ostream &os = llvm::errs()); } // end namespace mlir diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -9,12 +9,16 @@ #include "mlir/Transforms/ViewOpGraph.h" #include "PassDetail.h" #include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" -#include "llvm/Support/CommandLine.h" +#include "mlir/Support/IndentedOstream.h" +#include "llvm/Support/Format.h" using namespace mlir; +static const StringRef kLineStyleDataFlow = "solid"; +static const StringRef kShapeNode = "ellipse"; +static const StringRef kShapeNone = "plain"; + /// Return the size limits for eliding large attributes. static int64_t getLargeAttributeSizeLimit() { // Use the default from the printer flags if possible. @@ -23,145 +27,251 @@ return 16; } -namespace llvm { +/// Return all values printed onto a stream as a string. +static std::string strFromOs(function_ref func) { + std::string buf; + llvm::raw_string_ostream os(buf); + func(os); + return os.str(); +} + +/// Escape special characters such as '\n' and quotation marks. +static std::string escapeString(std::string str) { + return strFromOs([&](raw_ostream &os) { os.write_escaped(str); }); +} + +/// Put quotation marks around a given string. +static std::string quoteString(std::string str) { return "\"" + str + "\""; } -// Specialize GraphTraits to treat Block as a graph of Operations as nodes and -// uses as edges. -template <> struct GraphTraits { - using GraphType = Block *; - using NodeRef = Operation *; +using AttributeMap = llvm::StringMap; - using ChildIteratorType = Operation::user_iterator; - static ChildIteratorType child_begin(NodeRef n) { return n->user_begin(); } - static ChildIteratorType child_end(NodeRef n) { return n->user_end(); } +namespace { + +/// This struct represents a node in the DOT language. Each node has an +/// identifier and an optional identifier for the cluster (subgraph) that +/// contains the node. +/// Note: In the DOT language, edges can be drawn only from nodes to nodes, but +/// not between clusters. However, edges can be clipped to the boundary of a +/// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new +/// cluster, an invisible "anchor" node is created. +struct Node { +public: + Node(int id = 0, Optional clusterId = llvm::None) + : id(id), clusterId(clusterId) {} - // Operation's destructor is private so use Operation* instead and use - // mapped iterator. - static Operation *AddressOf(Operation &op) { return &op; } - using nodes_iterator = mapped_iterator; - static nodes_iterator nodes_begin(Block *b) { - return nodes_iterator(b->begin(), &AddressOf); + int id; + Optional clusterId; +}; + +/// This pass generates a Graphviz dataflow visualization of an MLIR operation. +/// Note: See https://www.graphviz.org/doc/info/lang.html for more information +/// about the Graphviz DOT language. +class PrintOpPass : public ViewOpGraphPassBase { +public: + PrintOpPass(raw_ostream &os) : os(os), plainOs(os) {} + PrintOpPass(const PrintOpPass &o) : os(o.plainOs), plainOs(o.plainOs) {} + + void runOnOperation() override { + emitGraph([&]() { + processOperation(getOperation()); + emitAllEdgeStmts(); + }); } - static nodes_iterator nodes_end(Block *b) { - return nodes_iterator(b->end(), &AddressOf); + +private: + /// Emit all edges. This function should be called after all nodes have been + /// emitted. + void emitAllEdgeStmts() { + for (const std::string &edge : edges) + os << edge << ";\n"; + edges.clear(); } -}; -// Specialize DOTGraphTraits to produce more readable output. -template <> struct DOTGraphTraits : public DefaultDOTGraphTraits { - using DefaultDOTGraphTraits::DefaultDOTGraphTraits; - static std::string getNodeLabel(Operation *op, Block *); -}; + /// Emit a cluster (subgraph). The specified builder generates the body of the + /// cluster. Return the anchor node of the cluster. + Node emitClusterStmt(function_ref builder, std::string label = "") { + int clusterId = ++counter; + os << "subgraph cluster_" << clusterId << " {\n"; + os.indent(); + // Emit invisible anchor node from/to which arrows can be drawn. + Node anchorNode = emitNodeStmt(" ", kShapeNone); + os << attrStmt("label", quoteString(escapeString(label))) << ";\n"; + builder(); + os.unindent(); + os << "}\n"; + return Node(anchorNode.id, clusterId); + } -std::string DOTGraphTraits::getNodeLabel(Operation *op, Block *b) { - // Reuse the print output for the node labels. - std::string ostr; - raw_string_ostream os(ostr); - os << op->getName() << "\n"; + /// Generate an attribute statement. + std::string attrStmt(const Twine &key, const Twine &value) { + return (key + " = " + value).str(); + } - if (!op->getLoc().isa()) { - os << op->getLoc() << "\n"; + /// Emit an attribute list. + void emitAttrList(raw_ostream &os, const AttributeMap &map) { + os << "["; + interleaveComma(map, os, [&](const auto &it) { + os << attrStmt(it.getKey(), it.getValue()); + }); + os << "]"; } - // Print resultant types - llvm::interleaveComma(op->getResultTypes(), os); - os << "\n"; + // Print an MLIR attribute to `os`. Large attributes are truncated. + void emitMlirAttr(raw_ostream &os, Attribute attr) { + // A value used to elide large container attribute. + int64_t largeAttrLimit = getLargeAttributeSizeLimit(); - // A value used to elide large container attribute. - int64_t largeAttrLimit = getLargeAttributeSizeLimit(); - for (auto attr : op->getAttrs()) { - os << '\n' << attr.first << ": "; // Always emit splat attributes. - if (attr.second.isa()) { - attr.second.print(os); - continue; + if (attr.isa()) { + attr.print(os); + return; } // Elide "big" elements attributes. - auto elements = attr.second.dyn_cast(); + auto elements = attr.dyn_cast(); if (elements && elements.getNumElements() > largeAttrLimit) { os << std::string(elements.getType().getRank(), '[') << "..." << std::string(elements.getType().getRank(), ']') << " : " << elements.getType(); - continue; + return; } - auto array = attr.second.dyn_cast(); + auto array = attr.dyn_cast(); if (array && static_cast(array.size()) > largeAttrLimit) { os << "[...]"; - continue; + return; } // Print all other attributes. - attr.second.print(os); + attr.print(os); } - return os.str(); -} -} // end namespace llvm + /// Append an edge to the list of edges. + /// Note: Edges are written to the output stream via `emitAllEdgeStmts`. + void emitEdgeStmt(Node n1, Node n2, std::string label, + StringRef style = kLineStyleDataFlow) { + AttributeMap attrs; + attrs["style"] = style.str(); + // Do not label edges that start/end at a cluster boundary. Such edges are + // clipped at the boundary, but labels are not. This can lead to labels + // floating around without any edge next to them. + if (!n1.clusterId && !n2.clusterId) + attrs["label"] = quoteString(escapeString(label)); + // Use `ltail` and `lhead` to draw edges between clusters. + if (n1.clusterId) + attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId); + if (n2.clusterId) + attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId); -namespace { -// PrintOpPass is simple pass to write graph per function. -// Note: this is a module pass only to avoid interleaving on the same ostream -// due to multi-threading over functions. -class PrintOpPass : public ViewOpGraphPassBase { -public: - PrintOpPass(raw_ostream &os, bool shortNames, const Twine &title) : os(os) { - this->shortNames = shortNames; - this->title = title.str(); + edges.push_back(strFromOs([&](raw_ostream &os) { + os << llvm::format("v%i -> v%i ", n1.id, n2.id); + emitAttrList(os, attrs); + })); } - std::string getOpName(Operation &op) { - auto symbolAttr = - op.getAttrOfType(SymbolTable::getSymbolAttrName()); - if (symbolAttr) - return std::string(symbolAttr.getValue()); - ++unnamedOpCtr; - return (op.getName().getStringRef() + llvm::utostr(unnamedOpCtr)).str(); + /// Emit a graph. The specified builder generates the body of the graph. + void emitGraph(function_ref builder) { + os << "digraph G {\n"; + os.indent(); + // Edges between clusters are allowed only in compound mode. + os << attrStmt("compound", "true") << ";\n"; + builder(); + os.unindent(); + os << "}\n"; } - // Print all the ops in a module. - void processModule(ModuleOp module) { - for (Operation &op : module) { - // Modules may actually be nested, recurse on nesting. - if (auto nestedModule = dyn_cast(op)) { - processModule(nestedModule); - continue; - } - auto opName = getOpName(op); - for (Region ®ion : op.getRegions()) { - for (auto indexed_block : llvm::enumerate(region)) { - // Suffix block number if there are more than 1 block. - auto blockName = llvm::hasSingleElement(region) - ? "" - : ("__" + llvm::utostr(indexed_block.index())); - llvm::WriteGraph(os, &indexed_block.value(), shortNames, - Twine(title) + opName + blockName); - } + /// Emit a node statement. + Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) { + int nodeId = ++counter; + AttributeMap attrs; + attrs["label"] = quoteString(escapeString(label)); + attrs["shape"] = shape.str(); + os << llvm::format("v%i ", nodeId); + emitAttrList(os, attrs); + os << ";\n"; + return Node(nodeId); + } + + /// Generate a label for an operation. + std::string getLabel(Operation *op) { + return strFromOs([&](raw_ostream &os) { + // Print operation name and type. + os << op->getName() << " : ("; + interleaveComma(op->getResultTypes(), os); + os << ")\n"; + + // Print attributes. + for (const NamedAttribute &attr : op->getAttrs()) { + os << '\n' << attr.first << ": "; + emitMlirAttr(os, attr.second); } + }); + } + + /// Generate a label for a block argument. + std::string getLabel(BlockArgument arg) { + return "arg" + std::to_string(arg.getArgNumber()); + } + + /// Process a block. Emit a cluster and one node per block argument and + /// operation inside the cluster. + void processBlock(Block &block) { + emitClusterStmt([&]() { + for (BlockArgument &blockArg : block.getArguments()) + valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg)); + + // Emit a node for each operation. + for (Operation &op : block) + processOperation(&op); + }); + } + + /// Process an operation. If the operation has regions, emit a cluster. + /// Otherwise, emit a node. + void processOperation(Operation *op) { + Node node; + if (op->getNumRegions() > 0) { + // Emit cluster for op with regions. + node = emitClusterStmt( + [&]() { + for (Region ®ion : op->getRegions()) + processRegion(region); + }, + getLabel(op)); + } else { + node = emitNodeStmt(getLabel(op)); } + + // Insert edges originating from each operand. + unsigned numOperands = op->getNumOperands(); + for (unsigned i = 0; i < numOperands; i++) + emitEdgeStmt(valueToNode[op->getOperand(i)], node, + /*label=*/numOperands == 1 ? "" : std::to_string(i)); + + for (Value result : op->getResults()) + valueToNode[result] = node; } - void runOnOperation() override { processModule(getOperation()); } + /// Process a region. + void processRegion(Region ®ion) { + for (Block &block : region.getBlocks()) + processBlock(block); + } -private: - raw_ostream &os; - int unnamedOpCtr = 0; + /// Output stream to write DOT file to. + raw_indented_ostream os; + raw_ostream &plainOs; // Needed for copy constructor. + /// A list of edges. For simplicity, should be emitted after all nodes were + /// emitted. + std::vector edges; + /// Mapping of SSA values to Graphviz nodes/clusters. + DenseMap valueToNode; + /// Counter for generating unique node/subgraph identifiers. + int counter = 0; }; -} // namespace - -void mlir::viewGraph(Block &block, const Twine &name, bool shortNames, - const Twine &title, llvm::GraphProgram::Name program) { - llvm::ViewGraph(&block, name, shortNames, title, program); -} -raw_ostream &mlir::writeGraph(raw_ostream &os, Block &block, bool shortNames, - const Twine &title) { - return llvm::WriteGraph(os, &block, shortNames, title); -} +} // namespace -std::unique_ptr> -mlir::createPrintOpGraphPass(raw_ostream &os, bool shortNames, - const Twine &title) { - return std::make_unique(os, shortNames, title); +std::unique_ptr mlir::createPrintOpGraphPass(raw_ostream &os) { + return std::make_unique(os); } diff --git a/mlir/test/Transforms/print-op-graph.mlir b/mlir/test/Transforms/print-op-graph.mlir --- a/mlir/test/Transforms/print-op-graph.mlir +++ b/mlir/test/Transforms/print-op-graph.mlir @@ -1,18 +1,36 @@ // RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph %s -o %t 2>&1 | FileCheck %s -// CHECK-LABEL: digraph "merge_blocks" -// CHECK{LITERAL}: value: [[...]] : tensor\<2x2xi32\>} -// CHECK{LITERAL}: value: dense\<1\> : tensor\<5xi32\>} -// CHECK{LITERAL}: value: dense\<[[0, 1]]\> : tensor\<1x2xi32\>} +// CHECK-LABEL: digraph G { +// CHECK: subgraph {{.*}} { +// CHECK: subgraph {{.*}} +// CHECK: label = "builtin.func{{.*}}merge_blocks +// CHECK: subgraph {{.*}} { +// CHECK: v[[ARG0:.*]] [label = "arg0" +// CHECK: v[[CONST10:.*]] [label ={{.*}}10 : i32 +// CHECK: subgraph [[CLUSTER_MERGE_BLOCKS:.*]] { +// CHECK: v[[ANCHOR:.*]] [label = " ", shape = plain] +// CHECK: label = "test.merge_blocks +// CHECK: subgraph {{.*}} { +// CHECK: v[[TEST_BR:.*]] [label = "test.br +// CHECK: } +// CHECK: subgraph {{.*}} { +// CHECK: } +// CHECK: } +// CHECK: v[[TEST_RET:.*]] [label = "test.return +// CHECK: v[[ARG0]] -> v[[TEST_BR]] +// CHECK: v[[CONST10]] -> v[[TEST_BR]] +// CHECK: v[[ANCHOR]] -> v[[TEST_RET]] [{{.*}}, ltail = [[CLUSTER_MERGE_BLOCKS]]] +// CHECK: v[[ANCHOR]] -> v[[TEST_RET]] [{{.*}}, ltail = [[CLUSTER_MERGE_BLOCKS]]] func @merge_blocks(%arg0: i32, %arg1 : i32) -> () { %0 = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> %1 = constant dense<1> : tensor<5xi32> %2 = constant dense<[[0, 1]]> : tensor<1x2xi32> - + %a = constant 10 : i32 + %b = "test.func"() : () -> i32 %3:2 = "test.merge_blocks"() ({ ^bb0: - "test.br"(%arg0, %arg1)[^bb1] : (i32, i32) -> () - ^bb1(%arg3 : i32, %arg4 : i32): + "test.br"(%arg0, %b, %a)[^bb1] : (i32, i32, i32) -> () + ^bb1(%arg3 : i32, %arg4 : i32, %arg5: i32): "test.return"(%arg3, %arg4) : (i32, i32) -> () }) : () -> (i32, i32) "test.return"(%3#0, %3#1) : (i32, i32) -> ()