diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -197,13 +197,21 @@ } }; -/// This function parses the textual representation of a pass pipeline, and adds -/// the result to 'pm' on success. This function returns failure if the given -/// pipeline was invalid. 'errorStream' is the output stream used to emit errors -/// found during parsing. +/// Parse the textual representation of a pass pipeline, adding the result to +/// 'pm' on success. Returns failure if the given pipeline was invalid. +/// 'errorStream' is the output stream used to emit errors found during parsing. LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream = llvm::errs()); +/// Parse the given textual representation of a pass pipeline, and return the +/// parsed pipeline on success. The given pipeline string should be wrapped with +/// the desired type of operation to root the created operation, i.e. +/// `builtin.module(cse)` over `cse`. Returns failure if the given pipeline was +/// invalid. 'errorStream' is the output stream used to emit errors found during +/// parsing. +FailureOr +parsePassPipeline(StringRef pipeline, raw_ostream &errorStream = llvm::errs()); + //===----------------------------------------------------------------------===// // PassPipelineCLParser //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -482,10 +482,6 @@ return success(); } -/// This function parses the textual representation of a pass pipeline, and adds -/// the result to 'pm' on success. This function returns failure if the given -/// pipeline was invalid. 'errorStream' is an optional parameter that, if -/// non-null, will be used to emit errors found during parsing. LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream) { TextualPipeline pipelineParser; @@ -500,6 +496,23 @@ return success(); } +FailureOr mlir::parsePassPipeline(StringRef pipeline, + raw_ostream &errorStream) { + // Pipelines are expected to be of the form `()`. + size_t pipelineStart = pipeline.find_first_of('('); + if (pipelineStart == StringRef::npos || !pipeline.consume_back(")")) { + errorStream << "expected pass pipeline to be wrapped with the anchor " + "operation type, e.g. `builtin.module(...)"; + return failure(); + } + + StringRef opName = pipeline.take_front(pipelineStart); + OpPassManager pm(opName); + if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm))) + return failure(); + return pm; +} + //===----------------------------------------------------------------------===// // PassNameParser //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -752,16 +752,10 @@ // Skip empty pipelines. if (pipeline.empty()) continue; - - // Pipelines are expected to be of the form `()`. - size_t pipelineStart = pipeline.find_first_of('('); - if (pipelineStart == StringRef::npos || !pipeline.consume_back(")")) - return failure(); - StringRef opName = pipeline.take_front(pipelineStart); - OpPassManager pm(opName); - if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm))) + FailureOr pm = parsePassPipeline(pipeline); + if (failed(pm)) return failure(); - pipelines.try_emplace(opName, std::move(pm)); + pipelines.try_emplace(pm->getOpName(), std::move(*pm)); } opPipelines.assign({std::move(pipelines)});