diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -197,13 +197,21 @@ } }; -/// This function parses the textual representation of a pass pipeline, and adds -/// the result to 'pm' on success. This function returns failure if the given -/// pipeline was invalid. 'errorStream' is the output stream used to emit errors -/// found during parsing. +/// Parse the textual representation of a pass pipeline, adding the result to +/// 'pm' on success. Returns failure if the given pipeline was invalid. +/// 'errorStream' is the output stream used to emit errors found during parsing. LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream = llvm::errs()); +/// Parse the given textual representation of a pass pipeline, and return the +/// parsed pipeline on success. The given pipeline string should be wrapped with +/// the desired type of operation to root the created operation, i.e. +/// `builtin.module(cse)` over `cse`. Returns failure if the given pipeline was +/// invalid. 'errorStream' is the output stream used to emit errors found during +/// parsing. +FailureOr +parsePassPipeline(StringRef pipeline, raw_ostream &errorStream = llvm::errs()); + //===----------------------------------------------------------------------===// // PassPipelineCLParser //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -482,10 +482,6 @@ return success(); } -/// This function parses the textual representation of a pass pipeline, and adds -/// the result to 'pm' on success. This function returns failure if the given -/// pipeline was invalid. 'errorStream' is an optional parameter that, if -/// non-null, will be used to emit errors found during parsing. LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream) { TextualPipeline pipelineParser; @@ -500,6 +496,24 @@ return success(); } +FailureOr mlir::parsePassPipeline(StringRef pipeline, + raw_ostream &errorStream) { + // Pipelines are expected to be of the form `()`. + size_t pipelineStart = pipeline.find_first_of('('); + if (pipelineStart == 0 || pipelineStart == StringRef::npos || + !pipeline.consume_back(")")) { + errorStream << "expected pass pipeline to be wrapped with the anchor " + "operation type, e.g. `builtin.module(...)"; + return failure(); + } + + StringRef opName = pipeline.take_front(pipelineStart); + OpPassManager pm(opName); + if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm))) + return failure(); + return pm; +} + //===----------------------------------------------------------------------===// // PassNameParser //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -754,16 +754,10 @@ // Skip empty pipelines. if (pipeline.empty()) continue; - - // Pipelines are expected to be of the form `()`. - size_t pipelineStart = pipeline.find_first_of('('); - if (pipelineStart == StringRef::npos || !pipeline.consume_back(")")) - return failure(); - StringRef opName = pipeline.take_front(pipelineStart); - OpPassManager pm(opName); - if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm))) + FailureOr pm = parsePassPipeline(pipeline); + if (failed(pm)) return failure(); - pipelines.try_emplace(opName, std::move(pm)); + pipelines.try_emplace(pm->getOpName(), std::move(*pm)); } opPipelines.assign({std::move(pipelines)}); diff --git a/mlir/unittests/Pass/CMakeLists.txt b/mlir/unittests/Pass/CMakeLists.txt --- a/mlir/unittests/Pass/CMakeLists.txt +++ b/mlir/unittests/Pass/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_unittest(MLIRPassTests AnalysisManagerTest.cpp PassManagerTest.cpp + PassPipelineParserTest.cpp ) target_link_libraries(MLIRPassTests PRIVATE diff --git a/mlir/unittests/Pass/PassPipelineParserTest.cpp b/mlir/unittests/Pass/PassPipelineParserTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Pass/PassPipelineParserTest.cpp @@ -0,0 +1,45 @@ +//===- PassPipelineParserTest.cpp - Pass Parser unit tests ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include "llvm/Support/raw_ostream.h" +#include "gtest/gtest.h" + +#include + +using namespace mlir; +using namespace mlir::detail; + +namespace { +TEST(PassPipelineParserTest, InvalidOpAnchor) { + // Helper functor used to parse a pipeline and check that it results in the + // provided error message. + auto checkParseFailure = [](StringRef pipeline, StringRef expectedErrorMsg) { + std::string errorMsg; + { + llvm::raw_string_ostream os(errorMsg); + FailureOr result = parsePassPipeline(pipeline, os); + EXPECT_TRUE(failed(result)); + } + EXPECT_TRUE(StringRef(errorMsg).contains(expectedErrorMsg)); + }; + + // Handle parse errors when the anchor is incorrectly structured. + StringRef anchorErrorMsg = + "expected pass pipeline to be wrapped with the anchor operation type"; + checkParseFailure("module", anchorErrorMsg); + checkParseFailure("()", anchorErrorMsg); + checkParseFailure("module(", anchorErrorMsg); + checkParseFailure("module)", anchorErrorMsg); +} + +} // namespace