diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -483,6 +483,7 @@ EMBED_CAPI_LINK_LIBS MLIRCAPIPythonTestDialect ) + include_directories("${MLIR_BINARY_DIR}/test/python/lib") endif() ################################################################################ diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,7 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * -from .._mlir_libs._mlirPythonTest import TestAttr, TestType +from .._mlir_libs._mlirPythonTest import ( + TestEnum, TestBitEnum, TestAttr, TestEnumAttr, TestBitEnumAttr, TestType) def register_python_test_dialect(context, load=True): from .._mlir_libs import _mlirPythonTest diff --git a/mlir/test/python/CMakeLists.txt b/mlir/test/python/CMakeLists.txt --- a/mlir/test/python/CMakeLists.txt +++ b/mlir/test/python/CMakeLists.txt @@ -3,6 +3,8 @@ mlir_tablegen(lib/PythonTestDialect.cpp.inc -gen-dialect-defs) mlir_tablegen(lib/PythonTestOps.h.inc -gen-op-decls) mlir_tablegen(lib/PythonTestOps.cpp.inc -gen-op-defs) +mlir_tablegen(lib/PythonTestEnums.h.inc -gen-enum-decls) +mlir_tablegen(lib/PythonTestEnums.cpp.inc -gen-enum-defs) mlir_tablegen(lib/PythonTestAttributes.h.inc -gen-attrdef-decls) mlir_tablegen(lib/PythonTestAttributes.cpp.inc -gen-attrdef-defs) mlir_tablegen(lib/PythonTestTypes.h.inc -gen-typedef-decls) diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -239,6 +239,42 @@ # The following cast must not assert. b = test.TestAttr(a) + # Verify pybind-generated enum comparisons. + # CHECK: True + print(test.TestEnum.OneSquared == test.TestEnum.OneSquared) + + # CHECK: False + print(test.TestEnum.TwoSquared == test.TestEnum.ThreeSquared) + + # CHECK: True + print(test.TestEnum.ThreeSquared.value == 9) + + # Get a TestEnumAttr using the enum value. + c = test.TestEnumAttr.get(test.TestEnum.OneSquared, ctx) + # CHECK: #python_test<"test_enum OneSquared"> + print(c) + + # Get a TestEnumAttr using the associated integer value. + d = test.TestEnumAttr.get(4, ctx) + # CHECK: #python_test<"test_enum TwoSquared"> + print(d) + + # Get a TestBitEnumAttr using a single enum value. + e = test.TestBitEnumAttr.get(test.TestBitEnum.Bit0, ctx) + # CHECK: #python_test<"test_bit_enum Bit0"> + print(e) + + # Get a TestBitEnumAttr using multiple enum values. + f = test.TestBitEnumAttr.get(test.TestBitEnum.Bit0 | test.TestBitEnum.Bit3, + ctx) + # CHECK: #python_test<"test_bit_enum Bit0|Bit3"> + print(f) + + # Get a TestBitEnumAttr using the associated integer value. + g = test.TestBitEnumAttr.get(4 | 16, ctx) + # CHECK: #python_test<"test_bit_enum Bit2|Bit4"> + print(g) + unit = UnitAttr.get() try: test.TestAttr(unit) diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h --- a/mlir/test/python/lib/PythonTestCAPI.h +++ b/mlir/test/python/lib/PythonTestCAPI.h @@ -23,6 +23,24 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context); +MLIR_CAPI_EXPORTED bool +mlirAttributeIsAPythonTestTestEnumAttribute(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirPythonTestTestEnumAttributeGet(MlirContext context, uint32_t value); + +MLIR_CAPI_EXPORTED uint32_t +mlirPythonTestTestEnumAttributeGetValue(MlirAttribute attr); + +MLIR_CAPI_EXPORTED bool +mlirAttributeIsAPythonTestTestBitEnumAttribute(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirPythonTestTestBitEnumAttributeGet(MlirContext context, uint32_t value); + +MLIR_CAPI_EXPORTED uint32_t +mlirPythonTestTestBitEnumAttributeGetValue(MlirAttribute attr); + MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context); diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp --- a/mlir/test/python/lib/PythonTestCAPI.cpp +++ b/mlir/test/python/lib/PythonTestCAPI.cpp @@ -22,6 +22,36 @@ return wrap(python_test::TestAttrAttr::get(unwrap(context))); } +bool mlirAttributeIsAPythonTestTestEnumAttribute(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirPythonTestTestEnumAttributeGet(MlirContext context, + uint32_t value) { + return wrap(python_test::TestEnumAttr::get( + unwrap(context), static_cast(value))); +} + +uint32_t mlirPythonTestTestEnumAttributeGetValue(MlirAttribute attr) { + return static_cast( + unwrap(attr).cast().getValue()); +} + +bool mlirAttributeIsAPythonTestTestBitEnumAttribute(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirPythonTestTestBitEnumAttributeGet(MlirContext context, + uint32_t value) { + return wrap(python_test::TestBitEnumAttr::get( + unwrap(context), static_cast(value))); +} + +uint32_t mlirPythonTestTestBitEnumAttributeGetValue(MlirAttribute attr) { + return static_cast( + unwrap(attr).cast().getValue()); +} + bool mlirTypeIsAPythonTestTestType(MlirType type) { return unwrap(type).isa(); } diff --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h --- a/mlir/test/python/lib/PythonTestDialect.h +++ b/mlir/test/python/lib/PythonTestDialect.h @@ -15,6 +15,8 @@ #include "PythonTestDialect.h.inc" +#include "PythonTestEnums.h.inc" + #define GET_OP_CLASSES #include "PythonTestOps.h.inc" diff --git a/mlir/test/python/lib/PythonTestDialect.cpp b/mlir/test/python/lib/PythonTestDialect.cpp --- a/mlir/test/python/lib/PythonTestDialect.cpp +++ b/mlir/test/python/lib/PythonTestDialect.cpp @@ -9,9 +9,11 @@ #include "PythonTestDialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "PythonTestDialect.cpp.inc" +#include "PythonTestEnums.cpp.inc" #define GET_ATTRDEF_CLASSES #include "PythonTestAttributes.cpp.inc" diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -8,11 +8,43 @@ #include "PythonTestCAPI.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" + +#include "PythonTestEnums.h.inc" namespace py = pybind11; using namespace mlir::python::adaptors; +template T enum_or(T a, T b) { + using utype = typename std::underlying_type::type; + utype v = static_cast(a) | static_cast(b); + return static_cast(v); +} + PYBIND11_MODULE(_mlirPythonTest, m) { + py::enum_(m, "TestEnum", py::module_local()) + .value("OneSquared", python_test::TestEnum::OneSquared) + .value("TwoSquared", python_test::TestEnum::TwoSquared) + .value("ThreeSquared", python_test::TestEnum::ThreeSquared); + py::enum_(m, "TestBitEnum", py::module_local(), + py::arithmetic()) + .value("Bit0", python_test::TestBitEnum::Bit0) + .value("Bit1", python_test::TestBitEnum::Bit1) + .value("Bit2", python_test::TestBitEnum::Bit2) + .value("Bit3", python_test::TestBitEnum::Bit3) + .value("Bit4", python_test::TestBitEnum::Bit4) + // Pybind does not wish to define the bitwise or operator for + // enum classes (because C++ does not by default) - see: + // https://github.com/pybind/pybind11/issues/2221 + .def("__or__", + [](python_test::TestBitEnum e1, python_test::TestBitEnum e2) { + return enum_or(e1, e2); + }) + .def("__ror__", + [](python_test::TestBitEnum e1, python_test::TestBitEnum e2) { + return enum_or(e1, e2); + }); m.def( "register_python_test_dialect", [](MlirContext context, bool load) { @@ -33,6 +65,30 @@ return cls(mlirPythonTestTestAttributeGet(ctx)); }, py::arg("cls"), py::arg("context") = py::none()); + mlir_attribute_subclass(m, "TestEnumAttr", + mlirAttributeIsAPythonTestTestEnumAttribute) + .def_classmethod( + "get", + [](py::object cls, uint32_t value, MlirContext ctx) { + return cls(mlirPythonTestTestEnumAttributeGet(ctx, value)); + }, + py::arg("cls"), py::arg("value"), py::arg("context") = py::none()) + .def_property_readonly("value", [](MlirAttribute self) { + auto value = mlirPythonTestTestEnumAttributeGetValue(self); + return static_cast(value); + }); + mlir_attribute_subclass(m, "TestBitEnumAttr", + mlirAttributeIsAPythonTestTestBitEnumAttribute) + .def_classmethod( + "get", + [](py::object cls, uint32_t value, MlirContext ctx) { + return cls(mlirPythonTestTestBitEnumAttributeGet(ctx, value)); + }, + py::arg("cls"), py::arg("value"), py::arg("context") = py::none()) + .def_property_readonly("value", [](MlirAttribute self) { + auto value = mlirPythonTestTestBitEnumAttributeGetValue(self); + return static_cast(value); + }); mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType) .def_classmethod( "get", diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -10,6 +10,7 @@ #define PYTHON_TEST_OPS include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" include "mlir/Bindings/Python/Attributes.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -45,7 +46,30 @@ // Attribute definitions. //===----------------------------------------------------------------------===// +def TestEnum : + I32EnumAttr<"TestEnum", "A test enum", + [I32EnumAttrCase<"OneSquared", 1>, + I32EnumAttrCase<"TwoSquared", 4>, + I32EnumAttrCase<"ThreeSquared", 9>, ]> { + let genSpecializedAttr = 0; + let cppNamespace = "python_test"; +} + +def TestBitEnum : I32BitEnumAttr<"TestBitEnum", "A test bit enum", + [I32BitEnumAttrCaseBit<"Bit0", 0>, + I32BitEnumAttrCaseBit<"Bit1", 1>, + I32BitEnumAttrCaseBit<"Bit2", 2>, + I32BitEnumAttrCaseBit<"Bit3", 3>, + I32BitEnumAttrCaseBit<"Bit4", 4>, ]> { + let genSpecializedAttr = 0; + let cppNamespace = "python_test"; +} + + def TestAttr : TestAttr<"TestAttr", "test_attr">; +def TestEnumAttr : EnumAttr; +def TestBitEnumAttr : EnumAttr; //===----------------------------------------------------------------------===// // Operation definitions.