diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -617,7 +617,9 @@ let constructor = "mlir::createPrintOpStatsPass()"; let options = [ Option<"topKLocs", "top-k-locs", "int64_t", - /*default=*/"0", "Print top-k locations by number of operations"> + /*default=*/"0", "Print top-k locations by number of operations">, + Option<"topKLocsOpCounts", "top-k-locs-op-counts", "bool", + /*default=*/"false", "Print op counts for top-k locations"> ]; } diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp --- a/mlir/lib/Transforms/OpStats.cpp +++ b/mlir/lib/Transforms/OpStats.cpp @@ -18,6 +18,16 @@ using namespace mlir; namespace { +struct OpCounter { + void increment(StringRef opName) { + ++map[opName]; + ++total; + } + + llvm::StringMap map; + int64_t total; +}; + struct PrintOpStatsPass : public PrintOpStatsBase { explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {} @@ -30,9 +40,12 @@ // Print top-k locations by number of operations. void printTopLocations(); + // Print op counts, sorted by op name. + void printOpCounts(const llvm::StringMap &count); + private: llvm::StringMap opCount; - llvm::DenseMap opCountByLoc; + llvm::DenseMap opCountByLoc; raw_ostream &os; }; @@ -46,7 +59,7 @@ // Compute the operation statistics for the currently visited operation. getOperation()->walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; - ++opCountByLoc[op->getLoc()]; + opCountByLoc[op->getLoc()].increment(op->getName().getStringRef()); }); printSummary(); @@ -58,7 +71,11 @@ void PrintOpStatsPass::printSummary() { os << "Operations encountered:\n"; os << "-----------------------\n"; - SmallVector sorted(opCount.keys()); + printOpCounts(opCount); +} + +void PrintOpStatsPass::printOpCounts(const llvm::StringMap &count) { + SmallVector sorted(count.keys()); llvm::sort(sorted); // Split an operation name from its dialect prefix. @@ -89,8 +106,8 @@ os << llvm::right_justify(dialectName, maxLenDialect + 2) << '.'; // Left justify the operation name. - os << llvm::left_justify(opName, maxLenOpName) << " , " << opCount[key] - << '\n'; + os << llvm::left_justify(opName, maxLenOpName) << " , " + << count.find(key)->second << '\n'; } } @@ -100,15 +117,16 @@ << "--------------------------------------\n"; // Sort by value (descending). - using EntryTy = std::pair; + using EntryTy = std::pair; std::vector sorted; for (const auto &it : opCountByLoc) sorted.push_back(it); - llvm::sort(sorted, - [](EntryTy e1, EntryTy e2) { return e1.second > e2.second; }); + llvm::sort(sorted, [](EntryTy e1, EntryTy e2) { + return e1.second.total > e2.second.total; + }); // Take top entries. - int64_t numLocs = sorted.size() < topKLocs ? sorted.size() : topKLocs; + int64_t numLocs = std::min(sorted.size(), topKLocs); for (int i = 0; i < numLocs; ++i) { std::string buffer; llvm::raw_string_ostream strOs(buffer); @@ -122,7 +140,10 @@ else os << "..." << str.substr(locSize - kMaxLocLen); - os << " , " << sorted[i].second << '\n'; + os << " , " << sorted[i].second.total << '\n'; + + if (topKLocsOpCounts) + printOpCounts(sorted[i].second.map); } } diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -137,14 +137,14 @@ mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass); // Print the top level pass manager - // CHECK: Top-level: builtin.module(builtin.func(print-op-stats{top-k-locs=0})) + // CHECK: Top-level: builtin.module(builtin.func(print-op-stats{top-k-locs=0 top-k-locs-op-counts=false})) fprintf(stderr, "Top-level: "); mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, NULL); fprintf(stderr, "\n"); // Print the pipeline nested one level down - // CHECK: Nested Module: builtin.func(print-op-stats{top-k-locs=0}) + // CHECK: Nested Module: builtin.func(print-op-stats{top-k-locs=0 top-k-locs-op-counts=false}) fprintf(stderr, "Nested Module: "); mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL); fprintf(stderr, "\n"); @@ -184,7 +184,8 @@ exit(EXIT_FAILURE); } - // CHECK: Round-trip: builtin.module(builtin.func(print-op-stats{top-k-locs=0}), builtin.func(print-op-stats{top-k-locs=0})) + // CHECK: Round-trip: builtin.module(builtin.func(print-op-stats{top-k-locs=0 top-k-locs-op-counts=false}), + // CHECK-SAME: builtin.func(print-op-stats{top-k-locs=0 top-k-locs-op-counts=false})) fprintf(stderr, "Round-trip: "); mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, NULL); diff --git a/mlir/test/IR/op-stats.mlir b/mlir/test/IR/op-stats.mlir --- a/mlir/test/IR/op-stats.mlir +++ b/mlir/test/IR/op-stats.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -print-op-stats="top-k-locs=5" %s -o=/dev/null 2>&1 | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -print-op-stats="top-k-locs=5 top-k-locs-op-counts" %s -o=/dev/null 2>&1 | FileCheck %s func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -48,7 +48,7 @@ # This will register the pass and round-trip should be possible now. import mlir.transforms pm = PassManager.parse("builtin.module(builtin.func(print-op-stats))") - # CHECK: Roundtrip: builtin.module(builtin.func(print-op-stats{top-k-locs=0})) + # CHECK: Roundtrip: builtin.module(builtin.func(print-op-stats{top-k-locs=0 top-k-locs-op-counts=false})) log("Roundtrip: ", pm) run(testParseSuccess)