diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -351,6 +351,9 @@ //===----------------------------------------------------------------------===// void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { + assert(impl->multiThreadedExecutionContext == 0 && + "appending to the MLIRContext dialect registry while in a " + "multi-threaded execution context"); registry.appendTo(impl->dialectsRegistry); // For the already loaded dialects, register the interfaces immediately. @@ -467,6 +470,9 @@ } void MLIRContext::allowUnregisteredDialects(bool allowing) { + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `allow-unregistered-dialects` configuration " + "while in a multi-threaded execution context"); impl->allowUnregisteredDialects = allowing; } @@ -481,6 +487,9 @@ // --mlir-disable-threading if (isThreadingGloballyDisabled()) return; + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `disable-threading` configuration while " + "in a multi-threaded execution context"); impl->threadingIsEnabled = !disable; @@ -554,6 +563,9 @@ /// Set the flag specifying if we should attach the operation to diagnostics /// emitted via Operation::emit. void MLIRContext::printOpOnDiagnostic(bool enable) { + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `print-op-on-diagnostic` configuration while in " + "a multi-threaded execution context"); impl->printOpOnDiagnostic = enable; } @@ -566,6 +578,9 @@ /// Set the flag specifying if we should attach the current stacktrace when /// emitting diagnostics. void MLIRContext::printStackTraceOnDiagnostic(bool enable) { + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `print-stacktrace-on-diagnostic` configuration " + "while in a multi-threaded execution context"); impl->printStackTraceOnDiagnostic = enable; } diff --git a/mlir/lib/Reducer/OptReductionPass.cpp b/mlir/lib/Reducer/OptReductionPass.cpp --- a/mlir/lib/Reducer/OptReductionPass.cpp +++ b/mlir/lib/Reducer/OptReductionPass.cpp @@ -42,7 +42,7 @@ ModuleOp module = this->getOperation(); ModuleOp moduleVariant = module.clone(); - PassManager passManager(module.getContext()); + OpPassManager passManager("builtin.module"); if (failed(parsePassPipeline(optPass, passManager))) { module.emitError() << "\nfailed to parse pass pipeline"; return signalPassFailure(); @@ -54,7 +54,13 @@ return signalPassFailure(); } - if (failed(passManager.run(moduleVariant))) { + // Temporarily push the variant under the main module and execute the pipeline + // on it. + module.getBody()->push_back(moduleVariant); + LogicalResult pipelineResult = runPipeline(passManager, moduleVariant); + moduleVariant->remove(); + + if (failed(pipelineResult)) { module.emitError() << "\nfailed to run pass pipeline"; return signalPassFailure(); } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -241,14 +241,13 @@ patterns.add(context); scf::populateSCFForLoopCanonicalizationPatterns(patterns); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + OpPassManager pm(FuncOp::getOperationName()); + pm.addPass(createLoopInvariantCodeMotionPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); do { (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns); - PassManager pm(context); - pm.addPass(createLoopInvariantCodeMotionPass()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - LogicalResult res = pm.run(getFunction()->getParentOfType()); - if (failed(res)) + if (failed(runPipeline(pm, getFunction()))) this->signalPassFailure(); } while (succeeded(fuseLinalgOpsGreedily(getFunction()))); }