diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -64,6 +64,8 @@ MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute._CAPIPtr") #define MLIR_PYTHON_CAPSULE_CONTEXT \ MAKE_MLIR_PYTHON_QUALNAME("ir.Context._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY \ + MAKE_MLIR_PYTHON_QUALNAME("ir.DialectRegistry._CAPIPtr") #define MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE \ MAKE_MLIR_PYTHON_QUALNAME("execution_engine.ExecutionEngine._CAPIPtr") #define MLIR_PYTHON_CAPSULE_INTEGER_SET \ @@ -172,6 +174,28 @@ return context; } +/** Creates a capsule object encapsulating the raw C-API MlirDialectRegistry. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the context in any way. + */ +static inline PyObject * +mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry) { + return PyCapsule_New(registry.ptr, MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY, + NULL); +} + +/** Extracts an MlirDialectRegistry from a capsule as produced from + * mlirPythonDialectRegistryToCapsule. If the capsule is not of the right type, + * then a null context is returned (as checked via mlirContextIsNull). In such a + * case, the Python APIs will have already set an error. */ +static inline MlirDialectRegistry +mlirPythonCapsuleToDialectRegistry(PyObject *capsule) { + void *ptr = + PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY); + MlirDialectRegistry registry = {ptr}; + return registry; +} + /** Creates a capsule object encapsulating the raw C-API MlirLocation. * The returned capsule does not extend or affect ownership of any Python * objects that reference the location in any way. */ diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -130,6 +130,11 @@ MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable); +/// Eagerly loads all available dialects registered with a context, making +/// them available for use for IR construction. +MLIR_CAPI_EXPORTED void +mlirContextLoadAllAvailableDialects(MlirContext context); + /// Returns whether the given fully-qualified operation (i.e. /// 'dialect.operation') is registered with the context. This will return true /// if the dialect is loaded and the operation is registered within the @@ -157,6 +162,47 @@ /// Returns the namespace of the given dialect. MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect); +//===----------------------------------------------------------------------===// +// DialectHandle API. +// Registration entry-points for each dialect are declared using the common +// MLIR_DECLARE_DIALECT_REGISTRATION_CAPI macro, which takes the dialect +// API name (i.e. "Func", "Tensor", "Linalg") and namespace (i.e. "func", +// "tensor", "linalg"). The following declarations are produced: +// +// /// Gets the above hook methods in struct form for a dialect by namespace. +// /// This is intended to facilitate dynamic lookup and registration of +// /// dialects via a plugin facility based on shared library symbol lookup. +// const MlirDialectHandle *mlirGetDialectHandle__{NAMESPACE}__(); +// +// This is done via a common macro to facilitate future expansion to +// registration schemes. +//===----------------------------------------------------------------------===// + +struct MlirDialectHandle { + const void *ptr; +}; +typedef struct MlirDialectHandle MlirDialectHandle; + +#define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace) \ + MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__() + +/// Returns the namespace associated with the provided dialect handle. +MLIR_CAPI_EXPORTED +MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle); + +/// Inserts the dialect associated with the provided dialect handle into the +/// provided dialect registry +MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, + MlirDialectRegistry); + +/// Registers the dialect associated with the provided dialect handle. +MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle, + MlirContext); + +/// Loads the dialect associated with the provided dialect handle. +MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle, + MlirContext); + //===----------------------------------------------------------------------===// // DialectRegistry API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir-c/Registration.h b/mlir/include/mlir-c/Registration.h --- a/mlir/include/mlir-c/Registration.h +++ b/mlir/include/mlir-c/Registration.h @@ -6,6 +6,12 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// This header contains registration entry points for MLIR upstream dialects +// and passes. Downstream projects typically will not want to use this unless +// if they don't care about binary size or build bloat and just wish access +// to the entire set of upstream facilities. For those that do care, they +// should use registration functions specific to their project. +//===----------------------------------------------------------------------===// #ifndef MLIR_C_REGISTRATION_H #define MLIR_C_REGISTRATION_H @@ -16,51 +22,8 @@ extern "C" { #endif -//===----------------------------------------------------------------------===// -// Dialect registration declarations. -// Registration entry-points for each dialect are declared using the common -// MLIR_DECLARE_DIALECT_REGISTRATION_CAPI macro, which takes the dialect -// API name (i.e. "Func", "Tensor", "Linalg") and namespace (i.e. "func", -// "tensor", "linalg"). The following declarations are produced: -// -// /// Gets the above hook methods in struct form for a dialect by namespace. -// /// This is intended to facilitate dynamic lookup and registration of -// /// dialects via a plugin facility based on shared library symbol lookup. -// const MlirDialectHandle *mlirGetDialectHandle__{NAMESPACE}__(); -// -// This is done via a common macro to facilitate future expansion to -// registration schemes. -//===----------------------------------------------------------------------===// - -struct MlirDialectHandle { - const void *ptr; -}; -typedef struct MlirDialectHandle MlirDialectHandle; - -#define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace) \ - MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__() - -/// Returns the namespace associated with the provided dialect handle. -MLIR_CAPI_EXPORTED -MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle); - -/// Inserts the dialect associated with the provided dialect handle into the -/// provided dialect registry -MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, - MlirDialectRegistry); - -/// Registers the dialect associated with the provided dialect handle. -MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle, - MlirContext); - -/// Loads the dialect associated with the provided dialect handle. -MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle, - MlirContext); - -/// Registers all dialects known to core MLIR with the provided Context. -/// This is needed before creating IR for these Dialects. -/// TODO: Remove this function once the real registration API is finished. -MLIR_CAPI_EXPORTED void mlirRegisterAllDialects(MlirContext context); +/// Appends all upstream dialects and extensions to the dialect registry. +MLIR_CAPI_EXPORTED void mlirRegisterAllDialects(MlirDialectRegistry registry); /// Register all translations to LLVM IR for dialects that can support it. MLIR_CAPI_EXPORTED void mlirRegisterAllLLVMTranslations(MlirContext context); diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -124,6 +124,25 @@ } }; +/// Casts object <-> MlirDialectRegistry. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); + return !mlirDialectRegistryIsNull(value); + } + static handle cast(MlirDialectRegistry v, return_value_policy, handle) { + py::object capsule = py::reinterpret_steal( + mlirPythonDialectRegistryToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("DialectRegistry") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + /// Casts object <-> MlirLocation. template <> struct type_caster { @@ -252,7 +271,7 @@ } template - pure_subclass &def(const char *name, Func &&f, const Extra &... extra) { + pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { py::cpp_function cf( std::forward(f), py::name(name), py::is_method(thisClass), py::sibling(py::getattr(thisClass, name, py::none())), extra...); @@ -262,7 +281,7 @@ template pure_subclass &def_property_readonly(const char *name, Func &&f, - const Extra &... extra) { + const Extra &...extra) { py::cpp_function cf( std::forward(f), py::name(name), py::is_method(thisClass), py::sibling(py::getattr(thisClass, name, py::none())), extra...); @@ -274,7 +293,7 @@ template pure_subclass &def_staticmethod(const char *name, Func &&f, - const Extra &... extra) { + const Extra &...extra) { static_assert(!std::is_member_function_pointer::value, "def_staticmethod(...) called with a non-static member " "function pointer"); @@ -287,7 +306,7 @@ template pure_subclass &def_classmethod(const char *name, Func &&f, - const Extra &... extra) { + const Extra &...extra) { static_assert(!std::is_member_function_pointer::value, "def_classmethod(...) called with a non-static member " "function pointer"); diff --git a/mlir/lib/Bindings/Python/AllPassesRegistration.cpp b/mlir/lib/Bindings/Python/AllPassesRegistration.cpp deleted file mode 100644 --- a/mlir/lib/Bindings/Python/AllPassesRegistration.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//===- AllPassesRegistration.cpp - Pybind module to register all passes ---===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir-c/Registration.h" - -#include - -// ----------------------------------------------------------------------------- -// Module initialization. -// ----------------------------------------------------------------------------- - -PYBIND11_MODULE(_mlirAllPassesRegistration, m) { - m.doc() = "MLIR All Passes Convenience Module"; - - // Register all passes on load. - mlirRegisterAllPasses(); -} diff --git a/mlir/lib/Bindings/Python/AllRegistration.cpp b/mlir/lib/Bindings/Python/AllRegistration.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/AllRegistration.cpp @@ -0,0 +1,26 @@ +//===- AllRegistration.cpp - API to register all dialects/passes ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Conversion.h" +#include "mlir-c/Registration.h" +#include "mlir-c/Transforms.h" + +#include "mlir/Bindings/Python/PybindAdaptors.h" + +PYBIND11_MODULE(_mlirAllRegistration, m) { + m.doc() = "MLIR All Upstream Dialects and Passes Registration"; + + m.def("register_dialects", [](MlirDialectRegistry registry) { + mlirRegisterAllDialects(registry); + }); + + // Register all passes on load. + mlirRegisterAllPasses(); + mlirRegisterConversionPasses(); + mlirRegisterTransformsPasses(); +} diff --git a/mlir/lib/Bindings/Python/Conversions/Conversions.cpp b/mlir/lib/Bindings/Python/Conversions/Conversions.cpp deleted file mode 100644 --- a/mlir/lib/Bindings/Python/Conversions/Conversions.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//===- Conversions.cpp - Pybind module for the Conversionss library -------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir-c/Conversion.h" - -#include - -// ----------------------------------------------------------------------------- -// Module initialization. -// ----------------------------------------------------------------------------- - -PYBIND11_MODULE(_mlirConversions, m) { - m.doc() = "MLIR Conversions library"; - - // Register all the passes in the Conversions library on load. - mlirRegisterConversionPasses(); -} diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -16,7 +16,7 @@ #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" +//#include "mlir-c/Registration.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -474,7 +474,6 @@ PyMlirContext *PyMlirContext::createNewContextForInit() { MlirContext context = mlirContextCreate(); - mlirRegisterAllDialects(context); return new PyMlirContext(context); } @@ -793,7 +792,7 @@ } //------------------------------------------------------------------------------ -// PyDialect, PyDialectDescriptor, PyDialects +// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry //------------------------------------------------------------------------------ MlirDialect PyDialects::getDialectForKey(const std::string &key, @@ -807,6 +806,19 @@ return dialect; } +py::object PyDialectRegistry::getCapsule() { + return py::reinterpret_steal( + mlirPythonDialectRegistryToCapsule(*this)); +} + +PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) { + MlirDialectRegistry rawRegistry = + mlirPythonCapsuleToDialectRegistry(capsule.ptr()); + if (mlirDialectRegistryIsNull(rawRegistry)) + throw py::error_already_set(); + return PyDialectRegistry(rawRegistry); +} + //------------------------------------------------------------------------------ // PyLocation //------------------------------------------------------------------------------ @@ -2276,7 +2288,16 @@ return mlirContextIsRegisteredOperation( self.get(), MlirStringRef{name.data(), name.size()}); }, - py::arg("operation_name")); + py::arg("operation_name")) + .def( + "append_dialect_registry", + [](PyMlirContext &self, PyDialectRegistry ®istry) { + mlirContextAppendDialectRegistry(self.get(), registry); + }, + py::arg("registry")) + .def("load_all_available_dialects", [](PyMlirContext &self) { + mlirContextLoadAllAvailableDialects(self.get()); + }); //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor @@ -2331,6 +2352,15 @@ clazz.attr("__name__") + py::str(")>"); }); + //---------------------------------------------------------------------------- + // Mapping of PyDialectRegistry + //---------------------------------------------------------------------------- + py::class_(m, "DialectRegistry", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyDialectRegistry::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) + .def(py::init<>()); + //---------------------------------------------------------------------------- // Mapping of Location //---------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -390,6 +390,32 @@ pybind11::object descriptor; }; +/// Wrapper around an MlirDialectRegistry. +/// Upon construction, the Python wrapper takes ownership of the +/// underlying MlirDialectRegistry. +class PyDialectRegistry { +public: + PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {} + PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {} + ~PyDialectRegistry() { + if (!mlirDialectRegistryIsNull(registry)) + mlirDialectRegistryDestroy(registry); + } + PyDialectRegistry(PyDialectRegistry &) = delete; + PyDialectRegistry(PyDialectRegistry &&other) : registry(other.registry) { + other.registry = {nullptr}; + } + + operator MlirDialectRegistry() const { return registry; } + MlirDialectRegistry get() const { return registry; } + + pybind11::object getCapsule(); + static PyDialectRegistry createFromCapsule(pybind11::object capsule); + +private: + MlirDialectRegistry registry; +}; + /// Wrapper around an MlirLocation. class PyLocation : public BaseContextObject { public: diff --git a/mlir/lib/Bindings/Python/Transforms/Transforms.cpp b/mlir/lib/Bindings/Python/Transforms/Transforms.cpp deleted file mode 100644 --- a/mlir/lib/Bindings/Python/Transforms/Transforms.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//===- Transforms.cpp - Pybind module for the Transforms library ----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir-c/Transforms.h" - -#include - -// ----------------------------------------------------------------------------- -// Module initialization. -// ----------------------------------------------------------------------------- - -PYBIND11_MODULE(_mlirTransforms, m) { - m.doc() = "MLIR Transforms library"; - - // Register all the passes in the Transforms library on load. - mlirRegisterTransformsPasses(); -} diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -77,6 +77,10 @@ return unwrap(context)->enableMultithreading(enable); } +void mlirContextLoadAllAvailableDialects(MlirContext context) { + unwrap(context)->loadAllAvailableDialects(); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/Registration/Registration.cpp b/mlir/lib/CAPI/Registration/Registration.cpp --- a/mlir/lib/CAPI/Registration/Registration.cpp +++ b/mlir/lib/CAPI/Registration/Registration.cpp @@ -13,10 +13,8 @@ #include "mlir/InitAllPasses.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -void mlirRegisterAllDialects(MlirContext context) { - mlir::registerAllDialects(*unwrap(context)); - // TODO: we may not want to eagerly load here. - unwrap(context)->loadAllAvailableDialects(); +void mlirRegisterAllDialects(MlirDialectRegistry registry) { + mlir::registerAllDialects(*unwrap(registry)); } void mlirRegisterAllLLVMTranslations(MlirContext context) { diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -37,15 +37,6 @@ runtime/*.py ) -declare_mlir_python_sources(MLIRPythonSources.Passes - ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" - ADD_TO_PARENT MLIRPythonSources - SOURCES_GLOB - all_passes_registration/*.py - conversions/*.py - transforms/*.py -) - declare_mlir_python_sources(MLIRPythonCAPIHeaderSources ROOT_DIR "${MLIR_SOURCE_DIR}/include" SOURCES_GLOB "mlir-c/*.h" @@ -284,12 +275,31 @@ MLIRCAPIDebug MLIRCAPIIR MLIRCAPIInterfaces - MLIRCAPIRegistration # TODO: See about dis-aggregating # Dialects MLIRCAPIFunc ) +# This extension exposes an API to register all dialects, extensions, and passes +# packaged in upstream MLIR and it is used for the upstream "mlir" Python +# package. Downstreams will likely want to provide their own and not depend +# on this one, since it links in the world. +declare_mlir_python_extension(MLIRPythonExtension.AllRegistration + MODULE_NAME _mlirAllRegistration + # TODO: Remove this once downstreams are configured to provide their + # own registration hooks. Retaining here for now to preserve NFC status + # with respect to the transition. See: + # https://github.com/llvm/llvm-project/issues/56037 + ADD_TO_PARENT MLIRPythonExtension.Core + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + AllRegistration.cpp + EMBED_CAPI_LINK_LIBS + MLIRCAPIRegistration + MLIRCAPIConversion + MLIRCAPITransforms +) + declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind MODULE_NAME _mlirDialectsLinalg ADD_TO_PARENT MLIRPythonSources.Dialects.linalg @@ -342,18 +352,6 @@ MLIRCAPISparseTensor ) -declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration - MODULE_NAME _mlirAllPassesRegistration - ROOT_DIR "${PYTHON_SOURCE_DIR}" - SOURCES - AllPassesRegistration.cpp - PRIVATE_LINK_LIBS - LLVMSupport - EMBED_CAPI_LINK_LIBS - MLIRCAPIConversion - MLIRCAPITransforms -) - declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect @@ -366,18 +364,6 @@ MLIRCAPIAsync ) -declare_mlir_python_extension(MLIRPythonExtension.Conversions - MODULE_NAME _mlirConversions - ADD_TO_PARENT MLIRPythonSources.Passes - ROOT_DIR "${PYTHON_SOURCE_DIR}" - SOURCES - Conversions/Conversions.cpp - PRIVATE_LINK_LIBS - LLVMSupport - EMBED_CAPI_LINK_LIBS - MLIRCAPIConversion -) - # Only enable the ExecutionEngine if the native target is configured in. if(TARGET ${LLVM_NATIVE_ARCH}) declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine @@ -429,18 +415,6 @@ MLIRCAPISparseTensor ) -declare_mlir_python_extension(MLIRPythonExtension.Transforms - MODULE_NAME _mlirTransforms - ADD_TO_PARENT MLIRPythonSources.Passes - ROOT_DIR "${PYTHON_SOURCE_DIR}" - SOURCES - Transforms/Transforms.cpp - PRIVATE_LINK_LIBS - LLVMSupport - EMBED_CAPI_LINK_LIBS - MLIRCAPITransforms -) - # TODO: Figure out how to put this in the test tree. # This should not be included in the main Python extension. However, # putting it into MLIRPythonTestSources along with the dialect declaration @@ -505,7 +479,6 @@ RELATIVE_INSTALL_ROOT "../../../.." DECLARED_SOURCES MLIRPythonSources - MLIRPythonExtension.AllPassesRegistration ${_ADDL_TEST_SOURCES} ) @@ -519,7 +492,6 @@ INSTALL_PREFIX "python_packages/mlir_core/mlir" DECLARED_SOURCES MLIRPythonSources - MLIRPythonExtension.AllPassesRegistration MLIRPythonCAPIHeaderSources ${_ADDL_TEST_SOURCES} COMMON_CAPI_LINK_LIBS diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -8,7 +8,6 @@ _this_dir = os.path.dirname(__file__) - # These submodules have no type stubs and are thus opaque to the type checker. _mlirConversions: Any _mlirTransforms: Any @@ -31,3 +30,34 @@ not be present. """ return [os.path.join(_this_dir, "include")] + + +# Do some monkey-patching of the Context initializer to match prior behavior +# from when all dialects were registered *and loaded* in a hard-coded way +# as part of the Context() constructor. +# See: https://github.com/llvm/llvm-project/issues/56037 +# Now the registration hook for all dialect registration is in its own +# (optional) extension (_mlirAllDialectsRegistration). In order to match +# existing behavior during transition, we probe for that extension and patch +# the Context class to register/load from it. In the future, there will be +# a dedicated AllDialectsContext or equivalent and this fallback will be +# removed, but we are retaining it for now in order to sequence upgrades +# to the new way as an NFC. +def _monkey_patch_dialect_registration(): + import logging + from ._mlir import ir + from . import _mlirAllRegistration + + registry = ir.DialectRegistry() + _mlirAllRegistration.register_dialects(registry) + + class Context(ir.Context): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.append_dialect_registry(registry) + self.load_all_available_dialects() + + ir.Context = Context + + +_monkey_patch_dialect_registration() diff --git a/mlir/python/mlir/all_passes_registration/__init__.py b/mlir/python/mlir/all_passes_registration/__init__.py deleted file mode 100644 --- a/mlir/python/mlir/all_passes_registration/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from .._mlir_libs import _mlirAllPassesRegistration as _cextAllPasses diff --git a/mlir/python/mlir/conversions/__init__.py b/mlir/python/mlir/conversions/__init__.py deleted file mode 100644 --- a/mlir/python/mlir/conversions/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Expose the corresponding C-Extension module with a well-known name at this -# level. -from .._mlir_libs import _mlirConversions as _cextConversions diff --git a/mlir/python/mlir/transforms/__init__.py b/mlir/python/mlir/transforms/__init__.py deleted file mode 100644 --- a/mlir/python/mlir/transforms/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Expose the corresponding C-Extension module with a well-known name at this -# level. -from .._mlir_libs import _mlirTransforms as _cextTransforms diff --git a/mlir/test/CAPI/execution_engine.c b/mlir/test/CAPI/execution_engine.c --- a/mlir/test/CAPI/execution_engine.c +++ b/mlir/test/CAPI/execution_engine.c @@ -23,6 +23,13 @@ #include #include +static void registerAllUpstreamDialects(MlirContext ctx) { + MlirDialectRegistry registry = mlirDialectRegistryCreate(); + mlirRegisterAllDialects(registry); + mlirContextAppendDialectRegistry(ctx, registry); + mlirDialectRegistryDestroy(registry); +} + void lowerModuleToLLVM(MlirContext ctx, MlirModule module) { MlirPassManager pm = mlirPassManagerCreate(ctx); MlirOpPassManager opm = mlirPassManagerGetNestedUnder( @@ -41,7 +48,8 @@ // CHECK-LABEL: Running test 'testSimpleExecution' void testSimpleExecution() { MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); + MlirModule module = mlirModuleCreateParse( ctx, mlirStringRefCreateFromCString( // clang-format off diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -28,6 +28,13 @@ #include #include +static void registerAllUpstreamDialects(MlirContext ctx) { + MlirDialectRegistry registry = mlirDialectRegistryCreate(); + mlirRegisterAllDialects(registry); + mlirContextAppendDialectRegistry(ctx, registry); + mlirDialectRegistryDestroy(registry); +} + void populateLoopBody(MlirContext ctx, MlirBlock loopBody, MlirLocation location, MlirBlock funcBody) { MlirValue iv = mlirBlockGetArgument(loopBody, 0); @@ -1646,7 +1653,9 @@ // CHECK-LABEL: @testOperands MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); + + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("arith")); mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("test")); MlirLocation loc = mlirLocationUnknownGet(ctx); MlirType indexType = mlirIndexTypeGet(ctx); @@ -1714,7 +1723,8 @@ // CHECK-LABEL: @testClone MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func")); MlirLocation loc = mlirLocationUnknownGet(ctx); MlirType indexType = mlirIndexTypeGet(ctx); @@ -2036,7 +2046,12 @@ int main() { MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func")); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("memref")); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("shape")); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("scf")); + if (constructAndTraverseIr(ctx)) return 1; buildWithInsertionsAndPrint(ctx); diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -22,9 +22,16 @@ #include #include +static void registerAllUpstreamDialects(MlirContext ctx) { + MlirDialectRegistry registry = mlirDialectRegistryCreate(); + mlirRegisterAllDialects(registry); + mlirContextAppendDialectRegistry(ctx, registry); + mlirDialectRegistryDestroy(registry); +} + void testRunPassOnModule() { MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); MlirModule module = mlirModuleCreateParse( ctx, @@ -62,7 +69,7 @@ void testRunPassOnNestedModule() { MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); MlirModule module = mlirModuleCreateParse( ctx, @@ -264,7 +271,7 @@ void testExternalPass() { MlirContext ctx = mlirContextCreate(); - mlirRegisterAllDialects(ctx); + registerAllUpstreamDialects(ctx); MlirModule module = mlirModuleCreateParse( ctx, diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -62,7 +62,6 @@ def lowerToLLVM(module): - import mlir.conversions pm = PassManager.parse( "convert-complex-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts") pm.run(module) diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -186,10 +186,6 @@ def transform(module, boilerplate): - import mlir.conversions - import mlir.all_passes_registration - import mlir.transforms - # TODO: Allow cloning functions from one module to another. # Atm we have to resort to string concatenation. ops = module.operation.regions[0].blocks[0].operations diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -33,20 +33,18 @@ # CHECK-LABEL: TEST: testParseSuccess def testParseSuccess(): with Context(): - # A first import is expected to fail because the pass isn't registered - # until we import mlir.transforms + # An unregistered pass should not parse. try: - pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))") + pm = PassManager.parse("builtin.module(func.func(not-existing-pass{json=false}))") # TODO: this error should be propagate to Python but the C API does not help right now. - # CHECK: error: 'print-op-stats' does not refer to a registered pass or pass pipeline + # CHECK: error: 'not-existing-pass' does not refer to a registered pass or pass pipeline except ValueError as e: - # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(print-op-stats{json=false}))'. + # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(not-existing-pass{json=false}))'. log("ValueError exception:", e) else: log("Exception not produced") - # This will register the pass and round-trip should be possible now. - import mlir.transforms + # A registered pass should parse successfully. pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))") # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false})) log("Roundtrip: ", pm) @@ -71,7 +69,6 @@ def testInvalidNesting(): with Context(): try: - import mlir.all_passes_registration pm = PassManager.parse("func.func(normalize-memrefs)") except ValueError as e: # CHECK: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest? diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp --- a/mlir/tools/mlir-pdll/mlir-pdll.cpp +++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp @@ -119,6 +119,11 @@ // options as static variables.. some of which overlap with our options. llvm::cl::ResetCommandLineParser(); + // The tablegen main and CMake rules add this flag so we must accept it. + llvm::cl::opt noWarnOnUnusedTemplateArg( + "no-warn-on-unused-template-args", + llvm::cl::desc("Disable unused template argument warnings.")); + llvm::cl::opt inputFilename( llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-"), llvm::cl::value_desc("filename"));