diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -140,13 +140,15 @@ runtime. In these situations, a pass may override the following hook to initialize this heavy state: -* `void initialize(MLIRContext *context)` +* `LogicalResult initialize(MLIRContext *context)` This hook is executed once per run of a full pass pipeline, meaning that it does not have access to the state available during a `runOnOperation` call. More concretely, all necessary accesses to an `MLIRContext` should be driven via the provided `context` parameter, and methods that utilize "per-run" state such as `getContext`/`getOperation`/`getAnalysis`/etc. must not be used. +In case of error during initialization, the pass is expected to emit an error +and return a `failure()` which will abort the pass pipeline execution. ## Analysis Management diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -170,7 +170,9 @@ /// should not rely on any state accessible during the execution of a pass. /// For example, `getContext`/`getOperation`/`getAnalysis`/etc. should not be /// invoked within this hook. - virtual void initialize(MLIRContext *context) {} + /// Returns a LogicalResult to indicate failure, in which case the pass + /// pipeline won't execute. + virtual LogicalResult initialize(MLIRContext *context) { return success(); } /// Schedule an arbitrary pass pipeline on the provided operation. /// This can be invoke any time in a pass to dynamic schedule more passes. diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -130,7 +130,7 @@ /// Initialize all of the passes within this pass manager with the given /// initialization generation. The initialization generation is used to detect /// if a pass manager has already been initialized. - void initialize(MLIRContext *context, unsigned newInitGeneration); + LogicalResult initialize(MLIRContext *context, unsigned newInitGeneration); /// A pointer to an internal implementation instance. std::unique_ptr impl; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -331,23 +331,26 @@ OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; } -void OpPassManager::initialize(MLIRContext *context, - unsigned newInitGeneration) { +LogicalResult OpPassManager::initialize(MLIRContext *context, + unsigned newInitGeneration) { if (impl->initializationGeneration == newInitGeneration) - return; + return success(); impl->initializationGeneration = newInitGeneration; for (Pass &pass : getPasses()) { // If this pass isn't an adaptor, directly initialize it. auto *adaptor = dyn_cast(&pass); if (!adaptor) { - pass.initialize(context); + if (failed(pass.initialize(context))) + return failure(); continue; } // Otherwise, initialize each of the adaptors pass managers. for (OpPassManager &adaptorPM : adaptor->getPassManagers()) - adaptorPM.initialize(context, newInitGeneration); + if (failed(adaptorPM.initialize(context, newInitGeneration))) + return failure(); } + return success(); } //===----------------------------------------------------------------------===// @@ -379,7 +382,8 @@ assert(pipeline.getOpName() == root->getName().getStringRef()); // Initialize the user provided pipeline and execute the pipeline. - pipeline.initialize(root->getContext(), parentInitGeneration); + if (failed(pipeline.initialize(root->getContext(), parentInitGeneration))) + return failure(); AnalysisManager nestedAm = root == op ? am : am.nest(root); return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm, verifyPasses, parentInitGeneration, @@ -872,7 +876,8 @@ // Initialize all of the passes within the pass manager with a new generation. llvm::hash_code newInitKey = context->getRegistryHash(); if (newInitKey != initializationKey) { - initialize(context, impl->initializationGeneration + 1); + if (failed(initialize(context, impl->initializationGeneration + 1))) + return failure(); initializationKey = newInitKey; } diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -23,11 +23,12 @@ struct Canonicalizer : public CanonicalizerBase { /// Initialize the canonicalizer by building the set of patterns used during /// execution. - void initialize(MLIRContext *context) override { + LogicalResult initialize(MLIRContext *context) override { OwningRewritePatternList owningPatterns; for (auto *op : context->getRegisteredOperations()) op->getCanonicalizationPatterns(owningPatterns, context); patterns = std::move(owningPatterns); + return success(); } void runOnOperation() override { (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns);