diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -51,8 +51,9 @@ // PassManager/OpPassManager APIs. //===----------------------------------------------------------------------===// -/// Create a new top-level PassManager. -MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx); +/// Create a new top-level PassManager anchored on `anchorOp`. +MLIR_CAPI_EXPORTED MlirPassManager +mlirPassManagerCreate(MlirContext ctx, MlirStringRef anchorOp); /// Destroy the provided PassManager. MLIR_CAPI_EXPORTED void mlirPassManagerDestroy(MlirPassManager passManager); diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -57,8 +57,9 @@ //---------------------------------------------------------------------------- py::class_(m, "PassManager", py::module_local()) .def(py::init<>([](DefaultingPyMlirContext context) { - MlirPassManager passManager = - mlirPassManagerCreate(context->get()); + MlirPassManager passManager = mlirPassManagerCreate( + context->get(), + mlirStringRefCreateFromCString("builtin.module")); return new PyPassManager(passManager); }), py::arg("context") = py::none(), @@ -83,7 +84,9 @@ .def_static( "parse", [](const std::string &pipeline, DefaultingPyMlirContext context) { - MlirPassManager passManager = mlirPassManagerCreate(context->get()); + MlirPassManager passManager = mlirPassManagerCreate( + context->get(), + mlirStringRefCreateFromCString("builtin.module")); PyPrintAccumulator errorMsg; MlirLogicalResult status = mlirOpPassManagerAddPipeline( mlirPassManagerGetAsOpPassManager(passManager), diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -20,8 +20,8 @@ // PassManager/OpPassManager APIs. //===----------------------------------------------------------------------===// -MlirPassManager mlirPassManagerCreate(MlirContext ctx) { - return wrap(new PassManager(unwrap(ctx))); +MlirPassManager mlirPassManagerCreate(MlirContext ctx, MlirStringRef anchorOp) { + return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp))); } void mlirPassManagerDestroy(MlirPassManager passManager) { diff --git a/mlir/test/CAPI/execution_engine.c b/mlir/test/CAPI/execution_engine.c --- a/mlir/test/CAPI/execution_engine.c +++ b/mlir/test/CAPI/execution_engine.c @@ -31,7 +31,8 @@ } void lowerModuleToLLVM(MlirContext ctx, MlirModule module) { - MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirPassManager pm = mlirPassManagerCreate( + ctx, mlirStringRefCreateFromCString("builtin.module")); MlirOpPassManager opm = mlirPassManagerGetNestedUnder( pm, mlirStringRefCreateFromCString("func.func")); mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVM()); diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -53,7 +53,8 @@ // CHECK: func.func , 1 // CHECK: func.return , 1 { - MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirPassManager pm = mlirPassManagerCreate( + ctx, mlirStringRefCreateFromCString("builtin.module")); MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); mlirPassManagerAddOwnedPass(pm, printOpStatPass); MlirLogicalResult success = mlirPassManagerRun(pm, module); @@ -95,7 +96,8 @@ // CHECK: func.func , 1 // CHECK: func.return , 1 { - MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirPassManager pm = mlirPassManagerCreate( + ctx, mlirStringRefCreateFromCString("builtin.module")); MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder( pm, mlirStringRefCreateFromCString("func.func")); MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); @@ -111,7 +113,8 @@ // CHECK: func.func , 1 // CHECK: func.return , 1 { - MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirPassManager pm = mlirPassManagerCreate( + ctx, mlirStringRefCreateFromCString("builtin.module")); MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder( pm, mlirStringRefCreateFromCString("builtin.module")); MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder( @@ -135,7 +138,8 @@ void testPrintPassPipeline() { MlirContext ctx = mlirContextCreate(); - MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirPassManager pm = mlirPassManagerCreate( + ctx, mlirStringRefCreateFromCString("builtin.module")); // Populate the pass-manager MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder( pm, mlirStringRefCreateFromCString("builtin.module")); @@ -171,7 +175,8 @@ void testParsePassPipeline() { MlirContext ctx = mlirContextCreate(); - MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirPassManager pm = mlirPassManagerCreate( + ctx, mlirStringRefCreateFromCString("builtin.module")); // Try parse a pipeline. MlirLogicalResult status = mlirParsePassPipeline( mlirPassManagerGetAsOpPassManager(pm), @@ -329,7 +334,8 @@ exit(EXIT_FAILURE); } - MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirPassManager pm = mlirPassManagerCreate( + ctx, mlirStringRefCreateFromCString("builtin.module")); mlirPassManagerAddOwnedPass(pm, externalPass); MlirLogicalResult success = mlirPassManagerRun(pm, module); if (mlirLogicalResultIsFailure(success)) { @@ -371,7 +377,8 @@ exit(EXIT_FAILURE); } - MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirPassManager pm = mlirPassManagerCreate( + ctx, mlirStringRefCreateFromCString("builtin.module")); MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(pm, funcOpName); mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass); @@ -421,7 +428,8 @@ exit(EXIT_FAILURE); } - MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirPassManager pm = mlirPassManagerCreate( + ctx, mlirStringRefCreateFromCString("builtin.module")); mlirPassManagerAddOwnedPass(pm, externalPass); MlirLogicalResult success = mlirPassManagerRun(pm, module); if (mlirLogicalResultIsFailure(success)) { @@ -468,7 +476,8 @@ exit(EXIT_FAILURE); } - MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirPassManager pm = mlirPassManagerCreate( + ctx, mlirStringRefCreateFromCString("builtin.module")); mlirPassManagerAddOwnedPass(pm, externalPass); MlirLogicalResult success = mlirPassManagerRun(pm, module); if (mlirLogicalResultIsSuccess(success)) { @@ -516,7 +525,8 @@ exit(EXIT_FAILURE); } - MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirPassManager pm = mlirPassManagerCreate( + ctx, mlirStringRefCreateFromCString("builtin.module")); mlirPassManagerAddOwnedPass(pm, externalPass); MlirLogicalResult success = mlirPassManagerRun(pm, module); if (mlirLogicalResultIsSuccess(success)) {