diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -23,6 +23,7 @@ class InFlightDiagnostic; class Location; class MLIRContextImpl; +class PassManager; class StorageUniquer; DialectRegistry &getGlobalDialectRegistry(); @@ -156,6 +157,12 @@ /// instances. This should not be used directly. StorageUniquer &getAttributeUniquer(); + /// These APIs are tracking whether the context will be used in a + /// multithreading environment: this has no effect other than enabling + /// assertions on some misuses of some APIs. + void enterMultiThreadedExecution(); + void exitMultiThreadedExecution(); + private: const std::unique_ptr impl; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -264,6 +264,11 @@ /// Enable support for multi-threading within MLIR. bool threadingIsEnabled = true; + /// Track if we are currently executing a threaded execution environment (like + /// the pass-manager): this is only a debugging feature to help reducing the + /// chances of data races one some context APIs. + std::atomic multiThreadedExecutionContext{0}; + /// If the operation should be attached to diagnostics printed via the /// Operation::emit methods. bool printOpOnDiagnostic = true; @@ -487,6 +492,15 @@ if (!dialect) { LLVM_DEBUG(llvm::dbgs() << "Load new dialect in Context" << dialectNamespace); +#ifndef NDEBUG + if (impl.multiThreadedExecutionContext != 0) { + llvm::errs() << "Loading a dialect (" << dialectNamespace + << ") while in a multi-threaded execution context (maybe " + "the PassManager): this can indicate a " + "missing `dependentDialects` in a pass for example."; + abort(); + } +#endif dialect = ctor(); assert(dialect && "dialect ctor failed"); return dialect.get(); @@ -527,6 +541,13 @@ impl->typeUniquer.disableMultithreading(disable); } +void MLIRContext::enterMultiThreadedExecution() { + ++impl->multiThreadedExecutionContext; +} +void MLIRContext::exitMultiThreadedExecution() { + --impl->multiThreadedExecutionContext; +} + /// Return true if we should attach the operation to diagnostics emitted via /// Operation::emit. bool MLIRContext::shouldPrintOpOnDiagnostic() { diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -751,12 +751,18 @@ // Construct an analysis manager for the pipeline. ModuleAnalysisManager am(module, instrumentor.get()); + // Notify the context that we start running a pipeline for book keeping. + module.getContext()->enterMultiThreadedExecution(); + // If reproducer generation is enabled, run the pass manager with crash // handling enabled. LogicalResult result = crashReproducerFileName ? runWithCrashRecovery(module, am) : OpPassManager::run(module, am); + // Notify the context that the run is done. + module.getContext()->exitMultiThreadedExecution(); + // Dump all of the pass statistics if necessary. if (passStatisticsMode) dumpStatistics(); diff --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp --- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp +++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp @@ -34,6 +34,10 @@ class TestConvertCallOp : public PassWrapper> { public: + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + } + void runOnOperation() override { ModuleOp m = getOperation();