diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -64,6 +64,8 @@ MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute._CAPIPtr") #define MLIR_PYTHON_CAPSULE_CONTEXT \ MAKE_MLIR_PYTHON_QUALNAME("ir.Context._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY \ + MAKE_MLIR_PYTHON_QUALNAME("ir.DialectRegistry._CAPIPtr") #define MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE \ MAKE_MLIR_PYTHON_QUALNAME("execution_engine.ExecutionEngine._CAPIPtr") #define MLIR_PYTHON_CAPSULE_INTEGER_SET \ @@ -172,6 +174,28 @@ return context; } +/** Creates a capsule object encapsulating the raw C-API MlirDialectRegistry. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the context in any way. + */ +static inline PyObject * +mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry) { + return PyCapsule_New(registry.ptr, MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY, + NULL); +} + +/** Extracts an MlirDialectRegistry from a capsule as produced from + * mlirPythonDialectRegistryToCapsule. If the capsule is not of the right type, + * then a null context is returned (as checked via mlirContextIsNull). In such a + * case, the Python APIs will have already set an error. */ +static inline MlirDialectRegistry +mlirPythonCapsuleToDialectRegistry(PyObject *capsule) { + void *ptr = + PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY); + MlirDialectRegistry registry = {ptr}; + return registry; +} + /** Creates a capsule object encapsulating the raw C-API MlirLocation. * The returned capsule does not extend or affect ownership of any Python * objects that reference the location in any way. */ diff --git a/mlir/include/mlir-c/Dialect/Async.h b/mlir/include/mlir-c/Dialect/Async.h --- a/mlir/include/mlir-c/Dialect/Async.h +++ b/mlir/include/mlir-c/Dialect/Async.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_ASYNC_H #define MLIR_C_DIALECT_ASYNC_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #include "mlir-c/Support.h" #ifdef __cplusplus diff --git a/mlir/include/mlir-c/Dialect/ControlFlow.h b/mlir/include/mlir-c/Dialect/ControlFlow.h --- a/mlir/include/mlir-c/Dialect/ControlFlow.h +++ b/mlir/include/mlir-c/Dialect/ControlFlow.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_CONTROLFLOW_H #define MLIR_C_DIALECT_CONTROLFLOW_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/Func.h b/mlir/include/mlir-c/Dialect/Func.h --- a/mlir/include/mlir-c/Dialect/Func.h +++ b/mlir/include/mlir-c/Dialect/Func.h @@ -18,7 +18,7 @@ #ifndef MLIR_C_DIALECT_FUNC_H #define MLIR_C_DIALECT_FUNC_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/GPU.h b/mlir/include/mlir-c/Dialect/GPU.h --- a/mlir/include/mlir-c/Dialect/GPU.h +++ b/mlir/include/mlir-c/Dialect/GPU.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_GPU_H #define MLIR_C_DIALECT_GPU_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #include "mlir-c/Support.h" #ifdef __cplusplus diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -11,7 +11,6 @@ #define MLIR_C_DIALECT_LLVM_H #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_LINALG_H #define MLIR_C_DIALECT_LINALG_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #include "mlir-c/Support.h" #ifdef __cplusplus diff --git a/mlir/include/mlir-c/Dialect/PDL.h b/mlir/include/mlir-c/Dialect/PDL.h --- a/mlir/include/mlir-c/Dialect/PDL.h +++ b/mlir/include/mlir-c/Dialect/PDL.h @@ -11,7 +11,6 @@ #define MLIR_C_DIALECT_PDL_H #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h --- a/mlir/include/mlir-c/Dialect/Quant.h +++ b/mlir/include/mlir-c/Dialect/Quant.h @@ -11,7 +11,6 @@ #define MLIR_C_DIALECT_QUANT_H #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/SCF.h b/mlir/include/mlir-c/Dialect/SCF.h --- a/mlir/include/mlir-c/Dialect/SCF.h +++ b/mlir/include/mlir-c/Dialect/SCF.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_SCF_H #define MLIR_C_DIALECT_SCF_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/Shape.h b/mlir/include/mlir-c/Dialect/Shape.h --- a/mlir/include/mlir-c/Dialect/Shape.h +++ b/mlir/include/mlir-c/Dialect/Shape.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_SHAPE_H #define MLIR_C_DIALECT_SHAPE_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -11,7 +11,7 @@ #define MLIR_C_DIALECT_SPARSETENSOR_H #include "mlir-c/AffineMap.h" -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/Tensor.h b/mlir/include/mlir-c/Dialect/Tensor.h --- a/mlir/include/mlir-c/Dialect/Tensor.h +++ b/mlir/include/mlir-c/Dialect/Tensor.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_TENSOR_H #define MLIR_C_DIALECT_TENSOR_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -130,6 +130,11 @@ MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable); +/// Eagerly loads all available dialects registered with a context, making +/// them available for use for IR construction. +MLIR_CAPI_EXPORTED void +mlirContextLoadAllAvailableDialects(MlirContext context); + /// Returns whether the given fully-qualified operation (i.e. /// 'dialect.operation') is registered with the context. This will return true /// if the dialect is loaded and the operation is registered within the @@ -157,6 +162,47 @@ /// Returns the namespace of the given dialect. MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect); +//===----------------------------------------------------------------------===// +// DialectHandle API. +// Registration entry-points for each dialect are declared using the common +// MLIR_DECLARE_DIALECT_REGISTRATION_CAPI macro, which takes the dialect +// API name (i.e. "Func", "Tensor", "Linalg") and namespace (i.e. "func", +// "tensor", "linalg"). The following declarations are produced: +// +// /// Gets the above hook methods in struct form for a dialect by namespace. +// /// This is intended to facilitate dynamic lookup and registration of +// /// dialects via a plugin facility based on shared library symbol lookup. +// const MlirDialectHandle *mlirGetDialectHandle__{NAMESPACE}__(); +// +// This is done via a common macro to facilitate future expansion to +// registration schemes. +//===----------------------------------------------------------------------===// + +struct MlirDialectHandle { + const void *ptr; +}; +typedef struct MlirDialectHandle MlirDialectHandle; + +#define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace) \ + MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__() + +/// Returns the namespace associated with the provided dialect handle. +MLIR_CAPI_EXPORTED +MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle); + +/// Inserts the dialect associated with the provided dialect handle into the +/// provided dialect registry +MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, + MlirDialectRegistry); + +/// Registers the dialect associated with the provided dialect handle. +MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle, + MlirContext); + +/// Loads the dialect associated with the provided dialect handle. +MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle, + MlirContext); + //===----------------------------------------------------------------------===// // DialectRegistry API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -15,7 +15,6 @@ #define MLIR_C_PASS_H #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" #include "mlir-c/Support.h" #ifdef __cplusplus diff --git a/mlir/include/mlir-c/RegisterEverything.h b/mlir/include/mlir-c/RegisterEverything.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/RegisterEverything.h @@ -0,0 +1,38 @@ +//===-- mlir-c/RegisterEverything.h - Register all MLIR entities --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This header contains registration entry points for MLIR upstream dialects +// and passes. Downstream projects typically will not want to use this unless +// if they don't care about binary size or build bloat and just wish access +// to the entire set of upstream facilities. For those that do care, they +// should use registration functions specific to their project. +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_REGISTER_EVERYTHING_H +#define MLIR_C_REGISTER_EVERYTHING_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Appends all upstream dialects and extensions to the dialect registry. +MLIR_CAPI_EXPORTED void mlirRegisterAllDialects(MlirDialectRegistry registry); + +/// Register all translations to LLVM IR for dialects that can support it. +MLIR_CAPI_EXPORTED void mlirRegisterAllLLVMTranslations(MlirContext context); + +/// Register all compiler passes of MLIR. +MLIR_CAPI_EXPORTED void mlirRegisterAllPasses(); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_REGISTER_EVERYTHING_H diff --git a/mlir/include/mlir-c/Registration.h b/mlir/include/mlir-c/Registration.h deleted file mode 100644 --- a/mlir/include/mlir-c/Registration.h +++ /dev/null @@ -1,75 +0,0 @@ -//===-- mlir-c/Registration.h - Registration functions for MLIR ---*- C -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM -// Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_C_REGISTRATION_H -#define MLIR_C_REGISTRATION_H - -#include "mlir-c/IR.h" - -#ifdef __cplusplus -extern "C" { -#endif - -//===----------------------------------------------------------------------===// -// Dialect registration declarations. -// Registration entry-points for each dialect are declared using the common -// MLIR_DECLARE_DIALECT_REGISTRATION_CAPI macro, which takes the dialect -// API name (i.e. "Func", "Tensor", "Linalg") and namespace (i.e. "func", -// "tensor", "linalg"). The following declarations are produced: -// -// /// Gets the above hook methods in struct form for a dialect by namespace. -// /// This is intended to facilitate dynamic lookup and registration of -// /// dialects via a plugin facility based on shared library symbol lookup. -// const MlirDialectHandle *mlirGetDialectHandle__{NAMESPACE}__(); -// -// This is done via a common macro to facilitate future expansion to -// registration schemes. -//===----------------------------------------------------------------------===// - -struct MlirDialectHandle { - const void *ptr; -}; -typedef struct MlirDialectHandle MlirDialectHandle; - -#define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace) \ - MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__() - -/// Returns the namespace associated with the provided dialect handle. -MLIR_CAPI_EXPORTED -MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle); - -/// Inserts the dialect associated with the provided dialect handle into the -/// provided dialect registry -MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, - MlirDialectRegistry); - -/// Registers the dialect associated with the provided dialect handle. -MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle, - MlirContext); - -/// Loads the dialect associated with the provided dialect handle. -MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle, - MlirContext); - -/// Registers all dialects known to core MLIR with the provided Context. -/// This is needed before creating IR for these Dialects. -/// TODO: Remove this function once the real registration API is finished. -MLIR_CAPI_EXPORTED void mlirRegisterAllDialects(MlirContext context); - -/// Register all translations to LLVM IR for dialects that can support it. -MLIR_CAPI_EXPORTED void mlirRegisterAllLLVMTranslations(MlirContext context); - -/// Register all compiler passes of MLIR. -MLIR_CAPI_EXPORTED void mlirRegisterAllPasses(); - -#ifdef __cplusplus -} -#endif - -#endif // MLIR_C_REGISTRATION_H diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -124,6 +124,25 @@ } }; +/// Casts object <-> MlirDialectRegistry. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); + return !mlirDialectRegistryIsNull(value); + } + static handle cast(MlirDialectRegistry v, return_value_policy, handle) { + py::object capsule = py::reinterpret_steal( + mlirPythonDialectRegistryToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("DialectRegistry") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + /// Casts object <-> MlirLocation. template <> struct type_caster { @@ -252,7 +271,7 @@ } template - pure_subclass &def(const char *name, Func &&f, const Extra &... extra) { + pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { py::cpp_function cf( std::forward(f), py::name(name), py::is_method(thisClass), py::sibling(py::getattr(thisClass, name, py::none())), extra...); @@ -262,7 +281,7 @@ template pure_subclass &def_property_readonly(const char *name, Func &&f, - const Extra &... extra) { + const Extra &...extra) { py::cpp_function cf( std::forward(f), py::name(name), py::is_method(thisClass), py::sibling(py::getattr(thisClass, name, py::none())), extra...); @@ -274,7 +293,7 @@ template pure_subclass &def_staticmethod(const char *name, Func &&f, - const Extra &... extra) { + const Extra &...extra) { static_assert(!std::is_member_function_pointer::value, "def_staticmethod(...) called with a non-static member " "function pointer"); @@ -287,7 +306,7 @@ template pure_subclass &def_classmethod(const char *name, Func &&f, - const Extra &... extra) { + const Extra &...extra) { static_assert(!std::is_member_function_pointer::value, "def_classmethod(...) called with a non-static member " "function pointer"); diff --git a/mlir/include/mlir/CAPI/Registration.h b/mlir/include/mlir/CAPI/Registration.h --- a/mlir/include/mlir/CAPI/Registration.h +++ b/mlir/include/mlir/CAPI/Registration.h @@ -10,7 +10,6 @@ #define MLIR_CAPI_REGISTRATION_H #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" diff --git a/mlir/lib/Bindings/Python/AllPassesRegistration.cpp b/mlir/lib/Bindings/Python/AllPassesRegistration.cpp deleted file mode 100644 --- a/mlir/lib/Bindings/Python/AllPassesRegistration.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//===- AllPassesRegistration.cpp - Pybind module to register all passes ---===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir-c/Registration.h" - -#include - -// ----------------------------------------------------------------------------- -// Module initialization. -// ----------------------------------------------------------------------------- - -PYBIND11_MODULE(_mlirAllPassesRegistration, m) { - m.doc() = "MLIR All Passes Convenience Module"; - - // Register all passes on load. - mlirRegisterAllPasses(); -} diff --git a/mlir/lib/Bindings/Python/Conversions/Conversions.cpp b/mlir/lib/Bindings/Python/Conversions/Conversions.cpp deleted file mode 100644 --- a/mlir/lib/Bindings/Python/Conversions/Conversions.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//===- Conversions.cpp - Pybind module for the Conversionss library -------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir-c/Conversion.h" - -#include - -// ----------------------------------------------------------------------------- -// Module initialization. -// ----------------------------------------------------------------------------- - -PYBIND11_MODULE(_mlirConversions, m) { - m.doc() = "MLIR Conversions library"; - - // Register all the passes in the Conversions library on load. - mlirRegisterConversionPasses(); -} diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -16,7 +16,7 @@ #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" +//#include "mlir-c/Registration.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -474,7 +474,6 @@ PyMlirContext *PyMlirContext::createNewContextForInit() { MlirContext context = mlirContextCreate(); - mlirRegisterAllDialects(context); return new PyMlirContext(context); } @@ -793,7 +792,7 @@ } //------------------------------------------------------------------------------ -// PyDialect, PyDialectDescriptor, PyDialects +// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry //------------------------------------------------------------------------------ MlirDialect PyDialects::getDialectForKey(const std::string &key, @@ -807,6 +806,19 @@ return dialect; } +py::object PyDialectRegistry::getCapsule() { + return py::reinterpret_steal( + mlirPythonDialectRegistryToCapsule(*this)); +} + +PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) { + MlirDialectRegistry rawRegistry = + mlirPythonCapsuleToDialectRegistry(capsule.ptr()); + if (mlirDialectRegistryIsNull(rawRegistry)) + throw py::error_already_set(); + return PyDialectRegistry(rawRegistry); +} + //------------------------------------------------------------------------------ // PyLocation //------------------------------------------------------------------------------ @@ -2207,8 +2219,11 @@ //---------------------------------------------------------------------------- // Mapping of MlirContext. + // Note that this is exported as _BaseContext. The containing, Python level + // __init__.py will subclass it with site-specific functionality and set a + // "Context" attribute on this module. //---------------------------------------------------------------------------- - py::class_(m, "Context", py::module_local()) + py::class_(m, "_BaseContext", py::module_local()) .def(py::init<>(&PyMlirContext::createNewContextForInit)) .def_static("_get_live_count", &PyMlirContext::getLiveCount) .def("_get_context_again", @@ -2276,7 +2291,16 @@ return mlirContextIsRegisteredOperation( self.get(), MlirStringRef{name.data(), name.size()}); }, - py::arg("operation_name")); + py::arg("operation_name")) + .def( + "append_dialect_registry", + [](PyMlirContext &self, PyDialectRegistry ®istry) { + mlirContextAppendDialectRegistry(self.get(), registry); + }, + py::arg("registry")) + .def("load_all_available_dialects", [](PyMlirContext &self) { + mlirContextLoadAllAvailableDialects(self.get()); + }); //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor @@ -2331,6 +2355,15 @@ clazz.attr("__name__") + py::str(")>"); }); + //---------------------------------------------------------------------------- + // Mapping of PyDialectRegistry + //---------------------------------------------------------------------------- + py::class_(m, "DialectRegistry", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyDialectRegistry::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) + .def(py::init<>()); + //---------------------------------------------------------------------------- // Mapping of Location //---------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -390,6 +390,32 @@ pybind11::object descriptor; }; +/// Wrapper around an MlirDialectRegistry. +/// Upon construction, the Python wrapper takes ownership of the +/// underlying MlirDialectRegistry. +class PyDialectRegistry { +public: + PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {} + PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {} + ~PyDialectRegistry() { + if (!mlirDialectRegistryIsNull(registry)) + mlirDialectRegistryDestroy(registry); + } + PyDialectRegistry(PyDialectRegistry &) = delete; + PyDialectRegistry(PyDialectRegistry &&other) : registry(other.registry) { + other.registry = {nullptr}; + } + + operator MlirDialectRegistry() const { return registry; } + MlirDialectRegistry get() const { return registry; } + + pybind11::object getCapsule(); + static PyDialectRegistry createFromCapsule(pybind11::object capsule); + +private: + MlirDialectRegistry registry; +}; + /// Wrapper around an MlirLocation. class PyLocation : public BaseContextObject { public: diff --git a/mlir/lib/Bindings/Python/RegisterEverything.cpp b/mlir/lib/Bindings/Python/RegisterEverything.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/RegisterEverything.cpp @@ -0,0 +1,26 @@ +//===- RegisterEverything.cpp - API to register all dialects/passes -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/RegisterEverything.h" +#include "mlir-c/Conversion.h" +#include "mlir-c/Transforms.h" + +#include "mlir/Bindings/Python/PybindAdaptors.h" + +PYBIND11_MODULE(_mlirRegisterEverything, m) { + m.doc() = "MLIR All Upstream Dialects and Passes Registration"; + + m.def("register_dialects", [](MlirDialectRegistry registry) { + mlirRegisterAllDialects(registry); + }); + + // Register all passes on load. + mlirRegisterAllPasses(); + mlirRegisterConversionPasses(); + mlirRegisterTransformsPasses(); +} diff --git a/mlir/lib/Bindings/Python/Transforms/Transforms.cpp b/mlir/lib/Bindings/Python/Transforms/Transforms.cpp deleted file mode 100644 --- a/mlir/lib/Bindings/Python/Transforms/Transforms.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//===- Transforms.cpp - Pybind module for the Transforms library ----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir-c/Transforms.h" - -#include - -// ----------------------------------------------------------------------------- -// Module initialization. -// ----------------------------------------------------------------------------- - -PYBIND11_MODULE(_mlirTransforms, m) { - m.doc() = "MLIR Transforms library"; - - // Register all the passes in the Transforms library on load. - mlirRegisterTransformsPasses(); -} diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -12,7 +12,7 @@ add_subdirectory(Conversion) add_subdirectory(Interfaces) add_subdirectory(IR) -add_subdirectory(Registration) +add_subdirectory(RegisterEverything) add_subdirectory(Transforms) # Only enable the ExecutionEngine if the native target is configured in. diff --git a/mlir/lib/CAPI/Conversion/CMakeLists.txt b/mlir/lib/CAPI/Conversion/CMakeLists.txt --- a/mlir/lib/CAPI/Conversion/CMakeLists.txt +++ b/mlir/lib/CAPI/Conversion/CMakeLists.txt @@ -2,6 +2,9 @@ add_mlir_upstream_c_api_library(MLIRCAPIConversion Passes.cpp + DEPENDS + MLIRConversionPassIncGen + LINK_LIBS PUBLIC ${conversion_libs} ) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -77,6 +77,10 @@ return unwrap(context)->enableMultithreading(enable); } +void mlirContextLoadAllAvailableDialects(MlirContext context) { + unwrap(context)->loadAllAvailableDialects(); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/Registration/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt rename from mlir/lib/CAPI/Registration/CMakeLists.txt rename to mlir/lib/CAPI/RegisterEverything/CMakeLists.txt --- a/mlir/lib/CAPI/Registration/CMakeLists.txt +++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt @@ -2,13 +2,15 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -add_mlir_upstream_c_api_library(MLIRCAPIRegistration - Registration.cpp +add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything + RegisterEverything.cpp LINK_LIBS PUBLIC - MLIRCAPIIR - MLIRLLVMToLLVMIRTranslation ${dialect_libs} ${translation_libs} ${conversion_libs} + + MLIRCAPIIR + MLIRLLVMToLLVMIRTranslation + MLIRCAPITransforms ) diff --git a/mlir/lib/CAPI/Registration/Registration.cpp b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp rename from mlir/lib/CAPI/Registration/Registration.cpp rename to mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp --- a/mlir/lib/CAPI/Registration/Registration.cpp +++ b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp @@ -1,4 +1,4 @@ -//===- Registration.cpp - C Interface for MLIR Registration ---------------===// +//===- RegisterEverything.cpp - Register all MLIR entities ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,17 +6,15 @@ // //===----------------------------------------------------------------------===// -#include "mlir-c/Registration.h" +#include "mlir-c/RegisterEverything.h" #include "mlir/CAPI/IR.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -void mlirRegisterAllDialects(MlirContext context) { - mlir::registerAllDialects(*unwrap(context)); - // TODO: we may not want to eagerly load here. - unwrap(context)->loadAllAvailableDialects(); +void mlirRegisterAllDialects(MlirDialectRegistry registry) { + mlir::registerAllDialects(*unwrap(registry)); } void mlirRegisterAllLLVMTranslations(MlirContext context) { diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -37,15 +37,6 @@ runtime/*.py ) -declare_mlir_python_sources(MLIRPythonSources.Passes - ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" - ADD_TO_PARENT MLIRPythonSources - SOURCES_GLOB - all_passes_registration/*.py - conversions/*.py - transforms/*.py -) - declare_mlir_python_sources(MLIRPythonCAPIHeaderSources ROOT_DIR "${MLIR_SOURCE_DIR}/include" SOURCES_GLOB "mlir-c/*.h" @@ -284,12 +275,31 @@ MLIRCAPIDebug MLIRCAPIIR MLIRCAPIInterfaces - MLIRCAPIRegistration # TODO: See about dis-aggregating # Dialects MLIRCAPIFunc ) +# This extension exposes an API to register all dialects, extensions, and passes +# packaged in upstream MLIR and it is used for the upstream "mlir" Python +# package. Downstreams will likely want to provide their own and not depend +# on this one, since it links in the world. +# Note that this is not added to any top-level source target for transitive +# inclusion: It must be included explicitly by downstreams if desired. Note that +# this has a very large impact on what gets built/packaged. +declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything + MODULE_NAME _mlirRegisterEverything + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + RegisterEverything.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIConversion + MLIRCAPITransforms + MLIRCAPIRegisterEverything +) + declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind MODULE_NAME _mlirDialectsLinalg ADD_TO_PARENT MLIRPythonSources.Dialects.linalg @@ -342,18 +352,6 @@ MLIRCAPISparseTensor ) -declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration - MODULE_NAME _mlirAllPassesRegistration - ROOT_DIR "${PYTHON_SOURCE_DIR}" - SOURCES - AllPassesRegistration.cpp - PRIVATE_LINK_LIBS - LLVMSupport - EMBED_CAPI_LINK_LIBS - MLIRCAPIConversion - MLIRCAPITransforms -) - declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect @@ -366,18 +364,6 @@ MLIRCAPIAsync ) -declare_mlir_python_extension(MLIRPythonExtension.Conversions - MODULE_NAME _mlirConversions - ADD_TO_PARENT MLIRPythonSources.Passes - ROOT_DIR "${PYTHON_SOURCE_DIR}" - SOURCES - Conversions/Conversions.cpp - PRIVATE_LINK_LIBS - LLVMSupport - EMBED_CAPI_LINK_LIBS - MLIRCAPIConversion -) - # Only enable the ExecutionEngine if the native target is configured in. if(TARGET ${LLVM_NATIVE_ARCH}) declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine @@ -429,18 +415,6 @@ MLIRCAPISparseTensor ) -declare_mlir_python_extension(MLIRPythonExtension.Transforms - MODULE_NAME _mlirTransforms - ADD_TO_PARENT MLIRPythonSources.Passes - ROOT_DIR "${PYTHON_SOURCE_DIR}" - SOURCES - Transforms/Transforms.cpp - PRIVATE_LINK_LIBS - LLVMSupport - EMBED_CAPI_LINK_LIBS - MLIRCAPITransforms -) - # TODO: Figure out how to put this in the test tree. # This should not be included in the main Python extension. However, # putting it into MLIRPythonTestSources along with the dialect declaration @@ -505,7 +479,7 @@ RELATIVE_INSTALL_ROOT "../../../.." DECLARED_SOURCES MLIRPythonSources - MLIRPythonExtension.AllPassesRegistration + MLIRPythonExtension.RegisterEverything ${_ADDL_TEST_SOURCES} ) @@ -519,8 +493,8 @@ INSTALL_PREFIX "python_packages/mlir_core/mlir" DECLARED_SOURCES MLIRPythonSources - MLIRPythonExtension.AllPassesRegistration MLIRPythonCAPIHeaderSources + MLIRPythonExtension.RegisterEverything ${_ADDL_TEST_SOURCES} COMMON_CAPI_LINK_LIBS MLIRPythonCAPI diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -9,12 +9,6 @@ _this_dir = os.path.dirname(__file__) -# These submodules have no type stubs and are thus opaque to the type checker. -_mlirConversions: Any -_mlirTransforms: Any -_mlirAllPassesRegistration: Any - - def get_lib_dirs() -> Sequence[str]: """Gets the lib directory for linking to shared libraries. @@ -31,3 +25,77 @@ not be present. """ return [os.path.join(_this_dir, "include")] + + +# Perform Python level site initialization. This involves: +# 1. Attempting to load initializer modules, specific to the distribution. +# 2. Defining the concrete mlir.ir.Context that does site specific +# initialization. +# +# Aside from just being far more convenient to do this at the Python level, +# it is actually quite hard/impossible to have such __init__ hooks, given +# the pybind memory model (i.e. there is not a Python reference to the object +# in the scope of the base class __init__). +# +# For #1, we: +# a. Probe for modules named '_mlirRegisterEverything' and +# '_site_initialize_{i}', where 'i' is a number starting at zero and +# proceeding so long as a module with the name is found. +# b. If the module has a 'register_dialects' attribute, it will be called +# immediately with a DialectRegistry to populate. +# c. If the module has a 'context_init_hook', it will be added to a list +# of callbacks that are invoked as the last step of Context +# initialization (and passed the Context under construction). +# +# This facility allows downstreams to customize Context creation to their +# needs. +def _site_initialize(): + import importlib + import itertools + import logging + from ._mlir import ir + registry = ir.DialectRegistry() + post_init_hooks = [] + + def process_initializer_module(module_name): + try: + m = importlib.import_module(f".{module_name}", __name__) + except ModuleNotFoundError: + return False + + logging.debug("Initializing MLIR with module: %s", module_name) + if hasattr(m, "register_dialects"): + logging.debug("Registering dialects from initializer %r", m) + m.register_dialects(registry) + if hasattr(m, "context_init_hook"): + logging.debug("Adding context init hook from %r", m) + post_init_hooks.append(m.context_init_hook) + return True + + + # If _mlirRegisterEverything is built, then include it as an initializer + # module. + process_initializer_module("_mlirRegisterEverything") + + # Load all _site_initialize_{i} modules, where 'i' is a number starting + # at 0. + for i in itertools.count(): + module_name = f"_site_initialize_{i}" + if not process_initializer_module(module_name): + break + + class Context(ir._BaseContext): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.append_dialect_registry(registry) + for hook in post_init_hooks: + hook(self) + # TODO: There is some debate about whether we should eagerly load + # all dialects. It is being done here in order to preserve existing + # behavior. See: https://github.com/llvm/llvm-project/issues/56037 + self.load_all_available_dialects() + + ir.Context = Context + + +_site_initialize() diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -479,6 +479,11 @@ def d(self) -> Dialects: ... @property def dialects(self) -> Dialects: ... + def append_dialect_registry(self, registry: "DialectRegistry") -> None: ... + def load_all_available_dialects(self) -> None: ... + +class DialectRegistry: + def __init__(self) -> None: ... # TODO: Auto-generated. Audit and fix. class DenseElementsAttr(Attribute): diff --git a/mlir/python/mlir/all_passes_registration/__init__.py b/mlir/python/mlir/all_passes_registration/__init__.py deleted file mode 100644 --- a/mlir/python/mlir/all_passes_registration/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from .._mlir_libs import _mlirAllPassesRegistration as _cextAllPasses diff --git a/mlir/python/mlir/conversions/__init__.py b/mlir/python/mlir/conversions/__init__.py deleted file mode 100644 --- a/mlir/python/mlir/conversions/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Expose the corresponding C-Extension module with a well-known name at this -# level. -from .._mlir_libs import _mlirConversions as _cextConversions diff --git a/mlir/python/mlir/transforms/__init__.py b/mlir/python/mlir/transforms/__init__.py deleted file mode 100644 --- a/mlir/python/mlir/transforms/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Expose the corresponding C-Extension module with a well-known name at this -# level. -from .._mlir_libs import _mlirTransforms as _cextTransforms diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt --- a/mlir/test/CAPI/CMakeLists.txt +++ b/mlir/test/CAPI/CMakeLists.txt @@ -26,8 +26,8 @@ LINK_LIBS PRIVATE MLIRCAPIConversion MLIRCAPIExecutionEngine - MLIRCAPIRegistration - ) + MLIRCAPIRegisterEverything +) endif() _add_capi_test_executable(mlir-capi-ir-test @@ -35,7 +35,7 @@ LINK_LIBS PRIVATE MLIRCAPIIR MLIRCAPIFunc - MLIRCAPIRegistration + MLIRCAPIRegisterEverything ) _add_capi_test_executable(mlir-capi-llvm-test @@ -43,7 +43,7 @@ LINK_LIBS PRIVATE MLIRCAPIIR MLIRCAPILLVM - MLIRCAPIRegistration + MLIRCAPIRegisterEverything ) _add_capi_test_executable(mlir-capi-pass-test @@ -51,7 +51,7 @@ LINK_LIBS PRIVATE MLIRCAPIFunc MLIRCAPIIR - MLIRCAPIRegistration + MLIRCAPIRegisterEverything MLIRCAPITransforms ) @@ -59,7 +59,7 @@ sparse_tensor.c LINK_LIBS PRIVATE MLIRCAPIIR - MLIRCAPIRegistration + MLIRCAPIRegisterEverything MLIRCAPISparseTensor ) @@ -67,7 +67,7 @@ quant.c LINK_LIBS PRIVATE MLIRCAPIIR - MLIRCAPIRegistration + MLIRCAPIRegisterEverything MLIRCAPIQuant ) @@ -75,6 +75,6 @@ pdl.c LINK_LIBS PRIVATE MLIRCAPIIR - MLIRCAPIRegistration + MLIRCAPIRegisterEverything MLIRCAPIPDL ) diff --git a/mlir/test/CAPI/execution_engine.c b/mlir/test/CAPI/execution_engine.c --- a/mlir/test/CAPI/execution_engine.c +++ b/mlir/test/CAPI/execution_engine.c @@ -15,7 +15,7 @@ #include "mlir-c/Conversion.h" #include "mlir-c/ExecutionEngine.h" #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" +#include "mlir-c/RegisterEverything.h" #include #include @@ -23,6 +23,13 @@ #include #include +static void registerAllUpstreamDialects(MlirContext ctx) { + MlirDialectRegistry registry = mlirDialectRegistryCreate(); + mlirRegisterAllDialects(registry); + mlirContextAppendDialectRegistry(ctx, registry); + mlirDialectRegistryDestroy(registry); +} + void lowerModuleToLLVM(MlirContext ctx, MlirModule module) { MlirPassManager pm = mlirPassManagerCreate(ctx); MlirOpPassManager opm = mlirPassManagerGetNestedUnder( @@ -41,7 +48,8 @@ // CHECK-LABEL: Running test 'testSimpleExecution' void testSimpleExecution() { MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); + MlirModule module = mlirModuleCreateParse( ctx, mlirStringRefCreateFromCString( // clang-format off diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -18,7 +18,7 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/Dialect/Func.h" #include "mlir-c/IntegerSet.h" -#include "mlir-c/Registration.h" +#include "mlir-c/RegisterEverything.h" #include "mlir-c/Support.h" #include @@ -28,6 +28,13 @@ #include #include +static void registerAllUpstreamDialects(MlirContext ctx) { + MlirDialectRegistry registry = mlirDialectRegistryCreate(); + mlirRegisterAllDialects(registry); + mlirContextAppendDialectRegistry(ctx, registry); + mlirDialectRegistryDestroy(registry); +} + void populateLoopBody(MlirContext ctx, MlirBlock loopBody, MlirLocation location, MlirBlock funcBody) { MlirValue iv = mlirBlockGetArgument(loopBody, 0); @@ -1646,7 +1653,9 @@ // CHECK-LABEL: @testOperands MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); + + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("arith")); mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("test")); MlirLocation loc = mlirLocationUnknownGet(ctx); MlirType indexType = mlirIndexTypeGet(ctx); @@ -1714,7 +1723,8 @@ // CHECK-LABEL: @testClone MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func")); MlirLocation loc = mlirLocationUnknownGet(ctx); MlirType indexType = mlirIndexTypeGet(ctx); @@ -2036,7 +2046,12 @@ int main() { MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func")); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("memref")); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("shape")); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("scf")); + if (constructAndTraverseIr(ctx)) return 1; buildWithInsertionsAndPrint(ctx); diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -13,7 +13,7 @@ #include "mlir-c/Pass.h" #include "mlir-c/Dialect/Func.h" #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" +#include "mlir-c/RegisterEverything.h" #include "mlir-c/Transforms.h" #include @@ -22,9 +22,16 @@ #include #include +static void registerAllUpstreamDialects(MlirContext ctx) { + MlirDialectRegistry registry = mlirDialectRegistryCreate(); + mlirRegisterAllDialects(registry); + mlirContextAppendDialectRegistry(ctx, registry); + mlirDialectRegistryDestroy(registry); +} + void testRunPassOnModule() { MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); MlirModule module = mlirModuleCreateParse( ctx, @@ -62,7 +69,7 @@ void testRunPassOnNestedModule() { MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); MlirModule module = mlirModuleCreateParse( ctx, @@ -264,7 +271,7 @@ void testExternalPass() { MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); MlirModule module = mlirModuleCreateParse( ctx, diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c --- a/mlir/test/CAPI/sparse_tensor.c +++ b/mlir/test/CAPI/sparse_tensor.c @@ -11,7 +11,7 @@ #include "mlir-c/Dialect/SparseTensor.h" #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" +#include "mlir-c/RegisterEverything.h" #include #include diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -62,7 +62,6 @@ def lowerToLLVM(module): - import mlir.conversions pm = PassManager.parse( "convert-complex-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts") pm.run(module) diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -186,10 +186,6 @@ def transform(module, boilerplate): - import mlir.conversions - import mlir.all_passes_registration - import mlir.transforms - # TODO: Allow cloning functions from one module to another. # Atm we have to resort to string concatenation. ops = module.operation.regions[0].blocks[0].operations diff --git a/mlir/test/python/lib/CMakeLists.txt b/mlir/test/python/lib/CMakeLists.txt --- a/mlir/test/python/lib/CMakeLists.txt +++ b/mlir/test/python/lib/CMakeLists.txt @@ -27,7 +27,6 @@ LINK_LIBS PUBLIC MLIRCAPIInterfaces MLIRCAPIIR - MLIRCAPIRegistration MLIRPythonTestDialect ) diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h --- a/mlir/test/python/lib/PythonTestCAPI.h +++ b/mlir/test/python/lib/PythonTestCAPI.h @@ -9,7 +9,7 @@ #ifndef MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H #define MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -33,20 +33,18 @@ # CHECK-LABEL: TEST: testParseSuccess def testParseSuccess(): with Context(): - # A first import is expected to fail because the pass isn't registered - # until we import mlir.transforms + # An unregistered pass should not parse. try: - pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))") + pm = PassManager.parse("builtin.module(func.func(not-existing-pass{json=false}))") # TODO: this error should be propagate to Python but the C API does not help right now. - # CHECK: error: 'print-op-stats' does not refer to a registered pass or pass pipeline + # CHECK: error: 'not-existing-pass' does not refer to a registered pass or pass pipeline except ValueError as e: - # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(print-op-stats{json=false}))'. + # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(not-existing-pass{json=false}))'. log("ValueError exception:", e) else: log("Exception not produced") - # This will register the pass and round-trip should be possible now. - import mlir.transforms + # A registered pass should parse successfully. pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))") # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false})) log("Roundtrip: ", pm) @@ -71,7 +69,6 @@ def testInvalidNesting(): with Context(): try: - import mlir.all_passes_registration pm = PassManager.parse("func.func(normalize-memrefs)") except ValueError as e: # CHECK: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?