diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -156,6 +156,12 @@ /// instances. This should not be used directly. StorageUniquer &getAttributeUniquer(); + /// These APIs are tracking whether the context will be used in a + /// multithreading environment: this has no effect other than enabling + /// assertions on misuses of some APIs. + void enterMultiThreadedExecution(); + void exitMultiThreadedExecution(); + private: const std::unique_ptr impl; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -264,6 +264,13 @@ /// Enable support for multi-threading within MLIR. bool threadingIsEnabled = true; + /// Track if we are currently executing in a threaded execution environment + /// (like the pass-manager): this is only a debugging feature to help reducing + /// the chances of data races one some context APIs. +#ifndef NDEBUG + std::atomic multiThreadedExecutionContext{0}; +#endif + /// If the operation should be attached to diagnostics printed via the /// Operation::emit methods. bool printOpOnDiagnostic = true; @@ -487,6 +494,15 @@ if (!dialect) { LLVM_DEBUG(llvm::dbgs() << "Load new dialect in Context" << dialectNamespace); +#ifndef NDEBUG + if (impl.multiThreadedExecutionContext != 0) { + llvm::errs() << "Loading a dialect (" << dialectNamespace + << ") while in a multi-threaded execution context (maybe " + "the PassManager): this can indicate a " + "missing `dependentDialects` in a pass for example."; + abort(); + } +#endif dialect = ctor(); assert(dialect && "dialect ctor failed"); return dialect.get(); @@ -527,6 +543,17 @@ impl->typeUniquer.disableMultithreading(disable); } +void MLIRContext::enterMultiThreadedExecution() { +#ifndef NDEBUG + ++impl->multiThreadedExecutionContext; +#endif +} +void MLIRContext::exitMultiThreadedExecution() { +#ifndef NDEBUG + --impl->multiThreadedExecutionContext; +#endif +} + /// Return true if we should attach the operation to diagnostics emitted via /// Operation::emit. bool MLIRContext::shouldPrintOpOnDiagnostic() { @@ -583,6 +610,9 @@ "op name doesn't start with dialect namespace"); assert(&opInfo.dialect == this && "Dialect object mismatch"); auto &impl = context->getImpl(); + assert(impl.multiThreadedExecutionContext == 0 && + "Registering a new operation kind while in a multi-threaded execution " + "context"); StringRef opName = opInfo.name; if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) { llvm::errs() << "error: operation named '" << opInfo.name @@ -593,6 +623,9 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { auto &impl = context->getImpl(); + assert(impl.multiThreadedExecutionContext == 0 && + "Registering a new type kind while in a multi-threaded execution " + "context"); auto *newInfo = new (impl.abstractDialectSymbolAllocator.Allocate()) AbstractType(std::move(typeInfo)); @@ -602,6 +635,9 @@ void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { auto &impl = context->getImpl(); + assert(impl.multiThreadedExecutionContext == 0 && + "Registering a new attribute kind while in a multi-threaded execution " + "context"); auto *newInfo = new (impl.abstractDialectSymbolAllocator.Allocate()) AbstractAttribute(std::move(attrInfo)); diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -751,12 +751,18 @@ // Construct an analysis manager for the pipeline. ModuleAnalysisManager am(module, instrumentor.get()); + // Notify the context that we start running a pipeline for book keeping. + module.getContext()->enterMultiThreadedExecution(); + // If reproducer generation is enabled, run the pass manager with crash // handling enabled. LogicalResult result = crashReproducerFileName ? runWithCrashRecovery(module, am) : OpPassManager::run(module, am); + // Notify the context that the run is done. + module.getContext()->exitMultiThreadedExecution(); + // Dump all of the pass statistics if necessary. if (passStatisticsMode) dumpStatistics(); diff --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp --- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp +++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp @@ -34,6 +34,10 @@ class TestConvertCallOp : public PassWrapper> { public: + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + } + void runOnOperation() override { ModuleOp m = getOperation();