diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -123,10 +123,12 @@ MlirStringCallback callback, void *userData); -/// Parse a textual MLIR pass pipeline and add it to the provided OpPassManager. - +/// Parse a textual MLIR pass pipeline and assign it to the provided +/// OpPassManager. If parsing fails an error message is reported using the +/// provided callback. MLIR_CAPI_EXPORTED MlirLogicalResult -mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline); +mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, + MlirStringCallback callback, void *userData); //===----------------------------------------------------------------------===// // External Pass API. diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -86,10 +86,14 @@ } MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, - MlirStringRef pipeline) { - // TODO: errors are sent to std::errs() at the moment, we should pass in a - // stream and redirect to a diagnostic. - return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager))); + MlirStringRef pipeline, + MlirStringCallback callback, + void *userData) { + detail::CallbackOstream stream(callback, userData); + FailureOr pm = parsePassPipeline(unwrap(pipeline), stream); + if (succeeded(pm)) + *unwrap(passManager) = std::move(*pm); + return wrap(pm); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -182,7 +182,8 @@ MlirLogicalResult status = mlirParsePassPipeline( mlirPassManagerGetAsOpPassManager(pm), mlirStringRefCreateFromCString( - "builtin.module(func.func(print-op-stats{json=false}))")); + "builtin.module(func.func(print-op-stats{json=false}))"), + printToStderr, NULL); // Expect a failure, we haven't registered the print-op-stats pass yet. if (mlirLogicalResultIsSuccess(status)) { fprintf( @@ -195,7 +196,8 @@ status = mlirParsePassPipeline( mlirPassManagerGetAsOpPassManager(pm), mlirStringRefCreateFromCString( - "builtin.module(func.func(print-op-stats{json=false}))")); + "builtin.module(func.func(print-op-stats{json=false}))"), + printToStderr, NULL); // Expect a failure, we haven't registered the print-op-stats pass yet. if (mlirLogicalResultIsFailure(status)) { fprintf(stderr, @@ -203,9 +205,7 @@ exit(EXIT_FAILURE); } - // CHECK: Round-trip: builtin.module( - // CHECK-SAME: builtin.module(func.func(print-op-stats{json=false})) - // CHECK-SAME: ) + // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false})) fprintf(stderr, "Round-trip: "); mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, NULL); @@ -221,7 +221,7 @@ exit(EXIT_FAILURE); } // CHECK: Appended: builtin.module( - // CHECK-SAME: builtin.module(func.func(print-op-stats{json=false})), + // CHECK-SAME: func.func(print-op-stats{json=false}), // CHECK-SAME: func.func(print-op-stats{json=false}) // CHECK-SAME: ) fprintf(stderr, "Appended: "); @@ -242,6 +242,14 @@ MlirOpPassManager opm = mlirPassManagerGetAsOpPassManager(pm); MlirStringRef invalidPipeline = mlirStringRefCreateFromCString("invalid"); + // CHECK: mlirParsePassPipeline: + // CHECK: expected pass pipeline to be wrapped with the anchor operation type + fprintf(stderr, "mlirParsePassPipeline:\n"); + if (mlirLogicalResultIsSuccess( + mlirParsePassPipeline(opm, invalidPipeline, printToStderr, NULL))) + exit(EXIT_FAILURE); + fprintf(stderr, "\n"); + // CHECK: mlirOpPassManagerAddPipeline: // CHECK: 'invalid' does not refer to a registered pass or pass pipeline fprintf(stderr, "mlirOpPassManagerAddPipeline:\n"); @@ -253,6 +261,9 @@ // Make sure all output is going through the callback. // CHECK: dontPrint: <> fprintf(stderr, "dontPrint: <"); + if (mlirLogicalResultIsSuccess( + mlirParsePassPipeline(opm, invalidPipeline, dontPrint, NULL))) + exit(EXIT_FAILURE); if (mlirLogicalResultIsSuccess( mlirOpPassManagerAddPipeline(opm, invalidPipeline, dontPrint, NULL))) exit(EXIT_FAILURE);