diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -588,6 +588,9 @@ pipeline.back().options = text.substr(0, close); text = text.substr(close + 1); + // Consume space characters that an user might add for readability. + text = text.ltrim(); + // Skip checking for '(' because nested pipelines cannot have options. } else if (sep == '(') { text = text.substr(1); @@ -607,6 +610,8 @@ "parentheses while parsing pipeline"); pipelineStack.pop_back(); + // Consume space characters that an user might add for readability. + text = text.ltrim(); } // Check if we've finished parsing. @@ -703,6 +708,7 @@ FailureOr mlir::parsePassPipeline(StringRef pipeline, raw_ostream &errorStream) { + pipeline = pipeline.trim(); // Pipelines are expected to be of the form `()`. size_t pipelineStart = pipeline.find_first_of('('); if (pipelineStart == 0 || pipelineStart == StringRef::npos || @@ -712,7 +718,7 @@ return failure(); } - StringRef opName = pipeline.take_front(pipelineStart); + StringRef opName = pipeline.take_front(pipelineStart).rtrim(); OpPassManager pm(opName); if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm, errorStream))) diff --git a/mlir/test/Pass/pipeline-parsing.mlir b/mlir/test/Pass/pipeline-parsing.mlir --- a/mlir/test/Pass/pipeline-parsing.mlir +++ b/mlir/test/Pass/pipeline-parsing.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -mlir-disable-threading -pass-pipeline='builtin.module(builtin.module(test-module-pass,func.func(test-function-pass)),func.func(test-function-pass),func.func(cse,canonicalize))' -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s // RUN: mlir-opt %s -mlir-disable-threading -test-textual-pm-nested-pipeline -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s --check-prefix=TEXTUAL_CHECK // RUN: mlir-opt %s -mlir-disable-threading -pass-pipeline='builtin.module(builtin.module(test-module-pass),any(test-interface-pass),any(test-interface-pass),func.func(test-function-pass),any(canonicalize),func.func(cse))' -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s --check-prefix=GENERIC_MERGE_CHECK +// RUN: mlir-opt %s -mlir-disable-threading -pass-pipeline=' builtin.module ( builtin.module( func.func( test-function-pass, print-op-stats{ json=false } ) ) ) ' -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s --check-prefix=PIPELINE_STR_WITH_SPACES_CHECK // RUN: not mlir-opt %s -pass-pipeline='any(builtin.module(test-module-pass)' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_1 %s // RUN: not mlir-opt %s -pass-pipeline='builtin.module(test-module-pass))' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_2 %s // RUN: not mlir-opt %s -pass-pipeline='any(builtin.module()()' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s @@ -51,6 +52,11 @@ // TEXTUAL_CHECK-NEXT: 'func.func' Pipeline // TEXTUAL_CHECK-NEXT: TestFunctionPass +// PIPELINE_STR_WITH_SPACES_CHECK: 'builtin.module' Pipeline +// PIPELINE_STR_WITH_SPACES_CHECK-NEXT: 'func.func' Pipeline +// PIPELINE_STR_WITH_SPACES_CHECK-NEXT: TestFunctionPass +// PIPELINE_STR_WITH_SPACES_CHECK-NEXT: PrintOpStats + // Check that generic pass pipelines are only merged when they aren't // going to overlap with op-specific pipelines. // GENERIC_MERGE_CHECK: Pipeline Collection : ['builtin.module', 'any'] diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -59,6 +59,18 @@ log("Roundtrip: ", pm) run(testParseSuccess) +# Verify successful round-trip. +# CHECK-LABEL: TEST: testParseSpacedPipeline +def testParseSpacedPipeline(): + with Context(): + # A registered pass should parse successfully even if has extras spaces for readability + pm = PassManager.parse("""builtin.module( + func.func( print-op-stats{ json=false } ) + )""") + # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false})) + log("Roundtrip: ", pm) +run(testParseSpacedPipeline) + # Verify failure on unregistered pass. # CHECK-LABEL: TEST: testParseFail def testParseFail():