diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -453,6 +453,62 @@ } }; +/// Creates a custom subclass of mlir.ir.Value, implementing a casting +/// constructor and type checking methods. +class mlir_value_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirValue); + + /// Subclasses by looking up the super-class dynamically. + mlir_value_subclass(py::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction) + : mlir_value_subclass( + scope, valueClassName, isaFunction, + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) { + } + + /// Subclasses with a provided mlir.ir.Value super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_value_subclass(py::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction, const py::object &superCls) + : pure_subclass(scope, valueClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureValueName( + valueClassName); // As string in case if valueClassName is not static. + py::cpp_function newCf( + [superCls, isaFunction, captureValueName](py::object cls, + py::object otherValue) { + MlirValue rawValue = py::cast(otherValue); + if (!isaFunction(rawValue)) { + auto origRepr = py::repr(otherValue).cast(); + throw std::invalid_argument((llvm::Twine("Cannot cast value to ") + + captureValueName + " (from " + + origRepr + ")") + .str()); + } + py::object self = superCls.attr("__new__")(cls, otherValue); + return self; + }, + py::name("__new__"), py::arg("cls"), py::arg("cast_from_value")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirValue other) { return isaFunction(other); }, + py::arg("other_value")); + } +}; + } // namespace adaptors } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3260,6 +3260,7 @@ // Mapping of Value. //---------------------------------------------------------------------------- py::class_(m, "Value", py::module_local()) + .def(py::init(), py::keep_alive<0, 1>(), py::arg("value")) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) .def_property_readonly( diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * -from .._mlir_libs._mlirPythonTest import TestAttr, TestType +from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue def register_python_test_dialect(context, load=True): from .._mlir_libs import _mlirPythonTest diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -2,6 +2,7 @@ from mlir.ir import * import mlir.dialects.python_test as test +import mlir.dialects.tensor as tensor def run(f): print("\nTEST:", f.__name__) @@ -302,3 +303,30 @@ pass else: raise + + +@run +# CHECK-LABEL: TEST: testTensorValue +def testTensorValue(): + with Context() as ctx, Location.unknown(): + test.register_python_test_dialect(ctx) + + i8 = IntegerType.get_signless(8) + + class Tensor(test.TestTensorValue): + def __str__(self): + return super().__str__().replace("Value", "Tensor") + + module = Module.create() + with InsertionPoint(module.body): + t = tensor.EmptyOp([10, 10], i8).result + + # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>) + print(Value(t)) + + tt = Tensor(t) + # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>) + print(tt) + + # CHECK: False + print(tt.is_null()) diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h --- a/mlir/test/python/lib/PythonTestCAPI.h +++ b/mlir/test/python/lib/PythonTestCAPI.h @@ -27,6 +27,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context); +MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value); + #ifdef __cplusplus } #endif diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp --- a/mlir/test/python/lib/PythonTestCAPI.cpp +++ b/mlir/test/python/lib/PythonTestCAPI.cpp @@ -8,6 +8,7 @@ #include "PythonTestCAPI.h" #include "PythonTestDialect.h" +#include "mlir-c/BuiltinTypes.h" #include "mlir/CAPI/Registration.h" #include "mlir/CAPI/Wrap.h" @@ -29,3 +30,7 @@ MlirType mlirPythonTestTestTypeGet(MlirContext context) { return wrap(python_test::TestTypeType::get(unwrap(context))); } + +bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) { + return mlirTypeIsATensor(wrap(unwrap(value).getType())); +} diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -40,4 +40,7 @@ return cls(mlirPythonTestTestTypeGet(ctx)); }, py::arg("cls"), py::arg("context") = py::none()); + mlir_value_subclass(m, "TestTensorValue", + mlirTypeIsAPythonTestTestTensorValue) + .def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); }); }