diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -53,6 +53,10 @@ /** Destroy the provided PassManager. */ void mlirPassManagerDestroy(MlirPassManager passManager); +/** Cast a top-level PassManager to a generic OpPassManager. */ +MlirOpPassManager +mlirPassManagerGetAsOpPassManager(MlirPassManager passManager); + /** Run the provided `passManager` on the given `module`. */ MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager, MlirModule module); @@ -83,6 +87,17 @@ void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, MlirPass pass); +/** Print a textual MLIR pass pipeline by sending chunks of the string + * representation and forwarding `userData to `callback`. Note that the callback + * may be called several times with consecutive chunks of the string. */ +void mlirPrintPassPipeline(MlirOpPassManager passManager, + MlirStringCallback callback, void *userData); + +/** Parse a textual MLIR pass pipeline and add it to the provided OpPassManager. + */ +MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, + MlirStringRef pipeline); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -28,6 +28,11 @@ delete unwrap(passManager); } +MlirOpPassManager +mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) { + return wrap(static_cast(unwrap(passManager))); +} + MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager, MlirModule module) { return wrap(unwrap(passManager)->run(unwrap(module))); @@ -51,3 +56,16 @@ MlirPass pass) { unwrap(passManager)->addPass(std::unique_ptr(unwrap(pass))); } + +void mlirPrintPassPipeline(MlirOpPassManager passManager, + MlirStringCallback callback, void *userData) { + detail::CallbackOstream stream(callback, userData); + unwrap(passManager)->printAsTextualPipeline(stream); +} + +MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, + MlirStringRef pipeline) { + // TODO: errors are sent to std::errs() at the moment, we should pass in a + // stream and redirect to a diagnostic. + return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager))); +} diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -33,8 +33,10 @@ " return %res : i32 \n" "}"); // clang-format on - if (mlirModuleIsNull(module)) + if (mlirModuleIsNull(module)) { + fprintf(stderr, "Unexpected failure parsing module.\n"); exit(EXIT_FAILURE); + } // Run the print-op-stats pass on the top-level module: // CHECK-LABEL: Operations encountered: @@ -47,8 +49,10 @@ MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); mlirPassManagerAddOwnedPass(pm, printOpStatPass); MlirLogicalResult success = mlirPassManagerRun(pm, module); - if (mlirLogicalResultIsFailure(success)) + if (mlirLogicalResultIsFailure(success)) { + fprintf(stderr, "Unexpected failure running pass manager.\n"); exit(EXIT_FAILURE); + } mlirPassManagerDestroy(pm); } mlirModuleDestroy(module); @@ -117,8 +121,81 @@ mlirContextDestroy(ctx); } +static void printToStderr(const char *str, intptr_t len, void *userData) { + (void)userData; + fwrite(str, 1, len, stderr); +} + +void testPrintPassPipeline() { + MlirContext ctx = mlirContextCreate(); + MlirPassManager pm = mlirPassManagerCreate(ctx); + // Populate the pass-manager + MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder( + pm, mlirStringRefCreateFromCString("module")); + MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder( + nestedModulePm, mlirStringRefCreateFromCString("func")); + MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); + mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass); + + // Print the top level pass manager + // CHECK: Top-level: module(func(print-op-stats)) + fprintf(stderr, "Top-level: "); + mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, + NULL); + fprintf(stderr, "\n"); + + // Print the pipeline nested one level down + // CHECK: Nested Module: func(print-op-stats) + fprintf(stderr, "Nested Module: "); + mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL); + fprintf(stderr, "\n"); + + // Print the pipeline nested two levels down + // CHECK: Nested Module>Func: print-op-stats + fprintf(stderr, "Nested Module>Func: "); + mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL); + fprintf(stderr, "\n"); + + mlirPassManagerDestroy(pm); + mlirContextDestroy(ctx); +} + +void testParsePassPipeline() { + MlirContext ctx = mlirContextCreate(); + MlirPassManager pm = mlirPassManagerCreate(ctx); + // Try parse a pipeline. + MlirLogicalResult status = mlirParsePassPipeline( + mlirPassManagerGetAsOpPassManager(pm), + mlirStringRefCreateFromCString( + "module(func(print-op-stats), func(print-op-stats))")); + // Expect a failure, we haven't registered the print-op-stats pass yet. + if (mlirLogicalResultIsSuccess(status)) { + fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n"); + exit(EXIT_FAILURE); + } + // Try again after registrating the pass. + mlirRegisterTransformsPrintOpStats(); + status = mlirParsePassPipeline( + mlirPassManagerGetAsOpPassManager(pm), + mlirStringRefCreateFromCString( + "module(func(print-op-stats), func(print-op-stats))")); + // Expect a failure, we haven't registered the print-op-stats pass yet. + if (mlirLogicalResultIsFailure(status)) { + fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n"); + exit(EXIT_FAILURE); + } + + // CHECK: Round-trip: module(func(print-op-stats), func(print-op-stats)) + fprintf(stderr, "Round-trip: "); + mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, + NULL); + fprintf(stderr, "\n"); +} + int main() { testRunPassOnModule(); testRunPassOnNestedModule(); + testPrintPassPipeline(); + testParsePassPipeline(); return 0; } diff --git a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp --- a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp +++ b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp @@ -33,6 +33,7 @@ const char *const passDecl = R"( /* Create {0} Pass. */ MlirPass mlirCreate{0}{1}(); +void mlirRegister{0}{1}(); )"; @@ -70,6 +71,9 @@ MlirPass mlirCreate{0}{1}() { return wrap({2}.release()); } +void mlirRegister{0}{1}() { + register{1}Pass(); +} )";