diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -453,6 +453,61 @@ } }; +/// Creates a custom subclass of mlir.ir.Type, implementing a casting +/// constructor and type checking methods. +class mlir_value_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirValue); + + /// Subclasses by looking up the super-class dynamically. + mlir_value_subclass(py::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction) + : mlir_value_subclass( + scope, valueClassName, isaFunction, + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) {} + + /// Subclasses with a provided mlir.ir.Value super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_value_subclass(py::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction, const py::object &superCls) + : pure_subclass(scope, valueClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureValueName( + valueClassName); // As string in case if valueClassName is not static. + py::cpp_function newCf( + [superCls, isaFunction, captureValueName](py::object cls, + py::object otherValue) { + MlirValue rawValue = py::cast(otherValue); + if (!isaFunction(rawValue)) { + auto origRepr = py::repr(otherValue).cast(); + throw std::invalid_argument((llvm::Twine("Cannot cast value to ") + + captureValueName + " (from " + + origRepr + ")") + .str()); + } + py::object self = superCls.attr("__new__")(cls, otherValue); + return self; + }, + py::name("__new__"), py::arg("cls"), py::arg("cast_from_value")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirValue other) { return isaFunction(other); }, + py::arg("other_value")); + } +}; + } // namespace adaptors } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/include/mlir/Bindings/Python/PybindUtils.h rename from mlir/lib/Bindings/Python/PybindUtils.h rename to mlir/include/mlir/Bindings/Python/PybindUtils.h diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -13,7 +13,7 @@ #include #include -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -10,7 +10,7 @@ #include "IRModule.h" -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Bindings/Python/Interop.h" diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -11,7 +11,7 @@ #include "IRModule.h" -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -9,7 +9,7 @@ #include "IRModule.h" #include "Globals.h" -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" @@ -3260,6 +3260,7 @@ // Mapping of Value. //---------------------------------------------------------------------------- py::class_(m, "Value", py::module_local()) + .def(py::init(), py::keep_alive<0, 1>(), py::arg("value")) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) .def_property_readonly( diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -13,7 +13,7 @@ #include #include -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -8,7 +8,7 @@ #include "IRModule.h" #include "Globals.h" -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include #include diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -8,7 +8,7 @@ #include "IRModule.h" -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -8,7 +8,7 @@ #include -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "Globals.h" #include "IRModule.h" diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h --- a/mlir/lib/Bindings/Python/Pass.h +++ b/mlir/lib/Bindings/Python/Pass.h @@ -9,7 +9,7 @@ #ifndef MLIR_BINDINGS_PYTHON_PASS_H #define MLIR_BINDINGS_PYTHON_PASS_H -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" namespace mlir { namespace python { diff --git a/mlir/lib/Bindings/Python/PybindUtils.cpp b/mlir/lib/Bindings/Python/PybindUtils.cpp --- a/mlir/lib/Bindings/Python/PybindUtils.cpp +++ b/mlir/lib/Bindings/Python/PybindUtils.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" pybind11::error_already_set mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) { diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * -from .._mlir_libs._mlirPythonTest import TestAttr, TestType +from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue def register_python_test_dialect(context, load=True): from .._mlir_libs import _mlirPythonTest diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -2,6 +2,8 @@ from mlir.ir import * import mlir.dialects.python_test as test +import mlir.dialects.arith as arith +import mlir.dialects.tensor as tensor def run(f): print("\nTEST:", f.__name__) @@ -302,3 +304,42 @@ pass else: raise + + +@run +# CHECK-LABEL: TEST: testTensorValue +def testTensorValue(): + with Context() as ctx, Location.unknown(): + test.register_python_test_dialect(ctx) + + i8 = IntegerType.get_signless(8) + + class Tensor(test.TestTensorValue): + def __getitem__(self, dims): + dims = list(dims) + for i, d in enumerate(dims): + if isinstance(d, int): + dims[i] = arith.ConstantOp.create_index(d) + + return tensor.ExtractOp(self, dims).result + + module = Module.create() + with InsertionPoint(module.body): + t = tensor.EmptyOp([10, 10], i8).result + + # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>) + print(Value(t)) + + tt = Tensor(t) + # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>) + print(tt) + + el = tt[0, 0] + + # CHECK: module { + # CHECK: %[[T0:.*]] = tensor.empty() : tensor<10x10xi8> + # CHECK: %[[C0:.*]] = arith.constant 0 : index + # CHECK: %[[C00:.*]] = arith.constant 0 : index + # CHECK: %[[RES:.*]] = tensor.extract %[[T0]][%[[C0]], %[[C00]]] : tensor<10x10xi8> + # CHECK: } + print(module) \ No newline at end of file diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h --- a/mlir/test/python/lib/PythonTestCAPI.h +++ b/mlir/test/python/lib/PythonTestCAPI.h @@ -27,6 +27,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context); +MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value); + #ifdef __cplusplus } #endif diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp --- a/mlir/test/python/lib/PythonTestCAPI.cpp +++ b/mlir/test/python/lib/PythonTestCAPI.cpp @@ -8,6 +8,7 @@ #include "PythonTestCAPI.h" #include "PythonTestDialect.h" +#include "mlir-c/BuiltinTypes.h" #include "mlir/CAPI/Registration.h" #include "mlir/CAPI/Wrap.h" @@ -29,3 +30,7 @@ MlirType mlirPythonTestTestTypeGet(MlirContext context) { return wrap(python_test::TestTypeType::get(unwrap(context))); } + +bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) { + return mlirTypeIsATensor(wrap(unwrap(value).getType())); +} \ No newline at end of file diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -8,6 +8,7 @@ #include "PythonTestCAPI.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/PybindUtils.h" namespace py = pybind11; using namespace mlir::python::adaptors; @@ -40,4 +41,14 @@ return cls(mlirPythonTestTestTypeGet(ctx)); }, py::arg("cls"), py::arg("context") = py::none()); + mlir_value_subclass(m, "TestTensorValue", + mlirTypeIsAPythonTestTestTensorValue) + .def("__str__", [](MlirValue &self) { + mlir::PyPrintAccumulator printAccum; + printAccum.parts.append("Tensor("); + mlirValuePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); }