diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake --- a/mlir/cmake/modules/AddMLIRPython.cmake +++ b/mlir/cmake/modules/AddMLIRPython.cmake @@ -272,6 +272,7 @@ # SOURCES: Same as declare_mlir_python_sources(). # SOURCES_GLOB: Same as declare_mlir_python_sources(). # DEPENDS: Additional dependency targets. +# ENUM_GEN: Generate enum mappings. # # TODO: Right now `TD_FILE` can't be the actual dialect tablegen file, since we # use its path to determine where to place the generated python file. If @@ -279,7 +280,7 @@ # need for the separate "wrapper" .td files function(declare_mlir_dialect_python_bindings) cmake_parse_arguments(ARG - "" + "ENUM_GEN" "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME" "SOURCES;SOURCES_GLOB;DEPENDS" ${ARGN}) @@ -306,11 +307,18 @@ ) add_public_tablegen_target(${tblgen_target}) + set(_sources ${dialect_filename}) + if(ARG_ENUM_GEN) + set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py") + mlir_tablegen(${enum_filename} -gen-python-enum-bindings) + list(APPEND _sources ${enum_filename}) + endif() + # Generated. declare_mlir_python_sources("${_dialect_target}.ops_gen" ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" ADD_TO_PARENT "${_dialect_target}" - SOURCES "${dialect_filename}" + SOURCES ${_sources} ) endif() endfunction() @@ -333,7 +341,7 @@ # DEPENDS: Additional dependency targets. function(declare_mlir_dialect_extension_python_bindings) cmake_parse_arguments(ARG - "" + "ENUM_GEN" "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME" "SOURCES;SOURCES_GLOB;DEPENDS" ${ARGN}) @@ -362,10 +370,17 @@ add_dependencies(${tblgen_target} ${ARG_DEPENDS}) endif() + set(_sources ${output_filename}) + if(ARG_ENUM_GEN) + set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py") + mlir_tablegen(${enum_filename} -gen-python-enum-bindings) + list(APPEND _sources ${enum_filename}) + endif() + declare_mlir_python_sources("${_extension_target}.ops_gen" ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" ADD_TO_PARENT "${_extension_target}" - SOURCES "${output_filename}" + SOURCES ${_sources} ) endif() endfunction() diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -242,7 +242,7 @@ static py::function dundeGetItemNamed(const std::string &attributeKind) { auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); if (!builder) - throw py::key_error(); + throw py::key_error(attributeKind); return *builder; } static void dundeSetItemNamed(const std::string &attributeKind, diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -52,7 +52,8 @@ TD_FILE dialects/AMDGPUOps.td SOURCES dialects/amdgpu.py - DIALECT_NAME amdgpu) + DIALECT_NAME amdgpu + ENUM_GEN) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -109,7 +110,8 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/GPUOps.td SOURCES_GLOB dialects/gpu/*.py - DIALECT_NAME gpu) + DIALECT_NAME gpu + ENUM_GEN) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -120,7 +122,8 @@ SOURCES_GLOB dialects/linalg/*.py DIALECT_NAME linalg - DEPENDS LinalgOdsGen) + DEPENDS LinalgOdsGen + ENUM_GEN) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -140,16 +143,8 @@ dialects/_transform_ops_ext.py dialects/transform/__init__.py _mlir_libs/_mlir/dialects/transform/__init__.pyi - DIALECT_NAME transform) - -set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/TransformOps.td") -mlir_tablegen("dialects/_transform_enum_gen.py" -gen-python-enum-bindings) -add_public_tablegen_target(MLIRTransformDialectPyEnumGen) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.transform.enum_gen - ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" - ADD_TO_PARENT MLIRPythonSources.Dialects.transform - SOURCES "dialects/_transform_enum_gen.py") + DIALECT_NAME transform + ENUM_GEN) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -159,16 +154,8 @@ dialects/_bufferization_transform_ops_ext.py dialects/transform/bufferization.py DIALECT_NAME transform - EXTENSION_NAME bufferization_transform) - -set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/BufferizationTransformOps.td") -mlir_tablegen("dialects/_bufferization_transform_enum_gen.py" -gen-python-enum-bindings) -add_public_tablegen_target(MLIRBufferizationTransformDialectPyEnumGen) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.bufferization_transform.enum_gen - ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" - ADD_TO_PARENT MLIRPythonSources.Dialects.bufferization_transform - SOURCES "dialects/_bufferization_transform_enum_gen.py") + EXTENSION_NAME bufferization_transform + ENUM_GEN) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -208,7 +195,8 @@ dialects/_structured_transform_ops_ext.py dialects/transform/structured.py DIALECT_NAME transform - EXTENSION_NAME structured_transform) + EXTENSION_NAME structured_transform + ENUM_GEN) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -227,23 +215,16 @@ SOURCES dialects/transform/vector.py DIALECT_NAME transform - EXTENSION_NAME vector_transform) - -set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/VectorTransformOps.td") -mlir_tablegen("dialects/_vector_transform_enum_gen.py" -gen-python-enum-bindings) -add_public_tablegen_target(MLIRVectorTransformPyEnumGen) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.vector_transform.enum_gen - ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" - ADD_TO_PARENT MLIRPythonSources.Dialects.vector_transform - SOURCES "dialects/_vector_transform_enum_gen.py" ) + EXTENSION_NAME vector_transform + ENUM_GEN) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/MathOps.td SOURCES dialects/math.py - DIALECT_NAME math) + DIALECT_NAME math + ENUM_GEN) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -252,7 +233,8 @@ SOURCES dialects/arith.py dialects/_arith_ops_ext.py - DIALECT_NAME arith) + DIALECT_NAME arith + ENUM_GEN) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -261,7 +243,8 @@ SOURCES dialects/memref.py dialects/_memref_ops_ext.py - DIALECT_NAME memref) + DIALECT_NAME memref + ENUM_GEN) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -278,7 +261,8 @@ TD_FILE dialects/NVGPUOps.td SOURCES dialects/nvgpu.py - DIALECT_NAME nvgpu) + DIALECT_NAME nvgpu + ENUM_GEN) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -286,7 +270,8 @@ TD_FILE dialects/NVVMOps.td SOURCES dialects/nvvm.py - DIALECT_NAME nvvm) + DIALECT_NAME nvvm + ENUM_GEN) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -300,6 +285,7 @@ MLIRPythonSources.Dialects.quant ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + ENUM_GEN SOURCES dialects/quant.py _mlir_libs/_mlir/dialects/quant.pyi) @@ -335,7 +321,8 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/SparseTensorOps.td SOURCES dialects/sparse_tensor.py - DIALECT_NAME sparse_tensor) + DIALECT_NAME sparse_tensor + ENUM_GEN) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -358,7 +345,8 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/VectorOps.td SOURCES dialects/vector.py - DIALECT_NAME vector) + DIALECT_NAME vector + ENUM_GEN) ################################################################################ # Python extensions. diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py --- a/mlir/python/mlir/dialects/amdgpu.py +++ b/mlir/python/mlir/dialects/amdgpu.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._amdgpu_ops_gen import * +from ._amdgpu_enum_gen import * diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._arith_ops_gen import * +from ._arith_enum_gen import * diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py --- a/mlir/python/mlir/dialects/gpu/__init__.py +++ b/mlir/python/mlir/dialects/gpu/__init__.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._gpu_ops_gen import * +from .._gpu_enum_gen import * diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -9,6 +9,7 @@ # definitions following these steps: # DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py. from .._linalg_ops_gen import * +from .._linalg_enum_gen import * # These are the ground truth functions defined as: # ``` diff --git a/mlir/python/mlir/dialects/math.py b/mlir/python/mlir/dialects/math.py --- a/mlir/python/mlir/dialects/math.py +++ b/mlir/python/mlir/dialects/math.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._math_ops_gen import * +from ._math_enum_gen import * diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py --- a/mlir/python/mlir/dialects/memref.py +++ b/mlir/python/mlir/dialects/memref.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._memref_ops_gen import * +from ._memref_enum_gen import * diff --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py --- a/mlir/python/mlir/dialects/nvgpu.py +++ b/mlir/python/mlir/dialects/nvgpu.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._nvgpu_ops_gen import * +from ._nvgpu_enum_gen import * diff --git a/mlir/python/mlir/dialects/nvvm.py b/mlir/python/mlir/dialects/nvvm.py --- a/mlir/python/mlir/dialects/nvvm.py +++ b/mlir/python/mlir/dialects/nvvm.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._nvvm_ops_gen import * +from ._nvvm_enum_gen import * diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py --- a/mlir/python/mlir/dialects/sparse_tensor.py +++ b/mlir/python/mlir/dialects/sparse_tensor.py @@ -3,5 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._sparse_tensor_ops_gen import * +from ._sparse_tensor_enum_gen import * from .._mlir_libs._mlirDialectsSparseTensor import * from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._structured_transform_ops_gen import * +from .._structured_transform_enum_gen import * diff --git a/mlir/python/mlir/dialects/vector.py b/mlir/python/mlir/dialects/vector.py --- a/mlir/python/mlir/dialects/vector.py +++ b/mlir/python/mlir/dialects/vector.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._vector_ops_gen import * +from ._vector_enum_gen import * diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py --- a/mlir/test/python/dialects/vector.py +++ b/mlir/test/python/dialects/vector.py @@ -64,3 +64,21 @@ # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]] # CHECK-NOT: %[[MASK]] print(module) + + +# CHECK-LABEL: TEST: testBitEnumCombiningKind +@run +def testBitEnumCombiningKind(): + module = Module.create() + with InsertionPoint(module.body): + f32 = F32Type.get() + vector_type = VectorType.get([16], f32) + + @func.FuncOp.from_py_func(vector_type) + def reduction(arg): + v = vector.ReductionOp(f32, vector.Vector_CombiningKindAttr.ADD, arg) + return v + + # CHECK: func.func @reduction(%[[VEC:.*]]: vector<16xf32>) -> f32 { + # CHECK: %0 = vector.reduction , %[[VEC]] : vector<16xf32> into f32 + print(module) diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp @@ -13,6 +13,7 @@ #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/GenInfo.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Record.h" @@ -24,29 +25,51 @@ constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. -from enum import Enum +from enum import IntEnum, Enum from ._ods_common import _cext as _ods_cext _ods_ir = _ods_cext.ir # Convenience decorator for registering user-friendly Attribute builders. def _register_attribute_builder(kind): def decorator_builder(func): - _ods_ir.AttrBuilder.insert(kind, func) + try: + _ods_ir.AttrBuilder.insert(kind, func) + except: + pass return func return decorator_builder - )Py"; /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE. static std::string makePythonEnumCaseName(StringRef name) { + if (llvm::all_of(name, [](char c) { return std::isupper(c) || c == '_'; })) + return name.str(); return StringRef(llvm::convertToSnakeFromCamelCase(name)).upper(); } +/// Emits the Python class for the given bit enum. +static void emitBitEnumClass(const llvm::Record *def, raw_ostream &os) { + EnumAttr enumAttr(*def->getValueAsDef("enum")); + os << llvm::formatv("class {0}(str, Enum):\n", def->getName()); + if (!enumAttr.getSummary().empty()) + os << llvm::formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary()); + os << " def __str__(self): return self.value\n"; + os << "\n"; + + for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) { + os << llvm::formatv(" {0} = \"{1}\"\n", + makePythonEnumCaseName(enumCase.getSymbol()), + enumCase.getStr()); + } + + os << "\n"; +} + /// Emits the Python class for the given enum. static void emitEnumClass(StringRef enumName, StringRef description, ArrayRef cases, raw_ostream &os) { - os << llvm::formatv("class {0}(Enum):\n", enumName); + os << llvm::formatv("class {0}(IntEnum):\n", enumName); if (!description.empty()) os << llvm::formatv(" \"\"\"{0}\"\"\"\n", description); os << "\n"; @@ -58,14 +81,6 @@ } os << "\n"; - os << llvm::formatv(" def _as_int(self):\n"); - for (const EnumAttrCase &enumCase : cases) { - os << llvm::formatv(" if self is {0}.{1}:\n", enumName, - makePythonEnumCaseName(enumCase.getSymbol())); - os << llvm::formatv(" return {0}\n", enumCase.getValue()); - } - os << llvm::formatv(" assert False, \"Unknown {0} enum entry.\"\n\n\n", - enumName); } /// Attempts to extract the bitwidth B from string "uintB_t" describing the @@ -98,28 +113,65 @@ os << llvm::formatv( " return " "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, " - "context=context), x._as_int())\n\n", + "context=context), int(x))\n\n", bitwidth); return false; } +/// Emits an attribute builder for the given bit enum attribute to support +/// automatic conversion between enum values and attributes in Python. Returns +/// `false` on success, `true` on failure. +static bool emitBitEnumAttributeBuilder(const llvm::Record *def, + raw_ostream &os) { + EnumAttr enumAttr(*def->getValueAsDef("enum")); + os << llvm::formatv("@_register_attribute_builder(\"{0}\")\n", + def->getName()); + os << llvm::formatv( + "def _{0}(x, context):\n", + llvm::convertToSnakeFromCamelCase(enumAttr.getAttrDefName())); + os << llvm::formatv( + " return " + "_ods_ir.Attribute.parse(f'#{0}.{1}<{{str(x)}>', context=context)\n\n", + def->getValueAsDef("dialect")->getValueAsString("name"), + def->getValueAsString("mnemonic")); + return false; +} + /// Emits Python bindings for all enums in the record keeper. Returns /// `false` on success, `true` on failure. static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { os << fileHeader; + llvm::SmallSet seenEnums; + int emitted = 0; std::vector defs = - recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"); + recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr"); + for (const llvm::Record *def : defs) { + EnumAttr enumAttr(*def->getValueAsDef("enum")); + seenEnums.insert(&enumAttr); + emitBitEnumClass(def, os); + emitBitEnumAttributeBuilder(def, os); + ++emitted; + } + + defs = recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"); for (const llvm::Record *def : defs) { EnumAttr enumAttr(*def); - if (enumAttr.isBitEnum()) { - llvm::errs() << "bit enums not supported\n"; - return true; - } + if (enumAttr.isBitEnum()) + continue; + if (seenEnums.contains(&enumAttr)) + continue; + emitEnumClass(enumAttr.getEnumClassName(), enumAttr.getSummary(), enumAttr.getAllCases(), os); emitAttributeBuilder(enumAttr, os); + ++emitted; } + if (emitted == 0) { + llvm::errs() << "no enums\n"; + return true; + } + return false; }