diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -105,6 +105,13 @@ MLIR_CAPI_EXPORTED void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, MlirPass pass); +/// Parse a sequence of textual MLIR pass pipeline elements and add them to the +/// provided OpPassManager. If parsing fails an error message is reported using +/// the provided callback. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirOpPassManagerAddPipeline( + MlirOpPassManager passManager, MlirStringRef pipelineElements, + MlirStringCallback callback, void *userData); + /// Print a textual MLIR pass pipeline by sending chunks of the string /// representation and forwarding `userData to `callback`. Note that the /// callback may be called several times with consecutive chunks of the string. diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -82,15 +82,15 @@ py::arg("enable"), "Enable / disable verify-each.") .def_static( "parse", - [](const std::string pipeline, DefaultingPyMlirContext context) { + [](const std::string &pipeline, DefaultingPyMlirContext context) { MlirPassManager passManager = mlirPassManagerCreate(context->get()); - MlirLogicalResult status = mlirParsePassPipeline( + PyPrintAccumulator errorMsg; + MlirLogicalResult status = mlirOpPassManagerAddPipeline( mlirPassManagerGetAsOpPassManager(passManager), - mlirStringRefCreate(pipeline.data(), pipeline.size())); + mlirStringRefCreate(pipeline.data(), pipeline.size()), + errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw SetPyError(PyExc_ValueError, - llvm::Twine("invalid pass pipeline '") + - pipeline + "'."); + throw SetPyError(PyExc_ValueError, std::string(errorMsg.join())); return new PyPassManager(passManager); }, py::arg("pipeline"), py::arg("context") = py::none(), diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -65,6 +65,15 @@ unwrap(passManager)->addPass(std::unique_ptr(unwrap(pass))); } +MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, + MlirStringRef pipelineElements, + MlirStringCallback callback, + void *userData) { + detail::CallbackOstream stream(callback, userData); + return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager), + stream)); +} + void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData) { detail::CallbackOstream stream(callback, userData); diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -133,6 +133,11 @@ fwrite(str.data, 1, str.length, stderr); } +static void dontPrint(MlirStringRef str, void *userData) { + (void)str; + (void)userData; +} + void testPrintPassPipeline() { MlirContext ctx = mlirContextCreate(); MlirPassManager pm = mlirPassManagerCreate(ctx); @@ -176,8 +181,7 @@ MlirLogicalResult status = mlirParsePassPipeline( mlirPassManagerGetAsOpPassManager(pm), mlirStringRefCreateFromCString( - "builtin.module(func.func(print-op-stats{json=false})," - " func.func(print-op-stats{json=false}))")); + "builtin.module(func.func(print-op-stats{json=false}))")); // Expect a failure, we haven't registered the print-op-stats pass yet. if (mlirLogicalResultIsSuccess(status)) { fprintf( @@ -190,8 +194,7 @@ status = mlirParsePassPipeline( mlirPassManagerGetAsOpPassManager(pm), mlirStringRefCreateFromCString( - "builtin.module(func.func(print-op-stats{json=false})," - " func.func(print-op-stats{json=false}))")); + "builtin.module(func.func(print-op-stats{json=false}))")); // Expect a failure, we haven't registered the print-op-stats pass yet. if (mlirLogicalResultIsFailure(status)) { fprintf(stderr, @@ -199,14 +202,61 @@ exit(EXIT_FAILURE); } - // CHECK: Round-trip: builtin.module(builtin.module( - // CHECK-SAME: func.func(print-op-stats{json=false}), - // CHECK-SAME: func.func(print-op-stats{json=false}) - // CHECK-SAME: )) + // CHECK: Round-trip: builtin.module( + // CHECK-SAME: builtin.module(func.func(print-op-stats{json=false})) + // CHECK-SAME: ) fprintf(stderr, "Round-trip: "); mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, NULL); fprintf(stderr, "\n"); + + // Try appending a pass: + status = mlirOpPassManagerAddPipeline( + mlirPassManagerGetAsOpPassManager(pm), + mlirStringRefCreateFromCString("func.func(print-op-stats{json=false})"), + printToStderr, NULL); + if (mlirLogicalResultIsFailure(status)) { + fprintf(stderr, "Unexpected failure appending pipeline\n"); + exit(EXIT_FAILURE); + } + // CHECK: Appended: builtin.module( + // CHECK-SAME: builtin.module(func.func(print-op-stats{json=false})), + // CHECK-SAME: func.func(print-op-stats{json=false}) + // CHECK-SAME: ) + fprintf(stderr, "Appended: "); + mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, + NULL); + fprintf(stderr, "\n"); + + mlirPassManagerDestroy(pm); + mlirContextDestroy(ctx); +} + +void testParseErrorCapture() { + // CHECK-LABEL: testParseErrorCapture: + fprintf(stderr, "\nTEST: testParseErrorCapture:\n"); + + MlirContext ctx = mlirContextCreate(); + MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirOpPassManager opm = mlirPassManagerGetAsOpPassManager(pm); + MlirStringRef invalidPipeline = mlirStringRefCreateFromCString("invalid"); + + // CHECK: mlirOpPassManagerAddPipeline: + // CHECK: 'invalid' does not refer to a registered pass or pass pipeline + fprintf(stderr, "mlirOpPassManagerAddPipeline:\n"); + if (mlirLogicalResultIsSuccess(mlirOpPassManagerAddPipeline( + opm, invalidPipeline, printToStderr, NULL))) + exit(EXIT_FAILURE); + fprintf(stderr, "\n"); + + // Make sure all output is going through the callback. + // CHECK: dontPrint: <> + fprintf(stderr, "dontPrint: <"); + if (mlirLogicalResultIsSuccess( + mlirOpPassManagerAddPipeline(opm, invalidPipeline, dontPrint, NULL))) + exit(EXIT_FAILURE); + fprintf(stderr, ">\n"); + mlirPassManagerDestroy(pm); mlirContextDestroy(ctx); } @@ -534,6 +584,7 @@ testRunPassOnNestedModule(); testPrintPassPipeline(); testParsePassPipeline(); + testParseErrorCapture(); testExternalPass(); return 0; } diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -36,10 +36,8 @@ # An unregistered pass should not parse. try: pm = PassManager.parse("builtin.module(func.func(not-existing-pass{json=false}))") - # TODO: this error should be propagate to Python but the C API does not help right now. - # CHECK: error: 'not-existing-pass' does not refer to a registered pass or pass pipeline except ValueError as e: - # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(not-existing-pass{json=false}))'. + # CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass log("ValueError exception:", e) else: log("Exception not produced") @@ -57,7 +55,10 @@ try: pm = PassManager.parse("unknown-pass") except ValueError as e: - # CHECK: ValueError exception: invalid pass pipeline 'unknown-pass'. + # CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error: + # CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline + # CHECK: unknown-pass + # CHECK: ^ log("ValueError exception:", e) else: log("Exception not produced") @@ -71,8 +72,7 @@ try: pm = PassManager.parse("func.func(normalize-memrefs)") except ValueError as e: - # CHECK: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest? - # CHECK: ValueError exception: invalid pass pipeline 'func.func(normalize-memrefs)'. + # CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest? log("ValueError exception:", e) else: log("Exception not produced")