diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -467,6 +467,9 @@ //===----------------------------------------------------------------------===// void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { + assert(impl->multiThreadedExecutionContext == 0 && + "appending to the MLIRContext dialect registry while in a " + "multi-threaded execution context"); registry.appendTo(impl->dialectsRegistry); // For the already loaded dialects, register the interfaces immediately. @@ -581,6 +584,9 @@ } void MLIRContext::allowUnregisteredDialects(bool allowing) { + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `allow-unregistered-dialects` configuration " + "while in a multi-threaded execution context"); impl->allowUnregisteredDialects = allowing; } @@ -595,6 +601,9 @@ // --mlir-disable-threading if (isThreadingGloballyDisabled()) return; + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `disable-threading` configuration while " + "in a multi-threaded execution context"); impl->threadingIsEnabled = !disable; @@ -658,6 +667,9 @@ /// Set the flag specifying if we should attach the operation to diagnostics /// emitted via Operation::emit. void MLIRContext::printOpOnDiagnostic(bool enable) { + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `print-op-on-diagnostic` configuration while in " + "a multi-threaded execution context"); impl->printOpOnDiagnostic = enable; } @@ -670,6 +682,9 @@ /// Set the flag specifying if we should attach the current stacktrace when /// emitting diagnostics. void MLIRContext::printStackTraceOnDiagnostic(bool enable) { + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `print-stacktrace-on-diagnostic` configuration " + "while in a multi-threaded execution context"); impl->printStackTraceOnDiagnostic = enable; } diff --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp --- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp @@ -73,10 +73,10 @@ linalg::getLinalgTilingCanonicalizationPatterns(context); scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns); - auto stage3Transforms = [](Operation *op) { - PassManager pm(op->getContext()); + auto stage3Transforms = [this](Operation *op) { + OpPassManager pm(ModuleOp::getOperationName()); pm.addPass(createLoopInvariantCodeMotionPass()); - if (failed(pm.run(cast(op)))) + if (failed(runPipeline(pm, op))) llvm_unreachable("Unexpected failure in cleanup pass pipeline."); op->walk([](FuncOp func) { promoteSingleIterationLoops(func); @@ -105,9 +105,9 @@ (void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns)); // Programmatic controlled lowering of linalg.copy and linalg.fill. - PassManager pm(context); + OpPassManager pm(ModuleOp::getOperationName()); pm.addNestedPass(createConvertLinalgToLoopsPass()); - if (failed(pm.run(module))) + if (failed(runPipeline(pm, module))) llvm_unreachable("Unexpected failure in linalg to loops pass."); // Programmatic controlled lowering of vector.contract only. diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -239,14 +239,13 @@ patterns.add(context); scf::populateSCFForLoopCanonicalizationPatterns(patterns); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + OpPassManager pm(FuncOp::getOperationName()); + pm.addPass(createLoopInvariantCodeMotionPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); do { (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns); - PassManager pm(context); - pm.addPass(createLoopInvariantCodeMotionPass()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - LogicalResult res = pm.run(getFunction()->getParentOfType()); - if (failed(res)) + if (failed(runPipeline(pm, getFunction()))) this->signalPassFailure(); } while (succeeded(fuseLinalgOpsGreedily(getFunction()))); }