diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -615,6 +615,10 @@ def PrintOpStats : Pass<"print-op-stats"> { let summary = "Print statistics of operations"; let constructor = "mlir::createPrintOpStatsPass()"; + let options = [ + Option<"topKLocs", "top-k-locs", "int64_t", + /*default=*/"0", "Print top-k locations by number of operations"> + ]; } def SCCP : Pass<"sccp"> { diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp --- a/mlir/lib/Transforms/OpStats.cpp +++ b/mlir/lib/Transforms/OpStats.cpp @@ -27,18 +27,32 @@ // Print summary of op stats. void printSummary(); + // Print top-k locations by number of operations. + void printTopLocations(); + private: llvm::StringMap opCount; + llvm::DenseMap opCountByLoc; + raw_ostream &os; }; } // namespace void PrintOpStatsPass::runOnOperation() { opCount.clear(); + opCountByLoc.clear(); + llvm::DenseMap locCount; // Compute the operation statistics for the currently visited operation. - getOperation()->walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); + getOperation()->walk([&](Operation *op) { + ++opCount[op->getName().getStringRef()]; + ++opCountByLoc[op->getLoc()]; + }); + printSummary(); + + if (topKLocs) + printTopLocations(); } void PrintOpStatsPass::printSummary() { @@ -80,6 +94,38 @@ } } +void PrintOpStatsPass::printTopLocations() { + os << '\n' + << "Top locations by number of operations:\n" + << "--------------------------------------\n"; + + // Sort by value (descending). + using EntryTy = std::pair; + std::vector sorted; + for (const auto &it : opCountByLoc) + sorted.push_back(it); + llvm::sort(sorted, + [](EntryTy e1, EntryTy e2) { return e1.second > e2.second; }); + + // Take top entries. + int64_t numLocs = sorted.size() < topKLocs ? sorted.size() : topKLocs; + for (int i = 0; i < numLocs; ++i) { + std::string buffer; + llvm::raw_string_ostream strOs(buffer); + sorted[i].first.print(strOs); + std::string str = strOs.str(); + size_t locSize = str.size(); + static const int kMaxLocLen = 50; + + if (locSize <= kMaxLocLen) + os << llvm::right_justify(str, kMaxLocLen + 3); + else + os << "..." << str.substr(locSize - kMaxLocLen); + + os << " , " << sorted[i].second << '\n'; + } +} + std::unique_ptr mlir::createPrintOpStatsPass() { return std::make_unique(); } diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -137,14 +137,14 @@ mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass); // Print the top level pass manager - // CHECK: Top-level: builtin.module(builtin.func(print-op-stats)) + // CHECK: Top-level: builtin.module(builtin.func(print-op-stats{top-k-locs=0})) fprintf(stderr, "Top-level: "); mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, NULL); fprintf(stderr, "\n"); // Print the pipeline nested one level down - // CHECK: Nested Module: builtin.func(print-op-stats) + // CHECK: Nested Module: builtin.func(print-op-stats{top-k-locs=0}) fprintf(stderr, "Nested Module: "); mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL); fprintf(stderr, "\n"); @@ -184,7 +184,7 @@ exit(EXIT_FAILURE); } - // CHECK: Round-trip: builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats)) + // CHECK: Round-trip: builtin.module(builtin.func(print-op-stats{top-k-locs=0}), builtin.func(print-op-stats{top-k-locs=0})) fprintf(stderr, "Round-trip: "); mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, NULL); diff --git a/mlir/test/IR/op-stats.mlir b/mlir/test/IR/op-stats.mlir --- a/mlir/test/IR/op-stats.mlir +++ b/mlir/test/IR/op-stats.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -print-op-stats %s -o=/dev/null 2>&1 | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -print-op-stats="top-k-locs=5" %s -o=/dev/null 2>&1 | FileCheck %s func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): @@ -34,3 +34,9 @@ // CHECK: std.addf , 6 // CHECK: std.return , 1 // CHECK: xla.add , 17 +// CHECK-LABEL: Top locations by number of operations +// CHECK: IR/op-stats.mlir{{.*}} , 1 +// CHECK: IR/op-stats.mlir{{.*}} , 1 +// CHECK: IR/op-stats.mlir{{.*}} , 1 +// CHECK: IR/op-stats.mlir{{.*}} , 1 +// CHECK: IR/op-stats.mlir{{.*}} , 1 diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -48,7 +48,7 @@ # This will register the pass and round-trip should be possible now. import mlir.transforms pm = PassManager.parse("builtin.module(builtin.func(print-op-stats))") - # CHECK: Roundtrip: builtin.module(builtin.func(print-op-stats)) + # CHECK: Roundtrip: builtin.module(builtin.func(print-op-stats{top-k-locs=0})) log("Roundtrip: ", pm) run(testParseSuccess)