diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -23,6 +23,7 @@ class InFlightDiagnostic; class Location; class MLIRContextImpl; +class PassManager; class StorageUniquer; DialectRegistry &getGlobalDialectRegistry(); @@ -169,6 +170,12 @@ MLIRContext(const MLIRContext &) = delete; void operator=(const MLIRContext &) = delete; + + /// These APIs are tracking whether a PassManager is running in order to + /// detect misuse of some APIs. + friend class PassManager; + void passManagerStart(); + void passManagerExit(); }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -264,6 +264,11 @@ /// Enable support for multi-threading within MLIR. bool threadingIsEnabled = true; + /// Track if we are currently executing a threaded pass-manager: this is only + /// a debugging feature to help reducing the chances of data races one some + /// context APIs. + std::atomic passManagersRunning = 0; + /// If the operation should be attached to diagnostics printed via the /// Operation::emit methods. bool printOpOnDiagnostic = true; @@ -487,6 +492,8 @@ if (!dialect) { LLVM_DEBUG(llvm::dbgs() << "Load new dialect in Context" << dialectNamespace); + assert(impl.passManagersRunning == 0 && + "Loading a dialect while the PassManager is running"); dialect = ctor(); assert(dialect && "dialect ctor failed"); return dialect.get(); @@ -527,6 +534,9 @@ impl->typeUniquer.disableMultithreading(disable); } +void MLIRContext::passManagerStart() { ++impl->passManagersRunning; } +void MLIRContext::passManagerExit() { --impl->passManagersRunning; } + /// Return true if we should attach the operation to diagnostics emitted via /// Operation::emit. bool MLIRContext::shouldPrintOpOnDiagnostic() { diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -751,12 +751,18 @@ // Construct an analysis manager for the pipeline. ModuleAnalysisManager am(module, instrumentor.get()); + // Notify the context that we start running a pipeline for book keeping. + module.getContext()->passManagerStart(); + // If reproducer generation is enabled, run the pass manager with crash // handling enabled. LogicalResult result = crashReproducerFileName ? runWithCrashRecovery(module, am) : OpPassManager::run(module, am); + // Notify the context that the run is done. + module.getContext()->passManagerExit(); + // Dump all of the pass statistics if necessary. if (passStatisticsMode) dumpStatistics(); diff --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp --- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp +++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp @@ -34,6 +34,10 @@ class TestConvertCallOp : public PassWrapper> { public: + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + } + void runOnOperation() override { ModuleOp m = getOperation();