diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -617,7 +617,9 @@ let constructor = "mlir::createPrintOpStatsPass()"; let options = [ Option<"topKLocs", "top-k-locs", "int64_t", - /*default=*/"0", "Print top-k locations by number of operations"> + /*default=*/"0", "Print top-k locations by number of operations">, + Option<"topKLocsHist", "top-k-locs-hist", "bool", + /*default=*/"false", "Print op name histogram for top-k locations"> ]; } diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp --- a/mlir/lib/Transforms/OpStats.cpp +++ b/mlir/lib/Transforms/OpStats.cpp @@ -18,6 +18,16 @@ using namespace mlir; namespace { +struct OpCounter { + void increment(StringRef opName) { + ++map[opName]; + ++total; + } + + llvm::StringMap map; + int64_t total; +}; + struct PrintOpStatsPass : public PrintOpStatsBase { explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {} @@ -30,9 +40,12 @@ // Print top-k locations by number of operations. void printTopLocations(); + // Print a histogram of op names, sorted by op name. + void printOpNameHistogram(const llvm::StringMap &count); + private: llvm::StringMap opCount; - llvm::DenseMap opCountByLoc; + llvm::DenseMap opCountByLoc; raw_ostream &os; }; @@ -46,7 +59,7 @@ // Compute the operation statistics for the currently visited operation. getOperation()->walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; - ++opCountByLoc[op->getLoc()]; + opCountByLoc[op->getLoc()].increment(op->getName().getStringRef()); }); printSummary(); @@ -58,7 +71,12 @@ void PrintOpStatsPass::printSummary() { os << "Operations encountered:\n"; os << "-----------------------\n"; - SmallVector sorted(opCount.keys()); + printOpNameHistogram(opCount); +} + +void PrintOpStatsPass::printOpNameHistogram( + const llvm::StringMap &count) { + SmallVector sorted(count.keys()); llvm::sort(sorted); // Split an operation name from its dialect prefix. @@ -89,8 +107,8 @@ os << llvm::right_justify(dialectName, maxLenDialect + 2) << '.'; // Left justify the operation name. - os << llvm::left_justify(opName, maxLenOpName) << " , " << opCount[key] - << '\n'; + os << llvm::left_justify(opName, maxLenOpName) << " , " + << count.find(key)->second << '\n'; } } @@ -100,15 +118,16 @@ << "--------------------------------------\n"; // Sort by value (descending). - using EntryTy = std::pair; + using EntryTy = std::pair; std::vector sorted; for (const auto &it : opCountByLoc) sorted.push_back(it); - llvm::sort(sorted, - [](EntryTy e1, EntryTy e2) { return e1.second > e2.second; }); + llvm::sort(sorted, [](EntryTy e1, EntryTy e2) { + return e1.second.total > e2.second.total; + }); // Take top entries. - int64_t numLocs = sorted.size() < topKLocs ? sorted.size() : topKLocs; + int64_t numLocs = std::min(sorted.size(), topKLocs); for (int i = 0; i < numLocs; ++i) { std::string buffer; llvm::raw_string_ostream strOs(buffer); @@ -122,7 +141,10 @@ else os << "..." << str.substr(locSize - kMaxLocLen); - os << " , " << sorted[i].second << '\n'; + os << " , " << sorted[i].second.total << '\n'; + + if (topKLocsHist) + printOpNameHistogram(sorted[i].second.map); } } diff --git a/mlir/test/IR/op-stats.mlir b/mlir/test/IR/op-stats.mlir --- a/mlir/test/IR/op-stats.mlir +++ b/mlir/test/IR/op-stats.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -print-op-stats="top-k-locs=5" %s -o=/dev/null 2>&1 | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -print-op-stats="top-k-locs=5 top-k-locs-hist" %s -o=/dev/null 2>&1 | FileCheck %s func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):