diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -70,7 +70,14 @@ /// (Operations, Types, Attributes), other than dialect that exists in the /// input. For example, a pass that converts from Linalg to Affine would /// register the Affine dialect but does not need to register Linalg. + /// The context may be used to identify the dialects that may be produced by + /// the pass based on the presence of some registerable objects, such as + /// interfaces and dialect extensions. virtual void getDependentDialects(DialectRegistry ®istry) const {} + virtual void getDependentDialects(DialectRegistry ®istry, + MLIRContext &context) const { + getDependentDialects(registry); + } /// Return the command line argument used when registering this pass. Return /// an empty string if one does not exist. diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -157,7 +157,8 @@ /// Register dependent dialects for the current pass manager. /// This is forwarding to every pass in this PassManager, see the /// documentation for the same method on the Pass class. - void getDependentDialects(DialectRegistry &dialects) const; + void getDependentDialects(DialectRegistry &dialects, + MLIRContext &context) const; /// Enable or disable the implicit nesting on this particular PassManager. /// This will also apply to any newly nested PassManager built from this diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -380,13 +380,15 @@ } static void registerDialectsForPipeline(const OpPassManager &pm, - DialectRegistry &dialects) { + DialectRegistry &dialects, + MLIRContext &context) { for (const Pass &pass : pm.getPasses()) - pass.getDependentDialects(dialects); + pass.getDependentDialects(dialects, context); } -void OpPassManager::getDependentDialects(DialectRegistry &dialects) const { - registerDialectsForPipeline(*this, dialects); +void OpPassManager::getDependentDialects(DialectRegistry &dialects, + MLIRContext &context) const { + registerDialectsForPipeline(*this, dialects, context); } void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; } @@ -565,9 +567,10 @@ mgrs.emplace_back(std::move(mgr)); } -void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const { +void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects, + MLIRContext &context) const { for (auto &pm : mgrs) - pm.getDependentDialects(dialects); + pm.getDependentDialects(dialects, context); } LogicalResult OpToOpPassAdaptor::tryMergeInto(MLIRContext *ctx, @@ -788,7 +791,7 @@ // Register all dialects for the current pipeline. DialectRegistry dependentDialects; - getDependentDialects(dependentDialects); + getDependentDialects(dependentDialects, *context); context->appendDialectRegistry(dependentDialects); for (StringRef name : dependentDialects.getDialectNames()) context->getOrLoadDialect(name); diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -45,7 +45,8 @@ /// Populate the set of dependent dialects for the passes in the current /// adaptor. - void getDependentDialects(DialectRegistry &dialects) const override; + void getDependentDialects(DialectRegistry &dialects, + MLIRContext &context) const override; /// Return the async pass managers held by this parallel adaptor. MutableArrayRef> getParallelPassManagers() { diff --git a/mlir/test/lib/Pass/TestDynamicPipeline.cpp b/mlir/test/lib/Pass/TestDynamicPipeline.cpp --- a/mlir/test/lib/Pass/TestDynamicPipeline.cpp +++ b/mlir/test/lib/Pass/TestDynamicPipeline.cpp @@ -28,11 +28,12 @@ return "Tests the dynamic pipeline feature by applying " "a pipeline on a selected set of functions"; } - void getDependentDialects(DialectRegistry ®istry) const override { + void getDependentDialects(DialectRegistry ®istry, + MLIRContext &context) const override { OpPassManager pm(ModuleOp::getOperationName(), OpPassManager::Nesting::Implicit); (void)parsePassPipeline(pipeline, pm, llvm::errs()); - pm.getDependentDialects(registry); + pm.getDependentDialects(registry, context); } TestDynamicPipelinePass() = default; diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -84,7 +84,6 @@ } } -namespace { struct InvalidPass : Pass { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass) @@ -101,16 +100,15 @@ *static_cast(this)); } }; -} // namespace TEST(PassManagerTest, InvalidPass) { MLIRContext context; context.allowUnregisteredDialects(); - // Create a module + // Create a module. OwningOpRef module(ModuleOp::create(UnknownLoc::get(&context))); - // Add a single "invalid_op" operation + // Add a single "invalid_op" operation. OpBuilder builder(&module->getBodyRegion()); OperationState state(UnknownLoc::get(&context), "invalid_op"); builder.insert(Operation::create(state)); @@ -141,4 +139,46 @@ ASSERT_DEATH(pm.addPass(std::make_unique()), ""); } +/// A dummy pass that lists as dependent all dialects registered with the +/// context. +struct ContextuallyDependentPass + : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ContextuallyDependentPass) + + StringRef getName() const override { + return "test-contextually-dependent-pass"; + } + + void runOnOperation() override {} + void getDependentDialects(DialectRegistry ®istry, + MLIRContext &context) const override { + context.getDialectRegistry().appendTo(registry); + } +}; + +TEST(PassManagerTest, ContextuallyDependentPass) { + // Create an empty module. + MLIRContext context; + OwningOpRef module(ModuleOp::create(UnknownLoc::get(&context))); + + // Run the pass using the pass manager. This loads all dialects declared as + // dependent by the pass. The pass declares as dependent every registered + // dialect. Since the Func dialect is not registered, it should not be loaded + // after the pass manager completes. + PassManager pm(&context); + pm.addPass(std::make_unique()); + LogicalResult result = pm.run(module.get()); + ASSERT_TRUE(succeeded(result)); + EXPECT_EQ(context.getLoadedDialect(), nullptr); + + // Register the Func dialect with the context and run the same pass manager + // again. Now the dialect should have been loaded by the pass manager. + DialectRegistry registry; + registry.insert(); + context.appendDialectRegistry(registry); + result = pm.run(module.get()); + ASSERT_TRUE(succeeded(result)); + EXPECT_NE(context.getLoadedDialect(), nullptr); +} + } // namespace