diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -444,6 +444,7 @@ }, py::name("__new__"), py::arg("cls"), py::arg("cast_from_type")); thisClass.attr("__new__") = newCf; + thisClass.attr("__name__") = captureTypeName; // 'isinstance' method. def_staticmethod( diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -64,6 +64,9 @@ void registerAttributeBuilder(const std::string &attributeKind, pybind11::function pyFunc); + void registerValueBuilder(const std::string &typeRegex, + pybind11::function pyFunc); + /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. @@ -80,6 +83,8 @@ std::optional lookupAttributeBuilder(const std::string &attributeKind); + std::optional lookupValueBuilder(const std::string &type); + /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. std::optional @@ -102,6 +107,8 @@ /// Map of attribute ODS name to custom builder. llvm::StringMap attributeBuilderMap; + std::map valueBuilderMap; + /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModulesCache; diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/SmallVector.h" #include +#include #include namespace py = pybind11; @@ -230,6 +231,28 @@ } }; +struct PyValueBuilderMap { + static bool dunderContains(py::object type) { + return PyGlobals::get().lookupValueBuilder(py::str(type)).has_value(); + } + static py::function dundeGetItemNamed(py::object type) { + auto builder = PyGlobals::get().lookupValueBuilder(py::str(type)); + if (!builder) + throw py::key_error(); + return *builder; + } + static void dundeSetItemNamed(pybind11::object typeRegex, py::function func) { + PyGlobals::get().registerValueBuilder(py::str(typeRegex), std::move(func)); + } + + static void bind(py::module &m) { + py::class_(m, "ValueBuilder", py::module_local()) + .def_static("contains", &PyValueBuilderMap::dunderContains) + .def_static("get", &PyValueBuilderMap::dundeGetItemNamed) + .def_static("insert", &PyValueBuilderMap::dundeSetItemNamed); + } +}; + //------------------------------------------------------------------------------ // Collections. //------------------------------------------------------------------------------ @@ -3417,6 +3440,7 @@ // Attribute builder getter. PyAttrBuilderMap::bind(m); + PyValueBuilderMap::bind(m); py::register_local_exception_translator([](std::exception_ptr p) { // We can't define exceptions with custom fields through pybind, so instead diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -10,8 +10,9 @@ #include "Globals.h" #include "PybindUtils.h" -#include #include +#include +#include #include "mlir-c/Bindings/Python/Interop.h" @@ -72,6 +73,16 @@ found = std::move(pyFunc); } +void PyGlobals::registerValueBuilder(const std::string &typeRegex, + py::function pyFunc) { + py::object &found = valueBuilderMap[typeRegex]; + if (found) { + throw std::runtime_error( + (llvm::Twine("Value builder for '") + "' is already registered").str()); + } + found = std::move(pyFunc); +} + void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, py::object pyClass) { py::object &found = dialectClassMap[dialectNamespace]; @@ -110,6 +121,24 @@ return std::nullopt; } +std::optional +PyGlobals::lookupValueBuilder(const std::string &type) { + const auto foundIt = std::find_if( + valueBuilderMap.begin(), valueBuilderMap.end(), [&type](auto pair) { + return std::regex_match(type, std::regex(pair.first)); + }); + if (foundIt != valueBuilderMap.end()) { + if (foundIt->second.is_none()) + return std::nullopt; + assert(foundIt->second && "py::function is defined"); + return foundIt->second; + } + + // Not found and loading did not yield a registration. Negative cache. + valueBuilderMap[type] = py::none(); + return std::nullopt; +} + std::optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { loadDialectModule(dialectNamespace); diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -4,7 +4,9 @@ # Provide a convenient name for sub-packages to resolve the main C-extension # with a relative import. +from ..ir import Value, ValueBuilder from .._mlir_libs import _mlir as _cext +import inspect from typing import Sequence as _Sequence, Union as _Union __all__ = [ @@ -17,6 +19,31 @@ ] +class ValueMeta(type(Value)): + + def __call__(cls, *args, **kwargs): + cls_obj = cls.__new__(cls, *args, **kwargs) + cls.__init__(cls_obj, *args, **kwargs) + if len(cls_obj.results) == 1: + val = get_op_result_or_value(cls_obj) + if ValueBuilder.contains(val.type): + return ValueBuilder.get(val.type)(val) + return cls_obj + + +def rebuild_with_meta(parent_opview_cls, mixin=False): + v = ValueMeta(f"{parent_opview_cls.__name__}", + parent_opview_cls.__bases__, dict(parent_opview_cls.__dict__)) + + # mixins (extensions) for some reasons don't suffer from this problem + # i.e., the __class__ is the correctly patched/hacked one + if not mixin: + # some ops don't have __init__ but one is inherited from OpView (as an instancemethod) + if not inspect.ismethoddescriptor(v.__init__): + v.__init__.__closure__[0].cell_contents = v + return v + + def extend_opview_class(ext_module): """Decorator to extend an OpView class from an extension module. @@ -41,7 +68,7 @@ def class_decorator(parent_opview_cls: type): if ext_module is None: - return parent_opview_cls + return rebuild_with_meta(parent_opview_cls) mixin_cls = NotImplemented # First try to resolve by name. try: @@ -56,7 +83,7 @@ mixin_cls = select_mixin(parent_opview_cls) if mixin_cls is NotImplemented or mixin_cls is None: - return parent_opview_cls + return rebuild_with_meta(parent_opview_cls) # Have a mixin_cls. Create an appropriate subclass. try: @@ -65,10 +92,10 @@ pass except TypeError as e: raise TypeError( - f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e + f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e LocalOpView.__name__ = parent_opview_cls.__name__ LocalOpView.__qualname__ = parent_opview_cls.__qualname__ - return LocalOpView + return rebuild_with_meta(LocalOpView, mixin=True) return class_decorator diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -94,3 +94,12 @@ except ImportError: pass + + +def register_value_builder(kind): + + def decorator_builder(func): + ValueBuilder.insert(kind, func) + return func + + return decorator_builder \ No newline at end of file diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -2,7 +2,7 @@ import gc from mlir.ir import * -from mlir.dialects import func +from mlir.dialects import func, arith, tensor def run(f): @@ -232,3 +232,53 @@ value2.owner.detach_from_parent() # CHECK: %0 print(value2.get_name()) + + +# CHECK-LABEL: TEST: testValueBuilders +@run +def testValueBuilders(): + + class ScalarOrTensor: + def __init__(self, val): + self.val = val + + def __add__(self, other): + return arith.AddFOp(self.val, other.val) + + def __str__(self): + return str(self.val).replace("Value", "ScalarOrTensor") + + @register_value_builder("f64") + @register_value_builder("tensor<.*>") + def _buildF64Value(val): + return ScalarOrTensor(val) + + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + f64 = F64Type.get() + module = Module.create() + with InsertionPoint(module.body): + one_f64 = arith.ConstantOp(f64, 1.0) + # CHECK: ScalarOrTensor(%[[ONE:.*]] = arith.constant 1.000000e+00 : f64) + print(one_f64) + + two_f64 = arith.ConstantOp(f64, 2.0) + # CHECK: ScalarOrTensor(%[[TWO:.*]] = arith.constant 2.000000e+00 : f64) + print(two_f64) + + three_f64 = one_f64 + two_f64 + # CHECK: ScalarOrTensor(%[[THREE:.*]] = arith.addf %[[ONE:.*]], %[[TWO:.*]] : f64) + print(three_f64) + + a_ten_f64 = tensor.EmptyOp([10, 10], f64) + # CHECK: ScalarOrTensor(%[[ONE:.*]] = tensor.empty() : tensor<10x10xf64>) + print(a_ten_f64) + b_ten_f64 = tensor.EmptyOp([10, 10], f64) + # CHECK: ScalarOrTensor(%[[ONE:.*]] = tensor.empty() : tensor<10x10xf64>) + print(b_ten_f64) + c_ten_f64 = a_ten_f64 + b_ten_f64 + # CHECK: ScalarOrTensor(%[[ONE:.*]] = arith.addf %1, %2 : tensor<10x10xf64>) + print(c_ten_f64) + +