diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -87,6 +87,7 @@ PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {} void runOnOperation() override { + initColorMapping(*getOperation()); emitGraph([&]() { processOperation(getOperation()); emitAllEdgeStmts(); @@ -97,10 +98,31 @@ void emitRegionCFG(Region ®ion) { printControlFlowEdges = true; printDataFlowEdges = false; + initColorMapping(region); emitGraph([&]() { processRegion(region); }); } private: + /// Generate a color mapping that will color every operation with the same + /// name the same way. It'll interpolate the hue in the HSV color-space, + /// attempting to keep the contrast suitable for black text. + template + void initColorMapping(T &irEntity) { + backgroundColors.clear(); + SmallVector ops; + irEntity.walk([&](Operation *op) { + auto &entry = backgroundColors[op->getName()]; + if (entry.first == 0) + ops.push_back(op); + ++entry.first; + }); + for (auto indexedOps : llvm::enumerate(ops)) { + double hue = ((double)indexedOps.index()) / ops.size(); + backgroundColors[indexedOps.value()->getName()].second = + std::to_string(hue) + " 1.0 1.0"; + } + } + /// Emit all edges. This function should be called after all nodes have been /// emitted. void emitAllEdgeStmts() { @@ -206,11 +228,16 @@ } /// Emit a node statement. - Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) { + Node emitNodeStmt(std::string label, StringRef shape = kShapeNode, + StringRef background = "") { int nodeId = ++counter; AttributeMap attrs; attrs["label"] = quoteString(escapeString(std::move(label))); attrs["shape"] = shape.str(); + if (!background.empty()) { + attrs["style"] = "filled"; + attrs["fillcolor"] = ("\"" + background + "\"").str(); + } os << llvm::format("v%i ", nodeId); emitAttrList(os, attrs); os << ";\n"; @@ -278,7 +305,8 @@ }, getLabel(op)); } else { - node = emitNodeStmt(getLabel(op)); + node = emitNodeStmt(getLabel(op), kShapeNode, + backgroundColors[op->getName()].second); } // Insert data flow edges originating from each operand. @@ -318,6 +346,8 @@ DenseMap valueToNode; /// Counter for generating unique node/subgraph identifiers. int counter = 0; + + DenseMap> backgroundColors; }; } // namespace