diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -695,13 +695,20 @@ - Op are represented as nodes; - Uses as edges; + - Regions/blocks as subgraphs; }]; let constructor = "mlir::createPrintOpGraphPass()"; let options = [ - Option<"title", "title", "std::string", - /*default=*/"", "The prefix of the title of the graph">, - Option<"shortNames", "short-names", "bool", /*default=*/"false", - "Use short names"> + Option<"color", "color", "bool", + /*default=*/"true", "Use colors">, + Option<"controlFlowEdges", "control-flow-edges", "bool", + /*default=*/"false", "Draw control flow edges">, + Option<"duplicateConstNodes", "duplicate-const-nodes", "bool", + /*default=*/"true", "Duplicate nodes for constant ops/block args">, + Option<"maxLocationLen", "max-location-len", "int", + /*default=*/"0", "After how many chars to truncate location attrs">, + Option<"printTypes", "print-types", "bool", + /*default=*/"true", "Print operation types"> ]; } diff --git a/mlir/include/mlir/Transforms/ViewOpGraph.h b/mlir/include/mlir/Transforms/ViewOpGraph.h --- a/mlir/include/mlir/Transforms/ViewOpGraph.h +++ b/mlir/include/mlir/Transforms/ViewOpGraph.h @@ -14,27 +14,15 @@ #define MLIR_TRANSFORMS_VIEWOPGRAPH_H_ #include "mlir/Support/LLVM.h" -#include "llvm/Support/GraphWriter.h" #include "llvm/Support/raw_ostream.h" namespace mlir { -class Block; class ModuleOp; template class OperationPass; -/// Displays the graph in a window. This is for use from the debugger and -/// depends on Graphviz to generate the graph. -void viewGraph(Block &block, const Twine &name, bool shortNames = false, - const Twine &title = "", - llvm::GraphProgram::Name program = llvm::GraphProgram::DOT); - -raw_ostream &writeGraph(raw_ostream &os, Block &block, bool shortNames = false, - const Twine &title = ""); - /// Creates a pass to print op graphs. std::unique_ptr> -createPrintOpGraphPass(raw_ostream &os = llvm::errs(), bool shortNames = false, - const Twine &title = ""); +createPrintOpGraphPass(raw_ostream &os = llvm::errs()); } // end namespace mlir diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -10,11 +10,25 @@ #include "PassDetail.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/Operation.h" -#include "llvm/Support/CommandLine.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; +namespace { + +static const char kSolidLineStyle[] = "solid"; +static const char kDashedLineStyle[] = "dashed"; +static const char kEllipseNodeShape[] = "ellipse"; +static const char kNoneNodeShape[] = "none"; +static const char kColorWhite[] = "#FFFFFF"; +static const char *kColors[] = {"#FFADAD", "#FFD6A5", "#FDFFB6", "#CAFFBF", + "#9BF6FF", "#A0C4FF", "#BDB2FF", "#FFC6FF"}; + /// Return the size limits for eliding large attributes. static int64_t getLargeAttributeSizeLimit() { // Use the default from the printer flags if possible. @@ -23,145 +37,393 @@ return 16; } -namespace llvm { +/// Escape special characters such as '\n' and quotation marks. +static std::string escapeString(std::string str) { + std::string buf; + llvm::raw_string_ostream os(buf); + os.write_escaped(str); + return os.str(); +} + +/// Put quotation marks around a given string. +static std::string quoteString(std::string str) { return "\"" + str + "\""; } + +using AttributeMap = llvm::StringMap; + +/// This struct represents a node in the DOT language. Each node has an +/// identifier and an optional identifier for the cluster (subgraph) that +/// contains the node. +class Node { +public: + Node() : Node(0) {} + + static Node create() { return Node(getNextId()); } + + static Node createWithCluster() { return Node(getNextId(), getNextId()); } -// Specialize GraphTraits to treat Block as a graph of Operations as nodes and -// uses as edges. -template <> struct GraphTraits { - using GraphType = Block *; - using NodeRef = Operation *; + static int getNextId() { + static int counter = 0; + return ++counter; + } - using ChildIteratorType = Operation::user_iterator; - static ChildIteratorType child_begin(NodeRef n) { return n->user_begin(); } - static ChildIteratorType child_end(NodeRef n) { return n->user_end(); } + operator bool() { return id != 0; } - // Operation's destructor is private so use Operation* instead and use - // mapped iterator. - static Operation *AddressOf(Operation &op) { return &op; } - using nodes_iterator = mapped_iterator; - static nodes_iterator nodes_begin(Block *b) { - return nodes_iterator(b->begin(), &AddressOf); + int getClusterId() { + assert(hasCluster()); + return *clusterId; } - static nodes_iterator nodes_end(Block *b) { - return nodes_iterator(b->end(), &AddressOf); + + int getId() { + assert(id != 0); + return id; } -}; -// Specialize DOTGraphTraits to produce more readable output. -template <> struct DOTGraphTraits : public DefaultDOTGraphTraits { - using DefaultDOTGraphTraits::DefaultDOTGraphTraits; - static std::string getNodeLabel(Operation *op, Block *); + bool hasCluster() { return clusterId.hasValue(); } + void setId(int id) { this->id = id; } + +private: + Node(int id) { this->id = id; } + Node(int id, int clusterId) : Node(id) { this->clusterId = clusterId; } + + int id; + llvm::Optional clusterId; }; -std::string DOTGraphTraits::getNodeLabel(Operation *op, Block *b) { - // Reuse the print output for the node labels. - std::string ostr; - raw_string_ostream os(ostr); - os << op->getName() << "\n"; +static Optional getConstantAttribute(Value val) { + Attribute attr; + if (matchPattern(val, m_Constant(&attr))) + return attr; + return llvm::None; +} - if (!op->getLoc().isa()) { - os << op->getLoc() << "\n"; +class PrintOpPass : public ViewOpGraphPassBase { +public: + PrintOpPass(raw_ostream &os) : os(os) {} + + void runOnOperation() override { + emitGraph([&]() { + processModule(getOperation()); + emitAllEdgeStmts(); + }); } - // Print resultant types - llvm::interleaveComma(op->getResultTypes(), os); - os << "\n"; +private: + /// Emit output with the current indentation. + raw_ostream &emitIndented() { + os.indent(2 * indent); + return os; + } + + /// Emit all edges. This function should be called after all nodes have been + /// emitted. + void emitAllEdgeStmts() { + for (const auto &edge : edges) + emitIndented() << edge << ";\n"; + edges.clear(); + } + + /// Emit a cluster (subgraph). The specified builder generates the body of the + /// cluster. If the specified node has no cluster, the builder is called + /// directly. + void emitCluster(Node node, function_ref builder, + std::string label = "", std::string color = kColorWhite) { + if (node.hasCluster()) { + emitIndented() << "subgraph cluster_" << node.getClusterId() << " {\n"; + ++indent; + emitIndented() << attrStmt("label", quoteString(escapeString(label))) + << ";\n"; + if (this->color) { + emitIndented() << attrStmt("style", "filled") << ";\n"; + emitIndented() << attrStmt("fillcolor", quoteString(color)) << ";\n"; + } + } + builder(); + if (node.hasCluster()) { + --indent; + emitIndented() << "}\n"; + } + } + + /// Emit an edge from an MLIR value to a Graphviz node. In most cases, the + /// value's node is looked up in `valueToNode`. In a few special cases, a new + /// Graphviz node is created on-the-fly. + /// Note: Edges are appended to a list of edge and written to the output + /// stream via `emitAllEdgeStmts`. + void emitEdgeStmt(Value value, Node node, std::string label, + std::string style = kSolidLineStyle) { + // If `duplicateConstNodes` create nodes on-the-fly for each use. + if (this->duplicateConstNodes) { + // Create a new node for each block argument. + if (auto blockArg = value.dyn_cast()) { + Node src = Node::create(); + emitNodeStmt(src, getLabel(blockArg), kColorWhite, kNoneNodeShape); + emitEdgeStmt(src, node, label, style); + return; + } + + // Create a new node whenever the constant is used. + if (auto attr = getConstantAttribute(value)) { + Node src = Node::create(); + std::string buf; + llvm::raw_string_ostream os(buf); + emitMlirAttr(os, *attr); + emitNodeStmt(src, os.str(), kColorWhite, kNoneNodeShape); + emitEdgeStmt(src, node, label, style); + return; + } + } + + // Default case: Lookup node in `valueToNode`. + emitEdgeStmt(valueToNode[value], node, label, style); + } + + /// Generate an attribute statement. + std::string attrStmt(std::string key, std::string value) { + return key + " = " + value; + } + + /// Emit an attribute list. + void emitAttrList(raw_ostream &os, const AttributeMap &map) { + SmallVector attrs; + for (auto it = map.begin(); it != map.end(); ++it) + attrs.push_back(attrStmt(it->getKey().str(), it->getValue())); + os << "["; + llvm::interleaveComma(attrs, os); + os << "]"; + } + + // Print an MLIR attribute to `os`. Large attributes are truncated. + void emitMlirAttr(raw_ostream &os, Attribute attr) { + // A value used to elide large container attribute. + int64_t largeAttrLimit = getLargeAttributeSizeLimit(); - // A value used to elide large container attribute. - int64_t largeAttrLimit = getLargeAttributeSizeLimit(); - for (auto attr : op->getAttrs()) { - os << '\n' << attr.first << ": "; // Always emit splat attributes. - if (attr.second.isa()) { - attr.second.print(os); - continue; + if (attr.isa()) { + attr.print(os); + return; } // Elide "big" elements attributes. - auto elements = attr.second.dyn_cast(); + auto elements = attr.dyn_cast(); if (elements && elements.getNumElements() > largeAttrLimit) { os << std::string(elements.getType().getRank(), '[') << "..." << std::string(elements.getType().getRank(), ']') << " : " << elements.getType(); - continue; + return; } - auto array = attr.second.dyn_cast(); + auto array = attr.dyn_cast(); if (array && static_cast(array.size()) > largeAttrLimit) { os << "[...]"; - continue; + return; } // Print all other attributes. - attr.second.print(os); + attr.print(os); } - return os.str(); -} -} // end namespace llvm + /// Append an edge to the list of edges. + /// Note: Edges are written to the output stream via `emitAllEdgeStmts`. + void emitEdgeStmt(Node n1, Node n2, std::string label, + std::string style = kSolidLineStyle) { + std::string buf; + llvm::raw_string_ostream edge(buf); + AttributeMap attrs; + attrs["style"] = style; + attrs["label"] = quoteString(escapeString(label)); + if (n1.hasCluster()) + attrs["ltail"] = "cluster_" + std::to_string(n1.getClusterId()); + if (n2.hasCluster()) + attrs["lhead"] = "cluster_" + std::to_string(n2.getClusterId()); + edge << llvm::format("v%i -> v%i ", n1.getId(), n2.getId()); + emitAttrList(edge, attrs); + edges.push_back(edge.str()); + } -namespace { -// PrintOpPass is simple pass to write graph per function. -// Note: this is a module pass only to avoid interleaving on the same ostream -// due to multi-threading over functions. -class PrintOpPass : public ViewOpGraphPassBase { -public: - PrintOpPass(raw_ostream &os, bool shortNames, const Twine &title) : os(os) { - this->shortNames = shortNames; - this->title = title.str(); + /// Emit a graph. The specified builder generates the body of the graph. + void emitGraph(function_ref builder) { + emitIndented() << "digraph G {\n"; + ++indent; + emitIndented() << attrStmt("compound", "true") << ";\n"; + emitIndented() << "node [style = filled];\n"; + builder(); + --indent; + emitIndented() << "}\n"; } - std::string getOpName(Operation &op) { - auto symbolAttr = - op.getAttrOfType(SymbolTable::getSymbolAttrName()); - if (symbolAttr) - return std::string(symbolAttr.getValue()); - ++unnamedOpCtr; - return (op.getName().getStringRef() + llvm::utostr(unnamedOpCtr)).str(); + /// Emit a node statement. + void emitNodeStmt(Node node, std::string label, + std::string color = kColorWhite, + std::string shape = kEllipseNodeShape) { + AttributeMap attrs; + attrs["label"] = quoteString(escapeString(label)); + attrs["shape"] = shape; + if (this->color) + attrs["fillcolor"] = quoteString(color); + emitIndented() << llvm::format("v%i ", node.getId()); + emitAttrList(os, attrs); + os << ";\n"; } - // Print all the ops in a module. - void processModule(ModuleOp module) { - for (Operation &op : module) { - // Modules may actually be nested, recurse on nesting. - if (auto nestedModule = dyn_cast(op)) { - processModule(nestedModule); - continue; + /// Generate a label for an operation. + std::string getLabel(Operation *op) { + // Reuse the print output for the node labels. + std::string buf; + llvm::raw_string_ostream os(buf); + + // Print operation name and type. + os << op->getName(); + if (this->printTypes) { + os << " : ("; + // Print result types + llvm::interleaveComma(op->getResultTypes(), os); + os << ")"; + } + os << "\n"; + + // Print location. + int maxLocLen = this->maxLocationLen; + if (maxLocLen > 0 && !op->getLoc().isa()) { + std::string locStr = llvm::formatv("{0}", op->getLoc()).str(); + os << "loc: "; + // Shorten large loc values to keep the graph readable. + if (locStr.length() >= maxLocLen) { + os << "..." << locStr.substr(locStr.length() - maxLocLen, maxLocLen); + } else { + os << locStr; } - auto opName = getOpName(op); - for (Region ®ion : op.getRegions()) { - for (auto indexed_block : llvm::enumerate(region)) { - // Suffix block number if there are more than 1 block. - auto blockName = llvm::hasSingleElement(region) - ? "" - : ("__" + llvm::utostr(indexed_block.index())); - llvm::WriteGraph(os, &indexed_block.value(), shortNames, - Twine(title) + opName + blockName); + os << "\n"; + } + + // Print attributes. + for (auto attr : op->getAttrs()) { + os << '\n' << attr.first << ": "; + emitMlirAttr(os, attr.second); + } + + return os.str(); + } + + /// Generate a label for a block argument. + std::string getLabel(BlockArgument arg) { + return "arg" + std::to_string(arg.getArgNumber()); + } + + /// Return the node color for a given operation. Operations with the same name + /// have the same color. + std::string getColor(Operation *op) { + if (!this->color) + return kColorWhite; + + auto opName = op->getName().getStringRef(); + static unsigned nextColor = 0; + static llvm::StringMap opColors; + if (opColors.find(opName) == opColors.end()) { + opColors[opName] = kColors[nextColor]; + nextColor = (nextColor + 1) % (sizeof(kColors) / sizeof(char *)); + } + return opColors[opName]; + } + + /// Process a block. Emit a cluster and one node per block argument and + /// operation inside the cluster. + Optional processBlock(Block &block) { + Node node = Node::createWithCluster(); + + emitCluster(node, [&]() { + // If `duplicateConstNodes`, create block argument nodes on-the-fly. + if (!this->duplicateConstNodes) { + for (auto &blockArg : block.getArguments()) { + Node argNode = Node::create(); + valueToNode[blockArg] = argNode; + emitNodeStmt(argNode, getLabel(blockArg)); } } + + // Emit a node for each operation and insert control flow edges between + // them. + Optional prevNode; + for (auto &op : block.getOperations()) { + if (Optional opNode = processOperation(&op)) { + node.setId(opNode->getId()); + if (this->controlFlowEdges && prevNode && opNode) + emitEdgeStmt(*prevNode, *opNode, "", kDashedLineStyle); + prevNode = opNode; + } + } + }); + + return node ? Optional(node) : llvm::None; + } + + // Process a module. The contents of the module is wrapped in a cluster. + void processModule(ModuleOp module) { + emitCluster(Node::createWithCluster(), [&]() { + for (Operation &op : module) { + if (auto nestedModule = dyn_cast(op)) + processModule(nestedModule); + else + processOperation(&op); + } + }); + } + + // Process an operation. If the operation has regions, emit a cluster. + // Otherwise, emit a node. + Optional processOperation(Operation *op) { + // If `duplicateConstNodes`, do not emit nodes for constants. Instead, + // generate a new node, each time it is used in an edge. + if (this->duplicateConstNodes && op->getNumResults() == 1 && + getConstantAttribute(op->getResult(0))) + return llvm::None; + + Node node; + if (op->getNumRegions() > 0) { + // Emit cluster for op with regions. + node = Node::createWithCluster(); + emitCluster( + node, + [&]() { + for (Region ®ion : op->getRegions()) + if (auto nestedNode = processRegion(region)) + node.setId(nestedNode->getId()); + }, + getLabel(op), getColor(op)); + } else { + node = Node::create(); + emitNodeStmt(node, getLabel(op), getColor(op)); } + + // Insert edges originating from each operand. + unsigned numOperands = op->getNumOperands(); + for (unsigned i = 0; i < numOperands; i++) + emitEdgeStmt(op->getOperand(i), node, + /*label=*/numOperands == 1 ? "" : std::to_string(i)); + + for (Value result : op->getResults()) + valueToNode[result] = node; + + return node; } - void runOnOperation() override { processModule(getOperation()); } + // Process a region. + Optional processRegion(Region ®ion) { + Optional result; + for (auto &block : region.getBlocks()) + if (auto node = processBlock(block)) + result = node; + return result; + } -private: + unsigned indent = 0; raw_ostream &os; - int unnamedOpCtr = 0; + std::vector edges; + llvm::DenseMap valueToNode; }; -} // namespace - -void mlir::viewGraph(Block &block, const Twine &name, bool shortNames, - const Twine &title, llvm::GraphProgram::Name program) { - llvm::ViewGraph(&block, name, shortNames, title, program); -} -raw_ostream &mlir::writeGraph(raw_ostream &os, Block &block, bool shortNames, - const Twine &title) { - return llvm::WriteGraph(os, &block, shortNames, title); -} +} // namespace std::unique_ptr> -mlir::createPrintOpGraphPass(raw_ostream &os, bool shortNames, - const Twine &title) { - return std::make_unique(os, shortNames, title); +mlir::createPrintOpGraphPass(raw_ostream &os) { + return std::make_unique(os); } diff --git a/mlir/test/Transforms/print-op-graph.mlir b/mlir/test/Transforms/print-op-graph.mlir --- a/mlir/test/Transforms/print-op-graph.mlir +++ b/mlir/test/Transforms/print-op-graph.mlir @@ -1,18 +1,35 @@ -// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph %s -o %t 2>&1 | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph='color=true' %s -o %t 2>&1 | FileCheck %s -// CHECK-LABEL: digraph "merge_blocks" -// CHECK{LITERAL}: value: [[...]] : tensor\<2x2xi32\>} -// CHECK{LITERAL}: value: dense\<1\> : tensor\<5xi32\>} -// CHECK{LITERAL}: value: dense\<[[0, 1]]\> : tensor\<1x2xi32\>} +// CHECK-LABEL: digraph G { +// CHECK: subgraph {{.*}} { +// CHECK: subgraph {{.*}} +// CHECK: label = "func{{.*}}merge_blocks +// CHECK: subgraph {{.*}} { +// CHECK: subgraph [[CLUSTER_MERGE_BLOCKS:.*]] { +// CHECK: label = "test.merge_blocks +// CHECK: subgraph {{.*}} { +// CHECK: v[[TEST_BR:.*]] [label = "test.br +// CHECK: v[[ARG0:.*]] [label = "arg0" +// CHECK: v[[CONST10:.*]] [label = "10 : i32" +// CHECK: } +// CHECK: subgraph {{.*}} { +// CHECK: } +// CHECK: } +// CHECK: v[[TEST_RET:.*]] [label = "test.return +// CHECK: v[[ARG0]] -> v[[TEST_BR]] +// CHECK: v[[CONST10]] -> v[[TEST_BR]] +// CHECK: v{{.*}} -> v[[TEST_RET]] [{{.*}}, ltail = [[CLUSTER_MERGE_BLOCKS]]] +// CHECK: v{{.*}} -> v[[TEST_RET]] [{{.*}}, ltail = [[CLUSTER_MERGE_BLOCKS]]] func @merge_blocks(%arg0: i32, %arg1 : i32) -> () { %0 = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> %1 = constant dense<1> : tensor<5xi32> %2 = constant dense<[[0, 1]]> : tensor<1x2xi32> - + %a = constant 10 : i32 + %b = "test.func"() : () -> i32 %3:2 = "test.merge_blocks"() ({ ^bb0: - "test.br"(%arg0, %arg1)[^bb1] : (i32, i32) -> () - ^bb1(%arg3 : i32, %arg4 : i32): + "test.br"(%arg0, %b, %a)[^bb1] : (i32, i32, i32) -> () + ^bb1(%arg3 : i32, %arg4 : i32, %arg5: i32): "test.return"(%arg3, %arg4) : (i32, i32) -> () }) : () -> (i32, i32) "test.return"(%3#0, %3#1) : (i32, i32) -> ()