diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -48,8 +48,8 @@ /// other OpPassManagers or the top-level PassManager. class OpPassManager { public: - OpPassManager(Identifier name, bool verifyPasses); - OpPassManager(StringRef name, bool verifyPasses); + OpPassManager(Identifier name); + OpPassManager(StringRef name); OpPassManager(OpPassManager &&rhs); OpPassManager(const OpPassManager &rhs); ~OpPassManager(); @@ -149,8 +149,7 @@ /// The main pass manager and pipeline builder. class PassManager : public OpPassManager { public: - // If verifyPasses is true, the verifier is run after each pass. - PassManager(MLIRContext *ctx, bool verifyPasses = true); + PassManager(MLIRContext *ctx); ~PassManager(); /// Run the passes within this manager on the provided module. @@ -168,6 +167,9 @@ void enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer = false); + /// Runs the verifier after each individual pass. + void enableVerifier(bool enabled = true); + //===--------------------------------------------------------------------===// // Instrumentations //===--------------------------------------------------------------------===// @@ -330,6 +332,9 @@ /// Flag that specifies if the generated crash reproducer should be local. bool localReproducer : 1; + + /// A flag that indicates if the IR should be verified in between passes. + bool verifyPasses : 1; }; /// Register a set of useful command-line options that can be used to configure diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -91,10 +91,9 @@ namespace mlir { namespace detail { struct OpPassManagerImpl { - OpPassManagerImpl(Identifier identifier, bool verifyPasses) - : name(identifier), identifier(identifier), verifyPasses(verifyPasses) {} - OpPassManagerImpl(StringRef name, bool verifyPasses) - : name(name), verifyPasses(verifyPasses) {} + OpPassManagerImpl(Identifier identifier) + : name(identifier), identifier(identifier) {} + OpPassManagerImpl(StringRef name) : name(name) {} /// Merge the passes of this pass manager into the one provided. void mergeInto(OpPassManagerImpl &rhs); @@ -129,9 +128,6 @@ /// operation that passes of this pass manager operate on. Optional identifier; - /// Flag that specifies if the IR should be verified after each pass has run. - bool verifyPasses : 1; - /// The set of passes to run as part of this pass manager. std::vector> passes; }; @@ -146,14 +142,14 @@ } OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) { - OpPassManager nested(nestedName, verifyPasses); + OpPassManager nested(nestedName); auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); addPass(std::unique_ptr(adaptor)); return adaptor->getPassManagers().front(); } OpPassManager &OpPassManagerImpl::nest(StringRef nestedName) { - OpPassManager nested(nestedName, verifyPasses); + OpPassManager nested(nestedName); auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); addPass(std::unique_ptr(adaptor)); return adaptor->getPassManagers().front(); @@ -167,8 +163,6 @@ return nest(*passOpName).addPass(std::move(pass)); passes.emplace_back(std::move(pass)); - if (verifyPasses) - passes.emplace_back(std::make_unique()); } void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() { @@ -193,14 +187,6 @@ // Otherwise, merge into the existing adaptor and delete the current one. currentAdaptor->mergeInto(*lastAdaptor); it->reset(); - - // If the verifier is enabled, then next pass is a verifier run so - // drop it. Verifier passes are inserted after every pass, so this one - // would be a duplicate. - if (verifyPasses) { - assert(std::next(it) != e && isa(*std::next(it))); - (++it)->reset(); - } } else if (lastAdaptor && !isa(*it)) { // If this pass is not an adaptor and not a verifier pass, then coalesce // and forget any existing adaptor. @@ -254,14 +240,14 @@ // OpPassManager //===----------------------------------------------------------------------===// -OpPassManager::OpPassManager(Identifier name, bool verifyPasses) - : impl(new OpPassManagerImpl(name, verifyPasses)) {} -OpPassManager::OpPassManager(StringRef name, bool verifyPasses) - : impl(new OpPassManagerImpl(name, verifyPasses)) {} +OpPassManager::OpPassManager(Identifier name) + : impl(new OpPassManagerImpl(name)) {} +OpPassManager::OpPassManager(StringRef name) + : impl(new OpPassManagerImpl(name)) {} OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {} OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; } OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) { - impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->verifyPasses)); + impl.reset(new OpPassManagerImpl(rhs.impl->name)); for (auto &pass : rhs.impl->passes) impl->passes.emplace_back(pass->clone()); return *this; @@ -356,7 +342,7 @@ //===----------------------------------------------------------------------===// LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op, - AnalysisManager am) { + AnalysisManager am, bool verifyPasses) { if (!op->getName().getAbstractOperation()) return op->emitOpError() << "trying to schedule a pass on an unregistered operation"; @@ -368,18 +354,18 @@ // Initialize the pass state with a callback for the pass to dynamically // execute a pipeline on the currently visited operation. auto dynamic_pipeline_callback = - [op, &am](OpPassManager &pipeline, Operation *root) { - if (!op->isAncestor(root)) { - root->emitOpError() - << "Trying to schedule a dynamic pipeline on an " - "operation that isn't " - "nested under the current operation the pass is processing"; - return failure(); - } - AnalysisManager nestedAm = am.nest(root); - return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, - nestedAm); - }; + [op, &am, verifyPasses](OpPassManager &pipeline, + Operation *root) -> LogicalResult { + if (!op->isAncestor(root)) + return root->emitOpError() + << "Trying to schedule a dynamic pipeline on an " + "operation that isn't " + "nested under the current operation the pass is processing"; + + AnalysisManager nestedAm = am.nest(root); + return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm, + verifyPasses); + }; pass->passState.emplace(op, am, dynamic_pipeline_callback); // Instrument before the pass has run. PassInstrumentor *pi = am.getPassInstrumentor(); @@ -387,13 +373,20 @@ pi->runBeforePass(pass, op); // Invoke the virtual runOnOperation method. - pass->runOnOperation(); + if (auto *adaptor = dyn_cast(pass)) + adaptor->runOnOperation(verifyPasses); + else + pass->runOnOperation(); + bool passFailed = pass->passState->irAndPassFailed.getInt(); // Invalidate any non preserved analyses. am.invalidate(pass->passState->preservedAnalyses); + // Run the verifier if this pass didn't fail already. + if (!passFailed && verifyPasses) + passFailed = failed(verify(op)); + // Instrument after the pass has run. - bool passFailed = pass->passState->irAndPassFailed.getInt(); if (pi) { if (passFailed) pi->runAfterPassFailed(pass, op); @@ -408,7 +401,7 @@ /// Run the given operation and analysis manager on a provided op pass manager. LogicalResult OpToOpPassAdaptor::runPipeline( iterator_range passes, Operation *op, - AnalysisManager am) { + AnalysisManager am, bool verifyPasses) { auto scope_exit = llvm::make_scope_exit([&] { // Clear out any computed operation analyses. These analyses won't be used // any more in this pipeline, and this helps reduce the current working set @@ -419,7 +412,7 @@ // Run the pipeline over the provided operation. for (Pass &pass : passes) - if (failed(run(&pass, op, am))) + if (failed(run(&pass, op, am, verifyPasses))) return failure(); return success(); @@ -485,16 +478,21 @@ return os.str(); } -/// Run the held pipeline over all nested operations. void OpToOpPassAdaptor::runOnOperation() { + llvm_unreachable( + "Unexpected call to Pass::runOnOperation() on OpToOpPassAdaptor"); +} + +/// Run the held pipeline over all nested operations. +void OpToOpPassAdaptor::runOnOperation(bool verifyPasses) { if (getContext().isMultithreadingEnabled()) - runOnOperationAsyncImpl(); + runOnOperationAsyncImpl(verifyPasses); else - runOnOperationImpl(); + runOnOperationImpl(verifyPasses); } /// Run this pass adaptor synchronously. -void OpToOpPassAdaptor::runOnOperationImpl() { +void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) { auto am = getAnalysisManager(); PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(), this}; @@ -511,7 +509,8 @@ // Run the held pipeline over the current operation. if (instrumentor) instrumentor->runBeforePipeline(opName, parentInfo); - auto result = runPipeline(mgr->getPasses(), &op, am.nest(&op)); + LogicalResult result = + runPipeline(mgr->getPasses(), &op, am.nest(&op), verifyPasses); if (instrumentor) instrumentor->runAfterPipeline(opName, parentInfo); @@ -532,7 +531,7 @@ } /// Run this pass adaptor synchronously. -void OpToOpPassAdaptor::runOnOperationAsyncImpl() { +void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) { AnalysisManager am = getAnalysisManager(); // Create the async executors if they haven't been created, or if the main @@ -594,7 +593,7 @@ if (instrumentor) instrumentor->runBeforePipeline(opName, parentInfo); auto pipelineResult = - runPipeline(pm->getPasses(), it.first, it.second); + runPipeline(pm->getPasses(), it.first, it.second, verifyPasses); if (instrumentor) instrumentor->runAfterPipeline(opName, parentInfo); @@ -741,15 +740,11 @@ // isolation. impl->splitAdaptorPasses(); - // If this is a local producer, run each of the passes individually. If the - // verifier is enabled, each pass will have a verifier after. This is included - // in the recovery run. - unsigned stride = impl->verifyPasses ? 2 : 1; + // If this is a local producer, run each of the passes individually. MutableArrayRef> passes = impl->passes; - for (unsigned i = 0, e = passes.size(); i != e; i += stride) { - if (failed(runWithCrashRecovery(passes.slice(i, stride), module, am))) + for (std::unique_ptr &pass : passes) + if (failed(runWithCrashRecovery(pass, module, am))) return failure(); - } return success(); } @@ -759,7 +754,7 @@ ModuleOp module, AnalysisManager am) { RecoveryReproducerContext context(passes, module, *crashReproducerFileName, !getContext()->isMultithreadingEnabled(), - impl->verifyPasses); + verifyPasses); // Safely invoke the passes within a recovery context. llvm::CrashRecoveryContext::Enable(); @@ -767,7 +762,7 @@ llvm::CrashRecoveryContext recoveryContext; recoveryContext.RunSafelyOnThread([&] { for (std::unique_ptr &pass : passes) - if (failed(OpToOpPassAdaptor::run(pass.get(), module, am))) + if (failed(OpToOpPassAdaptor::run(pass.get(), module, am, verifyPasses))) return; passManagerResult = success(); }); @@ -788,13 +783,15 @@ // PassManager //===----------------------------------------------------------------------===// -PassManager::PassManager(MLIRContext *ctx, bool verifyPasses) - : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx), - verifyPasses), - context(ctx), passTiming(false), localReproducer(false) {} +PassManager::PassManager(MLIRContext *ctx) + : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx)), + context(ctx), passTiming(false), localReproducer(false), + verifyPasses(true) {} PassManager::~PassManager() {} +void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; } + /// Run the passes within this manager on the provided module. LogicalResult PassManager::run(ModuleOp module) { // Before running, make sure to coalesce any adjacent pass adaptors in the @@ -814,10 +811,10 @@ // If reproducer generation is enabled, run the pass manager with crash // handling enabled. - LogicalResult result = - crashReproducerFileName - ? runWithCrashRecovery(module, am) - : OpToOpPassAdaptor::runPipeline(getPasses(), module, am); + LogicalResult result = crashReproducerFileName + ? runWithCrashRecovery(module, am) + : OpToOpPassAdaptor::runPipeline( + getPasses(), module, am, verifyPasses); // Notify the context that the run is done. module.getContext()->exitMultiThreadedExecution(); diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -35,6 +35,7 @@ OpToOpPassAdaptor(const OpToOpPassAdaptor &rhs) = default; /// Run the held pipeline over all operations. + void runOnOperation(bool verifyPasses); void runOnOperation() override; /// Merge the current pass adaptor into given 'rhs'. @@ -57,19 +58,20 @@ private: /// Run this pass adaptor synchronously. - void runOnOperationImpl(); + void runOnOperationImpl(bool verifyPasses); /// Run this pass adaptor asynchronously. - void runOnOperationAsyncImpl(); + void runOnOperationAsyncImpl(bool verifyPasses); /// Run the given operation and analysis manager on a single pass. - static LogicalResult run(Pass *pass, Operation *op, AnalysisManager am); + static LogicalResult run(Pass *pass, Operation *op, AnalysisManager am, + bool verifyPasses); /// Run the given operation and analysis manager on a provided op pass /// manager. static LogicalResult runPipeline(iterator_range passes, - Operation *op, AnalysisManager am); + Operation *op, AnalysisManager am, bool verifyPasses); /// A set of adaptors to run. SmallVector mgrs; diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -58,7 +58,8 @@ return failure(); // Apply any pass manager command line options. - PassManager pm(context, verifyPasses); + PassManager pm(context); + pm.enableVerifier(verifyPasses); applyPassManagerCLOptions(pm); // Build the provided pipeline. diff --git a/mlir/test/Pass/pass-timing.mlir b/mlir/test/Pass/pass-timing.mlir --- a/mlir/test/Pass/pass-timing.mlir +++ b/mlir/test/Pass/pass-timing.mlir @@ -8,7 +8,6 @@ // LIST: Total Execution Time: // LIST: Name // LIST-DAG: Canonicalizer -// LIST-DAG: Verifier // LIST-DAG: CSE // LIST-DAG: DominanceInfo // LIST: Total @@ -19,20 +18,15 @@ // PIPELINE-NEXT: 'func' Pipeline // PIPELINE-NEXT: CSE // PIPELINE-NEXT: (A) DominanceInfo -// PIPELINE-NEXT: Verifier // PIPELINE-NEXT: Canonicalizer -// PIPELINE-NEXT: Verifier // PIPELINE-NEXT: CSE // PIPELINE-NEXT: (A) DominanceInfo -// PIPELINE-NEXT: Verifier -// PIPELINE-NEXT: Verifier // PIPELINE-NEXT: Total // MT_LIST: Pass execution timing report // MT_LIST: Total Execution Time: // MT_LIST: Name // MT_LIST-DAG: Canonicalizer -// MT_LIST-DAG: Verifier // MT_LIST-DAG: CSE // MT_LIST-DAG: DominanceInfo // MT_LIST: Total @@ -43,13 +37,9 @@ // MT_PIPELINE-NEXT: 'func' Pipeline // MT_PIPELINE-NEXT: CSE // MT_PIPELINE-NEXT: (A) DominanceInfo -// MT_PIPELINE-NEXT: Verifier // MT_PIPELINE-NEXT: Canonicalizer -// MT_PIPELINE-NEXT: Verifier // MT_PIPELINE-NEXT: CSE // MT_PIPELINE-NEXT: (A) DominanceInfo -// MT_PIPELINE-NEXT: Verifier -// MT_PIPELINE-NEXT: Verifier // MT_PIPELINE-NEXT: Total // NESTED_MT_PIPELINE: Pass execution timing report diff --git a/mlir/test/Pass/pipeline-stats.mlir b/mlir/test/Pass/pipeline-stats.mlir --- a/mlir/test/Pass/pipeline-stats.mlir +++ b/mlir/test/Pass/pipeline-stats.mlir @@ -10,11 +10,8 @@ // PIPELINE: 'func' Pipeline // PIPELINE-NEXT: TestStatisticPass // PIPELINE-NEXT: (S) {{0|4}} num-ops - Number of operations counted -// PIPELINE-NEXT: Verifier // PIPELINE-NEXT: TestStatisticPass // PIPELINE-NEXT: (S) {{0|4}} num-ops - Number of operations counted -// PIPELINE-NEXT: Verifier -// PIPELINE-NEXT: Verifier func @foo() { return diff --git a/mlir/test/lib/Transforms/TestDynamicPipeline.cpp b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp --- a/mlir/test/lib/Transforms/TestDynamicPipeline.cpp +++ b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp @@ -25,7 +25,7 @@ : public PassWrapper> { public: void getDependentDialects(DialectRegistry ®istry) const override { - OpPassManager pm(ModuleOp::getOperationName(), false); + OpPassManager pm(ModuleOp::getOperationName()); parsePassPipeline(pipeline, pm, llvm::errs()); pm.getDependentDialects(registry); } @@ -54,7 +54,7 @@ } if (!pm) { pm = std::make_unique( - getOperation()->getName().getIdentifier(), false); + getOperation()->getName().getIdentifier()); parsePassPipeline(pipeline, *pm, llvm::errs()); }