diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -118,6 +118,14 @@ /// documentation for the same method on the Pass class. void getDependentDialects(DialectRegistry &dialects) const; + /// Enable or disable the implicit nesting on this particular PassManager. + /// This will also apply to any newly nested PassManager built from this + /// instance. + void setNesting(Nesting nesting); + + /// Return the current nesting mode. + Nesting getNesting(); + private: /// A pointer to an internal implementation instance. std::unique_ptr impl; diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -28,8 +28,12 @@ /// A registry function that adds passes to the given pass manager. This should /// also parse options and return success() if parsing succeeded. -using PassRegistryFunction = - std::function; +/// `errorHandler` is a functor used to emit errors during parsing. +/// parameter corresponds to the raw location within the pipeline string. This +/// should always return failure. +using PassRegistryFunction = std::function errorHandler)>; using PassAllocatorFunction = std::function()>; //===----------------------------------------------------------------------===// @@ -43,10 +47,12 @@ /// Adds this pass registry entry to the given pass manager. `options` is /// an opaque string that will be parsed by the builder. The success of /// parsing will be returned. - LogicalResult addToPipeline(OpPassManager &pm, StringRef options) const { + LogicalResult + addToPipeline(OpPassManager &pm, StringRef options, + function_ref errorHandler) const { assert(builder && "cannot call addToPipeline on PassRegistryEntry without builder"); - return builder(pm, options); + return builder(pm, options, errorHandler); } /// Returns the command line option that may be passed to 'mlir-opt' that will @@ -163,7 +169,8 @@ std::function builder) { registerPassPipeline( arg, description, - [builder](OpPassManager &pm, StringRef optionsStr) { + [builder](OpPassManager &pm, StringRef optionsStr, + function_ref errorHandler) { Options options; if (failed(options.parseFromString(optionsStr))) return failure(); @@ -183,7 +190,8 @@ std::function builder) { registerPassPipeline( arg, description, - [builder](OpPassManager &pm, StringRef optionsStr) { + [builder](OpPassManager &pm, StringRef optionsStr, + function_ref errorHandler) { if (!optionsStr.empty()) return failure(); builder(pm); @@ -230,7 +238,9 @@ /// Adds the passes defined by this parser entry to the given pass manager. /// Returns failure() if the pass could not be properly constructed due /// to options parsing. - LogicalResult addToPipeline(OpPassManager &pm) const; + LogicalResult + addToPipeline(OpPassManager &pm, + function_ref errorHandler) const; private: std::unique_ptr impl; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -348,6 +348,10 @@ registerDialectsForPipeline(*this, dialects); } +OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; } + +void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; } + //===----------------------------------------------------------------------===// // OpToOpPassAdaptor //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -27,9 +27,16 @@ /// Utility to create a default registry function from a pass instance. static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator) { - return [=](OpPassManager &pm, StringRef options) { + return [=](OpPassManager &pm, StringRef options, + function_ref errorHandler) { std::unique_ptr pass = allocator(); LogicalResult result = pass->initializeOptions(options); + if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && + pass->getOpName() && *pass->getOpName() != pm.getOpName()) + return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() + + "' restricted to '" + *pass->getOpName() + + "' on a PassManager intended to run on '" + + pm.getOpName() + "', did you intend to nest?"); pm.addPass(std::move(pass)); return result; }; @@ -229,7 +236,9 @@ LogicalResult initialize(StringRef text, raw_ostream &errorStream); /// Add the internal pipeline elements to the provided pass manager. - LogicalResult addToPipeline(OpPassManager &pm) const; + LogicalResult + addToPipeline(OpPassManager &pm, + function_ref errorHandler) const; private: /// A functor used to emit errors found during pipeline handling. The first @@ -269,8 +278,9 @@ ErrorHandlerT errorHandler); /// Add the given pipeline elements to the provided pass manager. - LogicalResult addToPipeline(ArrayRef elements, - OpPassManager &pm) const; + LogicalResult + addToPipeline(ArrayRef elements, OpPassManager &pm, + function_ref errorHandler) const; std::vector pipeline; }; @@ -299,8 +309,10 @@ } /// Add the internal pipeline elements to the provided pass manager. -LogicalResult TextualPipeline::addToPipeline(OpPassManager &pm) const { - return addToPipeline(pipeline, pm); +LogicalResult TextualPipeline::addToPipeline( + OpPassManager &pm, + function_ref errorHandler) const { + return addToPipeline(pipeline, pm, errorHandler); } /// Parse the given pipeline text into the internal pipeline vector. This @@ -397,7 +409,6 @@ // pipeline. if (!element.innerPipeline.empty()) return resolvePipelineElements(element.innerPipeline, errorHandler); - // Otherwise, this must be a pass or pass pipeline. // Check to see if a pipeline was registered with this name. auto pipelineRegistryIt = passPipelineRegistry->find(element.name); @@ -422,13 +433,16 @@ } /// Add the given pipeline elements to the provided pass manager. -LogicalResult TextualPipeline::addToPipeline(ArrayRef elements, - OpPassManager &pm) const { +LogicalResult TextualPipeline::addToPipeline( + ArrayRef elements, OpPassManager &pm, + function_ref errorHandler) const { for (auto &elt : elements) { if (elt.registryEntry) { - if (failed(elt.registryEntry->addToPipeline(pm, elt.options))) + if (failed( + elt.registryEntry->addToPipeline(pm, elt.options, errorHandler))) return failure(); - } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name)))) { + } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name), + errorHandler))) { return failure(); } } @@ -444,7 +458,11 @@ TextualPipeline pipelineParser; if (failed(pipelineParser.initialize(pipeline, errorStream))) return failure(); - if (failed(pipelineParser.addToPipeline(pm))) + auto errorHandler = [&](Twine msg) { + errorStream << msg << "\n"; + return failure(); + }; + if (failed(pipelineParser.addToPipeline(pm, errorHandler))) return failure(); return success(); } @@ -634,13 +652,21 @@ } /// Adds the passes defined by this parser entry to the given pass manager. -LogicalResult PassPipelineCLParser::addToPipeline(OpPassManager &pm) const { +LogicalResult PassPipelineCLParser::addToPipeline( + OpPassManager &pm, + function_ref errorHandler) const { for (auto &passIt : impl->passList) { if (passIt.registryEntry) { - if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options))) + if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options, + errorHandler))) + return failure(); + } else { + OpPassManager::Nesting nesting = pm.getNesting(); + pm.setNesting(OpPassManager::Nesting::Explicit); + LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler); + pm.setNesting(nesting); + if (failed(status)) return failure(); - } else if (failed(passIt.pipeline.addToPipeline(pm))) { - return failure(); } } return success(); diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -62,8 +62,13 @@ pm.enableVerifier(verifyPasses); applyPassManagerCLOptions(pm); + auto errorHandler = [&](const Twine &msg) { + emitError(UnknownLoc::get(context)) << msg; + return failure(); + }; + // Build the provided pipeline. - if (failed(passPipeline.addToPipeline(pm))) + if (failed(passPipeline.addToPipeline(pm, errorHandler))) return failure(); // Run the pipeline. diff --git a/mlir/test/Bindings/Python/pass_manager.py b/mlir/test/Bindings/Python/pass_manager.py --- a/mlir/test/Bindings/Python/pass_manager.py +++ b/mlir/test/Bindings/Python/pass_manager.py @@ -66,6 +66,21 @@ run(testParseFail) +# Verify failure on incorrect level of nesting. +# CHECK-LABEL: TEST: testInvalidNesting +def testInvalidNesting(): + with Context(): + try: + pm = PassManager.parse("func(print-op-graph)") + except ValueError as e: + # CHECK: Can't add pass 'PrintOp' restricted to 'module' on a PassManager intended to run on 'func', did you intend to nest? + # CHECK: ValueError exception: invalid pass pipeline 'func(print-op-graph)'. + log("ValueError exception:", e) + else: + log("Exception not produced") +run(testInvalidNesting) + + # Verify that a pass manager can execute on IR # CHECK-LABEL: TEST: testRun def testRunPipeline(): diff --git a/mlir/test/Pass/pipeline-options-parsing.mlir b/mlir/test/Pass/pipeline-options-parsing.mlir --- a/mlir/test/Pass/pipeline-options-parsing.mlir +++ b/mlir/test/Pass/pipeline-options-parsing.mlir @@ -1,11 +1,11 @@ // RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass{)' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_1 %s // RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass{test-option=3})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_2 %s -// RUN: not mlir-opt %s -pass-pipeline='module(test-options-pass{list=3}, test-module-pass{invalid-option=3})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s +// RUN: not mlir-opt %s -pass-pipeline='module(func(test-options-pass{list=3}), test-module-pass{invalid-option=3})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s // RUN: not mlir-opt %s -pass-pipeline='test-options-pass{list=3 list=notaninteger}' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_4 %s // RUN: not mlir-opt %s -pass-pipeline='func(test-options-pass{list=1,2,3,4 list=5 string=value1 string=value2})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_5 %s // RUN: mlir-opt %s -verify-each=false -pass-pipeline='func(test-options-pass{string-list=a list=1,2,3,4 string-list=b,c list=5 string-list=d string=some_value})' -test-dump-pipeline 2>&1 | FileCheck --check-prefix=CHECK_1 %s // RUN: mlir-opt %s -verify-each=false -test-options-pass-pipeline='list=1 string-list=a,b' -test-dump-pipeline 2>&1 | FileCheck --check-prefix=CHECK_2 %s -// RUN: mlir-opt %s -verify-each=false -pass-pipeline='module(test-options-pass{list=3}, test-options-pass{list=1,2,3,4})' -test-dump-pipeline 2>&1 | FileCheck --check-prefix=CHECK_3 %s +// RUN: mlir-opt %s -verify-each=false -pass-pipeline='module(func(test-options-pass{list=3}), func(test-options-pass{list=1,2,3,4}))' -test-dump-pipeline 2>&1 | FileCheck --check-prefix=CHECK_3 %s // CHECK_ERROR_1: missing closing '}' while processing pass options // CHECK_ERROR_2: no such option test-option diff --git a/mlir/test/Pass/pipeline-parsing.mlir b/mlir/test/Pass/pipeline-parsing.mlir --- a/mlir/test/Pass/pipeline-parsing.mlir +++ b/mlir/test/Pass/pipeline-parsing.mlir @@ -4,12 +4,13 @@ // RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass))' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_2 %s // RUN: not mlir-opt %s -pass-pipeline='module()(' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s // RUN: not mlir-opt %s -pass-pipeline=',' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_4 %s +// RUN: not mlir-opt %s -pass-pipeline='func(test-module-pass)' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_5 %s // CHECK_ERROR_1: encountered unbalanced parentheses while parsing pipeline // CHECK_ERROR_2: encountered extra closing ')' creating unbalanced parentheses while parsing pipeline // CHECK_ERROR_3: expected ',' after parsing pipeline // CHECK_ERROR_4: does not refer to a registered pass or pass pipeline - +// CHECK_ERROR_5: Can't add pass '{{.*}}TestModulePass' restricted to 'module' on a PassManager intended to run on 'func', did you intend to nest? func @foo() { return }