diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -615,6 +615,10 @@ def PrintOpStats : Pass<"print-op-stats"> { let summary = "Print statistics of operations"; let constructor = "mlir::createPrintOpStatsPass()"; + let options = [ + Option<"topKLocs", "top-k-locs", "int64_t", + /*default=*/"0", "Print top-k locations by number of operations"> + ]; } def SCCP : Pass<"sccp"> { diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp --- a/mlir/lib/Transforms/OpStats.cpp +++ b/mlir/lib/Transforms/OpStats.cpp @@ -27,18 +27,32 @@ // Print summary of op stats. void printSummary(); + // Print top-k locations by number of operations. + void printTopLocations(); + private: llvm::StringMap opCount; + llvm::DenseMap opCountByLoc; + raw_ostream &os; }; } // namespace void PrintOpStatsPass::runOnOperation() { opCount.clear(); + opCountByLoc.clear(); + llvm::DenseMap locCount; // Compute the operation statistics for the currently visited operation. - getOperation()->walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); + getOperation()->walk([&](Operation *op) { + ++opCount[op->getName().getStringRef()]; + ++opCountByLoc[op->getLoc()]; + }); + printSummary(); + + if (topKLocs) + printTopLocations(); } void PrintOpStatsPass::printSummary() { @@ -80,6 +94,38 @@ } } +void PrintOpStatsPass::printTopLocations() { + os << '\n' + << "Top locations by number of operations:\n" + << "--------------------------------------\n"; + + // Sort by value (descending). + using EntryTy = std::pair; + std::vector sorted; + for (const auto &it : opCountByLoc) + sorted.push_back(it); + llvm::sort(sorted, + [](EntryTy e1, EntryTy e2) { return e1.second > e2.second; }); + + // Take top entries. + int64_t numLocs = sorted.size() < topKLocs ? sorted.size() : topKLocs; + for (int i = 0; i < numLocs; ++i) { + std::string buffer; + llvm::raw_string_ostream strOs(buffer); + sorted[i].first.print(strOs); + std::string str = strOs.str(); + size_t locSize = str.size(); + static const int kMaxLocLen = 50; + + if (locSize <= kMaxLocLen) + os << llvm::right_justify(str, kMaxLocLen + 3); + else + os << "..." << str.substr(locSize - kMaxLocLen); + + os << " , " << sorted[i].second << '\n'; + } +} + std::unique_ptr mlir::createPrintOpStatsPass() { return std::make_unique(); } diff --git a/mlir/test/IR/op-stats.mlir b/mlir/test/IR/op-stats.mlir --- a/mlir/test/IR/op-stats.mlir +++ b/mlir/test/IR/op-stats.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -print-op-stats %s -o=/dev/null 2>&1 | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -print-op-stats="top-k-locs=5" %s -o=/dev/null 2>&1 | FileCheck %s func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): @@ -34,3 +34,9 @@ // CHECK: std.addf , 6 // CHECK: std.return , 1 // CHECK: xla.add , 17 +// CHECK-LABEL: Top locations by number of operations +// CHECK: IR/op-stats.mlir{{.*}} , 1 +// CHECK: IR/op-stats.mlir{{.*}} , 1 +// CHECK: IR/op-stats.mlir{{.*}} , 1 +// CHECK: IR/op-stats.mlir{{.*}} , 1 +// CHECK: IR/op-stats.mlir{{.*}} , 1