Index: mlir/include/mlir/Pass/PassManager.h =================================================================== --- mlir/include/mlir/Pass/PassManager.h +++ mlir/include/mlir/Pass/PassManager.h @@ -37,6 +37,7 @@ struct OpPassManagerImpl; class OpToOpPassAdaptor; struct PassExecutionState; +class PassTiming; } // end namespace detail //===----------------------------------------------------------------------===// @@ -346,6 +347,25 @@ /// unintentionally included in the timing results. void enableTiming(std::unique_ptr config = nullptr); + /// A helper to time code paths outside the pass manager. + class ExternalTimer { + public: + virtual ~ExternalTimer() {} + + /// Start the timer. + virtual void start() = 0; + + /// Stop the timer. + virtual void stop() = 0; + }; + + /// Get a timer that can be used to include things outside of the pass manager + /// in the timing statistics. + std::unique_ptr getExternalTimer(Identifier name); + std::unique_ptr getExternalTimer(StringRef name) { + return getExternalTimer(Identifier::get(name, context)); + } + /// Prompts the pass manager to print the statistics collected for each of the /// held passes after each call to 'run'. void @@ -377,8 +397,8 @@ /// A hash key used to detect when reinitialization is necessary. llvm::hash_code initializationKey; - /// Flag that specifies if pass timing is enabled. - bool passTiming : 1; + /// Pointer to the pass timing instrumentation, if it is enabled. + detail::PassTiming *passTiming; /// Flag that specifies if the generated crash reproducer should be local. bool localReproducer : 1; Index: mlir/lib/Pass/Pass.cpp =================================================================== --- mlir/lib/Pass/Pass.cpp +++ mlir/lib/Pass/Pass.cpp @@ -854,7 +854,7 @@ StringRef operationName) : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx), initializationKey(DenseMapInfo::getTombstoneKey()), - passTiming(false), localReproducer(false), verifyPasses(true) {} + passTiming(nullptr), localReproducer(false), verifyPasses(true) {} PassManager::~PassManager() {} Index: mlir/lib/Pass/PassTiming.cpp =================================================================== --- mlir/lib/Pass/PassTiming.cpp +++ mlir/lib/Pass/PassTiming.cpp @@ -158,7 +158,10 @@ /// The type of timer this instance represents. TimerKind kind; }; +} // end anonymous namespace +namespace mlir { +namespace detail { struct PassTiming : public PassInstrumentation { PassTiming(std::unique_ptr config) : config(std::move(config)) {} @@ -188,6 +191,12 @@ /// Start a new timer for the given analysis. void startAnalysisTimer(StringRef name, TypeID id); + /// Start a new timer for the given external code path. + void startExternalTimer(Identifier name); + + /// Stop an external timer. + void stopExternalTimer(Identifier name); + /// Pop the last active timer for the current thread. Timer *popLastActiveTimer() { auto tid = llvm::get_threadid(); @@ -240,7 +249,8 @@ DenseMap> pipelinesToMerge; }; -} // end anonymous namespace +} // namespace detail +} // namespace mlir void PassTiming::runBeforePipeline(Identifier name, const PipelineParentInfo &parentInfo) { @@ -298,6 +308,13 @@ timer->start(); } +/// Start a new timer for the given external code path. +void PassTiming::startExternalTimer(Identifier name) { + Timer *timer = getTimer(name.getAsOpaquePointer(), TimerKind::PassOrAnalysis, + [name] { return name.str(); }); + timer->start(); +} + /// Stop a pass timer. void PassTiming::runAfterPass(Pass *pass, Operation *) { Timer *timer = popLastActiveTimer(); @@ -319,6 +336,9 @@ popLastActiveTimer()->stop(); } +/// Stop an external timer. +void PassTiming::stopExternalTimer(Identifier) { popLastActiveTimer()->stop(); } + /// Utility to print the timer heading information. static void printTimerHeader(raw_ostream &os, TimeRecord total) { os << "===" << std::string(73, '-') << "===\n"; @@ -468,6 +488,42 @@ return; if (!config) config = std::make_unique(); - addInstrumentation(std::make_unique(std::move(config))); - passTiming = true; + auto pt = std::make_unique(std::move(config)); + passTiming = pt.get(); + addInstrumentation(std::move(pt)); +} + +namespace { +/// An opaque handle for an external timer. Returned from `getExternalTimer()` +/// to allow the caller to include code paths outside the PM in the timing +/// statistic. +struct ExternalTimerImpl : public PassManager::ExternalTimer { + PassTiming *pt; + Identifier name; + bool running = false; + + ExternalTimerImpl(PassTiming *pt, Identifier name) : pt(pt), name(name) {} + virtual ~ExternalTimerImpl() { stop(); } + + void start() override { + if (pt && !running) { + pt->startExternalTimer(name); + running = true; + } + } + + void stop() override { + if (pt && running) { + pt->stopExternalTimer(name); + pt = nullptr; + } + } +}; +} // namespace + +/// Get a timer that can be used to include things outside of the pass manager +/// in the timing statistics. +std::unique_ptr +PassManager::getExternalTimer(Identifier name) { + return std::make_unique(passTiming, name); }