diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h --- a/mlir/include/mlir/IR/DialectRegistry.h +++ b/mlir/include/mlir/IR/DialectRegistry.h @@ -212,6 +212,10 @@ addExtension(std::make_unique(std::move(extensionFn))); } + /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs' + /// contains all of the components of this registry. + bool isSubsetOf(const DialectRegistry &rhs) const; + private: MapTy registry; std::vector> extensions; diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -228,3 +228,12 @@ for (const auto &extension : extensions) applyExtension(*extension); } + +bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const { + // Treat any extensions conservatively. + if (!extensions.empty()) + return false; + // Check that the current dialects fully overlap with the dialects in 'rhs'. + return llvm::all_of( + registry, [&](const auto &it) { return rhs.registry.count(it.first); }); +} diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -355,6 +355,12 @@ //===----------------------------------------------------------------------===// void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { + if (registry.isSubsetOf(impl->dialectsRegistry)) + return; + + assert(impl->multiThreadedExecutionContext == 0 && + "appending to the MLIRContext dialect registry while in a " + "multi-threaded execution context"); registry.appendTo(impl->dialectsRegistry); // For the already loaded dialects, apply any possible extensions immediately. @@ -470,6 +476,9 @@ } void MLIRContext::allowUnregisteredDialects(bool allowing) { + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `allow-unregistered-dialects` configuration " + "while in a multi-threaded execution context"); impl->allowUnregisteredDialects = allowing; } @@ -484,6 +493,9 @@ // --mlir-disable-threading if (isThreadingGloballyDisabled()) return; + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `disable-threading` configuration while " + "in a multi-threaded execution context"); impl->threadingIsEnabled = !disable; @@ -557,6 +569,9 @@ /// Set the flag specifying if we should attach the operation to diagnostics /// emitted via Operation::emit. void MLIRContext::printOpOnDiagnostic(bool enable) { + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `print-op-on-diagnostic` configuration while in " + "a multi-threaded execution context"); impl->printOpOnDiagnostic = enable; } @@ -569,6 +584,9 @@ /// Set the flag specifying if we should attach the current stacktrace when /// emitting diagnostics. void MLIRContext::printStackTraceOnDiagnostic(bool enable) { + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `print-stacktrace-on-diagnostic` configuration " + "while in a multi-threaded execution context"); impl->printStackTraceOnDiagnostic = enable; } diff --git a/mlir/lib/Reducer/OptReductionPass.cpp b/mlir/lib/Reducer/OptReductionPass.cpp --- a/mlir/lib/Reducer/OptReductionPass.cpp +++ b/mlir/lib/Reducer/OptReductionPass.cpp @@ -42,7 +42,7 @@ ModuleOp module = this->getOperation(); ModuleOp moduleVariant = module.clone(); - PassManager passManager(module.getContext()); + OpPassManager passManager("builtin.module"); if (failed(parsePassPipeline(optPass, passManager))) { module.emitError() << "\nfailed to parse pass pipeline"; return signalPassFailure(); @@ -54,7 +54,13 @@ return signalPassFailure(); } - if (failed(passManager.run(moduleVariant))) { + // Temporarily push the variant under the main module and execute the pipeline + // on it. + module.getBody()->push_back(moduleVariant); + LogicalResult pipelineResult = runPipeline(passManager, moduleVariant); + moduleVariant->remove(); + + if (failed(pipelineResult)) { module.emitError() << "\nfailed to run pass pipeline"; return signalPassFailure(); } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -255,14 +255,13 @@ patterns.add(context); scf::populateSCFForLoopCanonicalizationPatterns(patterns); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + OpPassManager pm(FuncOp::getOperationName()); + pm.addPass(createLoopInvariantCodeMotionPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); do { (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns); - PassManager pm(context); - pm.addPass(createLoopInvariantCodeMotionPass()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - LogicalResult res = pm.run(getOperation()->getParentOfType()); - if (failed(res)) + if (failed(runPipeline(pm, getOperation()))) this->signalPassFailure(); } while (succeeded(fuseLinalgOpsGreedily(getOperation()))); }