diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -105,6 +105,13 @@ MLIR_CAPI_EXPORTED void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, MlirPass pass); +/// Parse a sequence of textual MLIR pass pipeline elements and add them to the +/// provided OpPassManager. If parsing fails an error message is reported using +/// the provided callback. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirOpPassManagerAddPipeline( + MlirOpPassManager passManager, MlirStringRef pipelineElements, + MlirStringCallback callback, void *userData); + /// Print a textual MLIR pass pipeline by sending chunks of the string /// representation and forwarding `userData to `callback`. Note that the /// callback may be called several times with consecutive chunks of the string. diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -82,15 +82,15 @@ py::arg("enable"), "Enable / disable verify-each.") .def_static( "parse", - [](const std::string pipeline, DefaultingPyMlirContext context) { + [](const std::string &pipeline, DefaultingPyMlirContext context) { MlirPassManager passManager = mlirPassManagerCreate(context->get()); - MlirLogicalResult status = mlirParsePassPipeline( + PyPrintAccumulator errorMsg; + MlirLogicalResult status = mlirOpPassManagerAddPipeline( mlirPassManagerGetAsOpPassManager(passManager), - mlirStringRefCreate(pipeline.data(), pipeline.size())); + mlirStringRefCreate(pipeline.data(), pipeline.size()), + errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw SetPyError(PyExc_ValueError, - llvm::Twine("invalid pass pipeline '") + - pipeline + "'."); + throw SetPyError(PyExc_ValueError, std::string(errorMsg.join())); return new PyPassManager(passManager); }, py::arg("pipeline"), py::arg("context") = py::none(), diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -65,6 +65,15 @@ unwrap(passManager)->addPass(std::unique_ptr(unwrap(pass))); } +MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, + MlirStringRef pipelineElements, + MlirStringCallback callback, + void *userData) { + detail::CallbackOstream stream(callback, userData); + return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager), + stream)); +} + void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData) { detail::CallbackOstream stream(callback, userData); diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -176,8 +176,7 @@ MlirLogicalResult status = mlirParsePassPipeline( mlirPassManagerGetAsOpPassManager(pm), mlirStringRefCreateFromCString( - "builtin.module(func.func(print-op-stats{json=false})," - " func.func(print-op-stats{json=false}))")); + "builtin.module(func.func(print-op-stats{json=false}))")); // Expect a failure, we haven't registered the print-op-stats pass yet. if (mlirLogicalResultIsSuccess(status)) { fprintf( @@ -190,8 +189,7 @@ status = mlirParsePassPipeline( mlirPassManagerGetAsOpPassManager(pm), mlirStringRefCreateFromCString( - "builtin.module(func.func(print-op-stats{json=false})," - " func.func(print-op-stats{json=false}))")); + "builtin.module(func.func(print-op-stats{json=false}))")); // Expect a failure, we haven't registered the print-op-stats pass yet. if (mlirLogicalResultIsFailure(status)) { fprintf(stderr, @@ -199,14 +197,32 @@ exit(EXIT_FAILURE); } - // CHECK: Round-trip: builtin.module(builtin.module( - // CHECK-SAME: func.func(print-op-stats{json=false}), - // CHECK-SAME: func.func(print-op-stats{json=false}) - // CHECK-SAME: )) + // CHECK: Round-trip: builtin.module( + // CHECK-SAME: builtin.module(func.func(print-op-stats{json=false})) + // CHECK-SAME: ) fprintf(stderr, "Round-trip: "); mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, NULL); fprintf(stderr, "\n"); + + // Try appending a pass: + status = mlirOpPassManagerAddPipeline( + mlirPassManagerGetAsOpPassManager(pm), + mlirStringRefCreateFromCString("func.func(print-op-stats{json=false})"), + printToStderr, NULL); + if (mlirLogicalResultIsFailure(status)) { + fprintf(stderr, "Unexpected failure appending pipeline\n"); + exit(EXIT_FAILURE); + } + // CHECK: Appended: builtin.module( + // CHECK-SAME: builtin.module(func.func(print-op-stats{json=false})), + // CHECK-SAME: func.func(print-op-stats{json=false}) + // CHECK-SAME: ) + fprintf(stderr, "Appended: "); + mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, + NULL); + fprintf(stderr, "\n"); + mlirPassManagerDestroy(pm); mlirContextDestroy(ctx); } diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -36,10 +36,8 @@ # An unregistered pass should not parse. try: pm = PassManager.parse("builtin.module(func.func(not-existing-pass{json=false}))") - # TODO: this error should be propagate to Python but the C API does not help right now. - # CHECK: error: 'not-existing-pass' does not refer to a registered pass or pass pipeline except ValueError as e: - # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(not-existing-pass{json=false}))'. + # CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass log("ValueError exception:", e) else: log("Exception not produced") @@ -57,7 +55,10 @@ try: pm = PassManager.parse("unknown-pass") except ValueError as e: - # CHECK: ValueError exception: invalid pass pipeline 'unknown-pass'. + # CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error: + # CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline + # CHECK: unknown-pass + # CHECK: ^ log("ValueError exception:", e) else: log("Exception not produced") @@ -71,8 +72,7 @@ try: pm = PassManager.parse("func.func(normalize-memrefs)") except ValueError as e: - # CHECK: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest? - # CHECK: ValueError exception: invalid pass pipeline 'func.func(normalize-memrefs)'. + # CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest? log("ValueError exception:", e) else: log("Exception not produced")