diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -75,6 +75,13 @@ }; typedef struct MlirNamedAttribute MlirNamedAttribute; +//===----------------------------------------------------------------------===// +// Global API. +//===----------------------------------------------------------------------===// + +/// Set the global debugging flag. +MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable); + //===----------------------------------------------------------------------===// // Context API. //===----------------------------------------------------------------------===// @@ -119,6 +126,10 @@ MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name); +/// Set threading mode (must be set to false to print-ir-after-all). +MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, + bool enable); + /// Returns whether the given fully-qualified operation (i.e. /// 'dialect.operation') is registered with the context. This will return true /// if the dialect is loaded and the operation is registered within the diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -65,6 +65,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager, MlirModule module); +/// Enable print-ir-after-all. +MLIR_CAPI_EXPORTED void +mlirPassManagerEnableIRPrinting(MlirPassManager passManager); + +/// Enable / disable verify-each. +MLIR_CAPI_EXPORTED void +mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable); + /// Nest an OpPassManager under the top-level PassManager, the nested /// passmanager will only run on operations matching the provided name. /// The returned OpPassManager will be destroyed when the parent is destroyed. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1712,6 +1712,11 @@ //------------------------------------------------------------------------------ void mlir::python::populateIRCore(py::module &m) { + //---------------------------------------------------------------------------- + // Mapping of Global functions + //---------------------------------------------------------------------------- + m.def("_enable_debug", [](bool enable) { mlirEnableGlobalDebug(enable); }); + //---------------------------------------------------------------------------- // Mapping of MlirContext //---------------------------------------------------------------------------- @@ -1766,6 +1771,10 @@ [](PyMlirContext &self, bool value) { mlirContextSetAllowUnregisteredDialects(self.get(), value); }) + .def("enable_multithreading", + [](PyMlirContext &self, bool enable) { + mlirContextEnableMultithreading(self.get(), enable); + }) .def("is_registered_operation", [](PyMlirContext &self, std::string &name) { return mlirContextIsRegisteredOperation( diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -68,6 +68,18 @@ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) .def("_testing_release", &PyPassManager::release, "Releases (leaks) the backing pass manager (testing)") + .def( + "enable_ir_printing", + [](PyPassManager &passManager) { + mlirPassManagerEnableIRPrinting(passManager.get()); + }, + "Enable print-ir-after-all.") + .def( + "enable_verifier", + [](PyPassManager &passManager, bool enable) { + mlirPassManagerEnableVerifier(passManager.get(), enable); + }, + "Enable / disable verify-each.") .def_static( "parse", [](const std::string pipeline, DefaultingPyMlirContext context) { diff --git a/mlir/lib/Bindings/Python/mlir/ir.py b/mlir/lib/Bindings/Python/mlir/ir.py --- a/mlir/lib/Bindings/Python/mlir/ir.py +++ b/mlir/lib/Bindings/Python/mlir/ir.py @@ -6,3 +6,8 @@ from ._cext_loader import _reexport_cext _reexport_cext("ir", __name__) del _reexport_cext + +# Extra functions that are not visible to _reexport_cext. +# TODO: is this really necessary? +from _mlir.ir import _enable_debug +_enable_debug = _enable_debug \ No newline at end of file diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -21,8 +21,16 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser.h" +#include "llvm/Support/Debug.h" + using namespace mlir; +//===----------------------------------------------------------------------===// +// Global API. +//===----------------------------------------------------------------------===// + +void mlirEnableGlobalDebug(bool enable) { ::llvm::DebugFlag = true; } + //===----------------------------------------------------------------------===// // Context API. //===----------------------------------------------------------------------===// @@ -64,6 +72,10 @@ return unwrap(context)->isOperationRegistered(unwrap(name)); } +void mlirContextEnableMultithreading(MlirContext context, bool enable) { + return unwrap(context)->enableMultithreading(enable); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -38,6 +38,14 @@ return wrap(unwrap(passManager)->run(unwrap(module))); } +void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) { + return unwrap(passManager)->enableIRPrinting(); +} + +void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { + unwrap(passManager)->enableVerifier(enable); +} + MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, MlirStringRef operationName) { return wrap(&unwrap(passManager)->nest(unwrap(operationName))); diff --git a/mlir/test/Bindings/Python/context_managers.py b/mlir/test/Bindings/Python/context_managers.py --- a/mlir/test/Bindings/Python/context_managers.py +++ b/mlir/test/Bindings/Python/context_managers.py @@ -10,6 +10,13 @@ assert Context._get_live_count() == 0 +# CHECK-LABEL: TEST: testExports +def testExports(): + from mlir.ir import _enable_debug + +run(testExports) + + # CHECK-LABEL: TEST: testContextEnterExit def testContextEnterExit(): with Context() as ctx: