diff --git a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake --- a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake +++ b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake @@ -32,7 +32,7 @@ message(STATUS "Found python libraries: ${Python3_LIBRARIES}") message(STATUS "Found numpy v${Python3_NumPy_VERSION}: ${Python3_NumPy_INCLUDE_DIRS}") mlir_detect_pybind11_install() - find_package(pybind11 2.6 CONFIG REQUIRED) + find_package(pybind11 2.8 CONFIG REQUIRED) message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIR}") message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', " "suffix = '${PYTHON_MODULE_SUFFIX}', " diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -314,31 +314,34 @@ /// as the mlir.ir class (otherwise, it will trigger a recursive /// initialization). mlir_attribute_subclass(py::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, - const py::object &superClass) - : pure_subclass(scope, typeClassName, superClass) { - // Casting constructor. Note that defining an __init__ method is special - // and not yet generalized on pure_subclass (it requires a somewhat - // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). + IsAFunctionTy isaFunction, const py::object &superCls) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. - py::cpp_function initCf( - [superClass, isaFunction, captureTypeName](py::object self, - py::object otherType) { - MlirAttribute rawAttribute = py::cast(otherType); + py::cpp_function newCf( + [superCls, isaFunction, captureTypeName](py::object cls, + py::object otherAttribute) { + MlirAttribute rawAttribute = py::cast(otherAttribute); if (!isaFunction(rawAttribute)) { - auto origRepr = py::repr(otherType).cast(); + auto origRepr = py::repr(otherAttribute).cast(); throw std::invalid_argument( (llvm::Twine("Cannot cast attribute to ") + captureTypeName + " (from " + origRepr + ")") .str()); } - superClass.attr("__init__")(self, otherType); + py::object self = superCls.attr("__new__")(cls, otherAttribute); + return self; }, - py::arg("cast_from_type"), py::is_method(py::none()), - "Casts the passed type to this specific sub-type."); - thisClass.attr("__init__") = initCf; + py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr")); + thisClass.attr("__new__") = newCf; // 'isinstance' method. def_staticmethod( @@ -366,17 +369,21 @@ /// as the mlir.ir class (otherwise, it will trigger a recursive /// initialization). mlir_type_subclass(py::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, const py::object &superClass) - : pure_subclass(scope, typeClassName, superClass) { - // Casting constructor. Note that defining an __init__ method is special - // and not yet generalized on pure_subclass (it requires a somewhat - // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). + IsAFunctionTy isaFunction, const py::object &superCls) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. - py::cpp_function initCf( - [superClass, isaFunction, captureTypeName](py::object self, - py::object otherType) { + py::cpp_function newCf( + [superCls, isaFunction, captureTypeName](py::object cls, + py::object otherType) { MlirType rawType = py::cast(otherType); if (!isaFunction(rawType)) { auto origRepr = py::repr(otherType).cast(); @@ -385,11 +392,11 @@ origRepr + ")") .str()); } - superClass.attr("__init__")(self, otherType); + py::object self = superCls.attr("__new__")(cls, otherType); + return self; }, - py::arg("cast_from_type"), py::is_method(py::none()), - "Casts the passed type to this specific sub-type."); - thisClass.attr("__init__") = initCf; + py::name("__new__"), py::arg("cls"), py::arg("cast_from_type")); + thisClass.attr("__new__") = newCf; // 'isinstance' method. def_staticmethod( diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * - +from .._mlir_libs._mlirPythonTest import TestAttr, TestType def register_python_test_dialect(context, load=True): from .._mlir_libs import _mlirPythonTest diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,3 @@ numpy -# Version 2.7.0 excluded: https://github.com/pybind/pybind11/issues/3136 -pybind11>=2.6.0,!=2.7.0 +pybind11>=2.8.0 PyYAML diff --git a/mlir/test/python/CMakeLists.txt b/mlir/test/python/CMakeLists.txt --- a/mlir/test/python/CMakeLists.txt +++ b/mlir/test/python/CMakeLists.txt @@ -3,6 +3,10 @@ mlir_tablegen(lib/PythonTestDialect.cpp.inc -gen-dialect-defs) mlir_tablegen(lib/PythonTestOps.h.inc -gen-op-decls) mlir_tablegen(lib/PythonTestOps.cpp.inc -gen-op-defs) +mlir_tablegen(lib/PythonTestAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(lib/PythonTestAttributes.cpp.inc -gen-attrdef-defs) +mlir_tablegen(lib/PythonTestTypes.h.inc -gen-typedef-decls) +mlir_tablegen(lib/PythonTestTypes.cpp.inc -gen-typedef-defs) add_public_tablegen_target(MLIRPythonTestIncGen) add_subdirectory(lib) diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -225,3 +225,62 @@ op2 = test.OptionalOperandOp(op1) # CHECK: op2.input is None: False print(f"op2.input is None: {op2.input is None}") + + +# CHECK-LABEL: TEST: testCustomAttribute +@run +def testCustomAttribute(): + with Context() as ctx: + test.register_python_test_dialect(ctx) + a = test.TestAttr.get() + # CHECK: #python_test.test_attr + print(a) + + # The following cast must not assert. + b = test.TestAttr(a) + + unit = UnitAttr.get() + try: + test.TestAttr(unit) + except ValueError as e: + assert "Cannot cast attribute to TestAttr" in str(e) + else: + raise + + # The following must trigger a TypeError from pybind (therefore, not + # checking its message) and must not crash. + try: + test.TestAttr(42, 56) + except TypeError: + pass + else: + raise + + +@run +def testCustomType(): + with Context() as ctx: + test.register_python_test_dialect(ctx) + a = test.TestType.get() + # CHECK: !python_test.test_type + print(a) + + # The following cast must not assert. + b = test.TestType(a) + + i8 = IntegerType.get_signless(8) + try: + test.TestType(i8) + except ValueError as e: + assert "Cannot cast type to TestType" in str(e) + else: + raise + + # The following must trigger a TypeError from pybind (therefore, not + # checking its message) and must not crash. + try: + test.TestType(42, 56) + except TypeError: + pass + else: + raise diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h --- a/mlir/test/python/lib/PythonTestCAPI.h +++ b/mlir/test/python/lib/PythonTestCAPI.h @@ -17,6 +17,16 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test); +MLIR_CAPI_EXPORTED bool +mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirPythonTestTestAttributeGet(MlirContext context); + +MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context); + #ifdef __cplusplus } #endif diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp --- a/mlir/test/python/lib/PythonTestCAPI.cpp +++ b/mlir/test/python/lib/PythonTestCAPI.cpp @@ -9,6 +9,23 @@ #include "PythonTestCAPI.h" #include "PythonTestDialect.h" #include "mlir/CAPI/Registration.h" +#include "mlir/CAPI/Wrap.h" MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test, python_test::PythonTestDialect) + +bool mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) { + return wrap(python_test::TestAttrAttr::get(unwrap(context))); +} + +bool mlirTypeIsAPythonTestTestType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirPythonTestTestTypeGet(MlirContext context) { + return wrap(python_test::TestTypeType::get(unwrap(context))); +} diff --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h --- a/mlir/test/python/lib/PythonTestDialect.h +++ b/mlir/test/python/lib/PythonTestDialect.h @@ -18,4 +18,10 @@ #define GET_OP_CLASSES #include "PythonTestOps.h.inc" +#define GET_ATTRDEF_CLASSES +#include "PythonTestAttributes.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "PythonTestTypes.h.inc" + #endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H diff --git a/mlir/test/python/lib/PythonTestDialect.cpp b/mlir/test/python/lib/PythonTestDialect.cpp --- a/mlir/test/python/lib/PythonTestDialect.cpp +++ b/mlir/test/python/lib/PythonTestDialect.cpp @@ -9,9 +9,16 @@ #include "PythonTestDialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/TypeSwitch.h" #include "PythonTestDialect.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "PythonTestAttributes.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "PythonTestTypes.cpp.inc" + #define GET_OP_CLASSES #include "PythonTestOps.cpp.inc" @@ -21,5 +28,14 @@ #define GET_OP_LIST #include "PythonTestOps.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "PythonTestAttributes.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "PythonTestTypes.cpp.inc" + >(); } + } // namespace python_test diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -10,6 +10,7 @@ #include "mlir/Bindings/Python/PybindAdaptors.h" namespace py = pybind11; +using namespace mlir::python::adaptors; PYBIND11_MODULE(_mlirPythonTest, m) { m.def( @@ -23,4 +24,20 @@ } }, py::arg("context"), py::arg("load") = true); + + mlir_attribute_subclass(m, "TestAttr", + mlirAttributeIsAPythonTestTestAttribute) + .def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirPythonTestTestAttributeGet(ctx)); + }, + py::arg("cls"), py::arg("context") = py::none()); + mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType) + .def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirPythonTestTestTypeGet(ctx)); + }, + py::arg("cls"), py::arg("context") = py::none()); } diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -17,9 +17,36 @@ let name = "python_test"; let cppNamespace = "python_test"; } + +class TestType + : TypeDef { + let mnemonic = typeMnemonic; +} + +class TestAttr + : AttrDef { + let mnemonic = attrMnemonic; +} + class TestOp traits = []> : Op; +//===----------------------------------------------------------------------===// +// Type definitions. +//===----------------------------------------------------------------------===// + +def TestType : TestType<"TestType", "test_type">; + +//===----------------------------------------------------------------------===// +// Attribute definitions. +//===----------------------------------------------------------------------===// + +def TestAttr : TestAttr<"TestAttr", "test_attr">; + +//===----------------------------------------------------------------------===// +// Operation definitions. +//===----------------------------------------------------------------------===// + def AttributedOp : TestOp<"attributed_op"> { let arguments = (ins I32Attr:$mandatory_i32, OptionalAttr:$optional_i32,