diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp --- a/mlir/examples/toy/Ch5/toyc.cpp +++ b/mlir/examples/toy/Ch5/toyc.cpp @@ -136,10 +136,10 @@ } if (isLoweringToAffine) { - // Partially lower the toy dialect with a few cleanups afterwards. - pm.addPass(mlir::toy::createLowerToAffinePass()); - mlir::OpPassManager &optPM = pm.nest(); + + // Partially lower the toy dialect with a few cleanups afterwards. + optPM.addPass(mlir::toy::createLowerToAffinePass()); optPM.addPass(mlir::createCanonicalizerPass()); optPM.addPass(mlir::createCSEPass()); diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp --- a/mlir/examples/toy/Ch6/toyc.cpp +++ b/mlir/examples/toy/Ch6/toyc.cpp @@ -150,10 +150,10 @@ } if (isLoweringToAffine) { - // Partially lower the toy dialect with a few cleanups afterwards. - pm.addPass(mlir::toy::createLowerToAffinePass()); - mlir::OpPassManager &optPM = pm.nest(); + + // Partially lower the toy dialect with a few cleanups afterwards. + optPM.addPass(mlir::toy::createLowerToAffinePass()); optPM.addPass(mlir::createCanonicalizerPass()); optPM.addPass(mlir::createCSEPass()); diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp --- a/mlir/examples/toy/Ch7/toyc.cpp +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -151,10 +151,10 @@ } if (isLoweringToAffine) { - // Partially lower the toy dialect with a few cleanups afterwards. - pm.addPass(mlir::toy::createLowerToAffinePass()); - mlir::OpPassManager &optPM = pm.nest(); + + // Partially lower the toy dialect with a few cleanups afterwards. + optPM.addPass(mlir::toy::createLowerToAffinePass()); optPM.addPass(mlir::createCanonicalizerPass()); optPM.addPass(mlir::createCSEPass()); diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -48,8 +48,9 @@ /// other OpPassManagers or the top-level PassManager. class OpPassManager { public: - OpPassManager(Identifier name); - OpPassManager(StringRef name); + enum class Nesting { Implicit, Explicit }; + OpPassManager(Identifier name, Nesting nesting); + OpPassManager(StringRef name, Nesting nesting); OpPassManager(OpPassManager &&rhs); OpPassManager(const OpPassManager &rhs); ~OpPassManager(); @@ -150,7 +151,7 @@ class PassManager : public OpPassManager { public: // If verifyPasses is true, the verifier is run after each pass. - PassManager(MLIRContext *ctx); + PassManager(MLIRContext *ctx, Nesting nesting = Nesting::Explicit); ~PassManager(); /// Run the passes within this manager on the provided module. diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -91,9 +91,10 @@ namespace mlir { namespace detail { struct OpPassManagerImpl { - OpPassManagerImpl(Identifier identifier) - : name(identifier), identifier(identifier) {} - OpPassManagerImpl(StringRef name) : name(name) {} + OpPassManagerImpl(Identifier identifier, OpPassManager::Nesting nesting) + : name(identifier), identifier(identifier), nesting(nesting) {} + OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting) + : name(name), nesting(nesting) {} /// Merge the passes of this pass manager into the one provided. void mergeInto(OpPassManagerImpl &rhs); @@ -130,6 +131,10 @@ /// The set of passes to run as part of this pass manager. std::vector> passes; + + /// Control the implicit nesting of passes that mismatch the name set for this + /// OpPassManager. + OpPassManager::Nesting nesting; }; } // end namespace detail } // end namespace mlir @@ -142,14 +147,14 @@ } OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) { - OpPassManager nested(nestedName); + OpPassManager nested(nestedName, nesting); auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); addPass(std::unique_ptr(adaptor)); return adaptor->getPassManagers().front(); } OpPassManager &OpPassManagerImpl::nest(StringRef nestedName) { - OpPassManager nested(nestedName); + OpPassManager nested(nestedName, nesting); auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); addPass(std::unique_ptr(adaptor)); return adaptor->getPassManagers().front(); @@ -157,10 +162,16 @@ void OpPassManagerImpl::addPass(std::unique_ptr pass) { // If this pass runs on a different operation than this pass manager, then - // implicitly nest a pass manager for this operation. + // implicitly nest a pass manager for this operation if enabled. auto passOpName = pass->getOpName(); - if (passOpName && passOpName != name) - return nest(*passOpName).addPass(std::move(pass)); + if (passOpName && passOpName != name) { + if (nesting == OpPassManager::Nesting::Implicit) + return nest(*passOpName).addPass(std::move(pass)); + llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() + + "' restricted to '" + *passOpName + + "' on a PassManager intended to run on '" + name + + "', did you intend to nest?"); + } passes.emplace_back(std::move(pass)); } @@ -240,14 +251,14 @@ // OpPassManager //===----------------------------------------------------------------------===// -OpPassManager::OpPassManager(Identifier name) - : impl(new OpPassManagerImpl(name)) {} -OpPassManager::OpPassManager(StringRef name) - : impl(new OpPassManagerImpl(name)) {} +OpPassManager::OpPassManager(Identifier name, Nesting nesting) + : impl(new OpPassManagerImpl(name, nesting)) {} +OpPassManager::OpPassManager(StringRef name, Nesting nesting) + : impl(new OpPassManagerImpl(name, nesting)) {} OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {} OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; } OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) { - impl.reset(new OpPassManagerImpl(rhs.impl->name)); + impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->nesting)); for (auto &pass : rhs.impl->passes) impl->passes.emplace_back(pass->clone()); return *this; @@ -784,8 +795,9 @@ // PassManager //===----------------------------------------------------------------------===// -PassManager::PassManager(MLIRContext *ctx) - : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx)), +PassManager::PassManager(MLIRContext *ctx, Nesting nesting) + : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx), + nesting), context(ctx), passTiming(false), localReproducer(false), verifyPasses(true) {} diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -58,7 +58,7 @@ return failure(); // Apply any pass manager command line options. - PassManager pm(context); + PassManager pm(context, OpPassManager::Nesting::Implicit); pm.enableVerifier(verifyPasses); applyPassManagerCLOptions(pm); diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp --- a/mlir/test/lib/Transforms/TestConvVectorization.cpp +++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp @@ -100,7 +100,7 @@ // Programmatic controlled lowering of linalg.copy and linalg.fill. PassManager pm(context); - pm.addPass(createConvertLinalgToLoopsPass()); + pm.addNestedPass(createConvertLinalgToLoopsPass()); if (failed(pm.run(module))) llvm_unreachable("Unexpected failure in linalg to loops pass."); diff --git a/mlir/test/lib/Transforms/TestDynamicPipeline.cpp b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp --- a/mlir/test/lib/Transforms/TestDynamicPipeline.cpp +++ b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp @@ -25,7 +25,8 @@ : public PassWrapper> { public: void getDependentDialects(DialectRegistry ®istry) const override { - OpPassManager pm(ModuleOp::getOperationName()); + OpPassManager pm(ModuleOp::getOperationName(), + OpPassManager::Nesting::Implicit); parsePassPipeline(pipeline, pm, llvm::errs()); pm.getDependentDialects(registry); } @@ -54,7 +55,8 @@ } if (!pm) { pm = std::make_unique( - getOperation()->getName().getIdentifier()); + getOperation()->getName().getIdentifier(), + OpPassManager::Nesting::Implicit); parsePassPipeline(pipeline, *pm, llvm::errs()); } diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -108,7 +108,7 @@ // Instantiate and run our pass. PassManager pm(&context); - pm.addPass(std::make_unique()); + pm.nest("invalid_op").addPass(std::make_unique()); LogicalResult result = pm.run(module.get()); EXPECT_TRUE(failed(result)); ASSERT_TRUE(diagnostic.get() != nullptr);