diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake --- a/mlir/cmake/modules/AddMLIRPython.cmake +++ b/mlir/cmake/modules/AddMLIRPython.cmake @@ -272,6 +272,7 @@ # SOURCES: Same as declare_mlir_python_sources(). # SOURCES_GLOB: Same as declare_mlir_python_sources(). # DEPENDS: Additional dependency targets. +# GEN_ENUM_BINDINGS: Generate enum bindings. # # TODO: Right now `TD_FILE` can't be the actual dialect tablegen file, since we # use its path to determine where to place the generated python file. If @@ -279,9 +280,9 @@ # need for the separate "wrapper" .td files function(declare_mlir_dialect_python_bindings) cmake_parse_arguments(ARG - "" + "GEN_ENUM_BINDINGS" "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME" - "SOURCES;SOURCES_GLOB;DEPENDS" + "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE" ${ARGN}) # Sources. set(_dialect_target "${ARG_ADD_TO_PARENT}.${ARG_DIALECT_NAME}") @@ -306,11 +307,22 @@ ) add_public_tablegen_target(${tblgen_target}) + set(_sources ${dialect_filename}) + if(ARG_GEN_ENUM_BINDINGS OR ARG_GEN_ENUM_BINDINGS_TD_FILE) + if(ARG_GEN_ENUM_BINDINGS_TD_FILE) + set(td_file "${ARG_ROOT_DIR}/${ARG_GEN_ENUM_BINDINGS_TD_FILE}") + set(LLVM_TARGET_DEFINITIONS ${td_file}) + endif() + set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py") + mlir_tablegen(${enum_filename} -gen-python-enum-bindings) + list(APPEND _sources ${enum_filename}) + endif() + # Generated. declare_mlir_python_sources("${_dialect_target}.ops_gen" ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" ADD_TO_PARENT "${_dialect_target}" - SOURCES "${dialect_filename}" + SOURCES ${_sources} ) endif() endfunction() @@ -333,9 +345,9 @@ # DEPENDS: Additional dependency targets. function(declare_mlir_dialect_extension_python_bindings) cmake_parse_arguments(ARG - "" + "GEN_ENUM_BINDINGS" "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME" - "SOURCES;SOURCES_GLOB;DEPENDS" + "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE" ${ARGN}) # Source files. set(_extension_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}") @@ -362,10 +374,21 @@ add_dependencies(${tblgen_target} ${ARG_DEPENDS}) endif() + set(_sources ${output_filename}) + if(ARG_GEN_ENUM_BINDINGS OR ARG_GEN_ENUM_BINDINGS_TD_FILE) + if(ARG_GEN_ENUM_BINDINGS_TD_FILE) + set(td_file "${ARG_ROOT_DIR}/${ARG_GEN_ENUM_BINDINGS_TD_FILE}") + set(LLVM_TARGET_DEFINITIONS ${td_file}) + endif() + set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py") + mlir_tablegen(${enum_filename} -gen-python-enum-bindings) + list(APPEND _sources ${enum_filename}) + endif() + declare_mlir_python_sources("${_extension_target}.ops_gen" ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" ADD_TO_PARENT "${_extension_target}" - SOURCES "${output_filename}" + SOURCES ${_sources} ) endif() endfunction() diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -58,10 +58,11 @@ void loadDialectModule(llvm::StringRef dialectNamespace); /// Adds a user-friendly Attribute builder. - /// Raises an exception if the mapping already exists. + /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerAttributeBuilder(const std::string &attributeKind, - pybind11::function pyFunc); + pybind11::function pyFunc, + bool replace = false); /// Adds a user-friendly type caster. Raises an exception if the mapping /// already exists and replace == false. This is intended to be called by diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -242,19 +242,23 @@ static py::function dundeGetItemNamed(const std::string &attributeKind) { auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); if (!builder) - throw py::key_error(); + throw py::key_error(attributeKind); return *builder; } static void dundeSetItemNamed(const std::string &attributeKind, - py::function func) { - PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func)); + py::function func, bool replace) { + PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), + replace); } static void bind(py::module &m) { py::class_(m, "AttrBuilder", py::module_local()) .def_static("contains", &PyAttrBuilderMap::dunderContains) .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed) - .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed); + .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed, + "attribute_kind"_a, "attr_builder"_a, "replace"_a = false, + "Register an attribute builder for building MLIR " + "attributes from python values."); } }; diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -63,11 +63,13 @@ } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, - py::function pyFunc) { + py::function pyFunc, bool replace) { py::object &found = attributeBuilderMap[attributeKind]; - if (found) { + if (found && !found.is_none() && !replace) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + - attributeKind + "' is already registered") + attributeKind + + "' is already registered with func: " + + py::str(found).operator std::string()) .str()); } found = std::move(pyFunc); diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -52,7 +52,8 @@ TD_FILE dialects/AMDGPUOps.td SOURCES dialects/amdgpu.py - DIALECT_NAME amdgpu) + DIALECT_NAME amdgpu + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -109,7 +110,8 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/GPUOps.td SOURCES_GLOB dialects/gpu/*.py - DIALECT_NAME gpu) + DIALECT_NAME gpu + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -120,7 +122,8 @@ SOURCES_GLOB dialects/linalg/*.py DIALECT_NAME linalg - DEPENDS LinalgOdsGen) + DEPENDS LinalgOdsGen + GEN_ENUM_BINDINGS) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -140,16 +143,8 @@ dialects/_transform_ops_ext.py dialects/transform/__init__.py _mlir_libs/_mlir/dialects/transform/__init__.pyi - DIALECT_NAME transform) - -set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/TransformOps.td") -mlir_tablegen("dialects/_transform_enum_gen.py" -gen-python-enum-bindings) -add_public_tablegen_target(MLIRTransformDialectPyEnumGen) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.transform.enum_gen - ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" - ADD_TO_PARENT MLIRPythonSources.Dialects.transform - SOURCES "dialects/_transform_enum_gen.py") + DIALECT_NAME transform + GEN_ENUM_BINDINGS) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -159,16 +154,8 @@ dialects/_bufferization_transform_ops_ext.py dialects/transform/bufferization.py DIALECT_NAME transform - EXTENSION_NAME bufferization_transform) - -set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/BufferizationTransformOps.td") -mlir_tablegen("dialects/_bufferization_transform_enum_gen.py" -gen-python-enum-bindings) -add_public_tablegen_target(MLIRBufferizationTransformDialectPyEnumGen) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.bufferization_transform.enum_gen - ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" - ADD_TO_PARENT MLIRPythonSources.Dialects.bufferization_transform - SOURCES "dialects/_bufferization_transform_enum_gen.py") + EXTENSION_NAME bufferization_transform + GEN_ENUM_BINDINGS) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -208,7 +195,8 @@ dialects/_structured_transform_ops_ext.py dialects/transform/structured.py DIALECT_NAME transform - EXTENSION_NAME structured_transform) + EXTENSION_NAME structured_transform + GEN_ENUM_BINDINGS_TD_FILE dialects/LinalgTransformEnums.td) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -227,16 +215,8 @@ SOURCES dialects/transform/vector.py DIALECT_NAME transform - EXTENSION_NAME vector_transform) - -set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/VectorTransformOps.td") -mlir_tablegen("dialects/_vector_transform_enum_gen.py" -gen-python-enum-bindings) -add_public_tablegen_target(MLIRVectorTransformPyEnumGen) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.vector_transform.enum_gen - ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" - ADD_TO_PARENT MLIRPythonSources.Dialects.vector_transform - SOURCES "dialects/_vector_transform_enum_gen.py" ) + EXTENSION_NAME vector_transform + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -252,7 +232,8 @@ SOURCES dialects/arith.py dialects/_arith_ops_ext.py - DIALECT_NAME arith) + DIALECT_NAME arith + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -278,7 +259,8 @@ TD_FILE dialects/NVGPUOps.td SOURCES dialects/nvgpu.py - DIALECT_NAME nvgpu) + DIALECT_NAME nvgpu + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -286,7 +268,8 @@ TD_FILE dialects/NVVMOps.td SOURCES dialects/nvvm.py - DIALECT_NAME nvvm) + DIALECT_NAME nvvm + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -300,6 +283,7 @@ MLIRPythonSources.Dialects.quant ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + GEN_ENUM_BINDINGS SOURCES dialects/quant.py _mlir_libs/_mlir/dialects/quant.pyi) @@ -335,7 +319,8 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/SparseTensorOps.td SOURCES dialects/sparse_tensor.py - DIALECT_NAME sparse_tensor) + DIALECT_NAME sparse_tensor + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -358,7 +343,8 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/VectorOps.td SOURCES dialects/vector.py - DIALECT_NAME vector) + DIALECT_NAME vector + GEN_ENUM_BINDINGS) ################################################################################ # Python extensions. diff --git a/mlir/python/mlir/dialects/LinalgTransformEnums.td b/mlir/python/mlir/dialects/LinalgTransformEnums.td new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/LinalgTransformEnums.td @@ -0,0 +1,20 @@ +//===-- LinalgTransformEnums.td --------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the structured transform ops +// provided by Linalg (and other dialects). +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_LINALG_TRANSFORM_ENUMS +#define PYTHON_BINDINGS_LINALG_TRANSFORM_ENUMS + +include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td" + +#endif // PYTHON_BINDINGS_LINALG_TRANSFORM_ENUMS diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py --- a/mlir/python/mlir/dialects/amdgpu.py +++ b/mlir/python/mlir/dialects/amdgpu.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._amdgpu_ops_gen import * +from ._amdgpu_enum_gen import * diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._arith_ops_gen import * +from ._arith_enum_gen import * diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py --- a/mlir/python/mlir/dialects/gpu/__init__.py +++ b/mlir/python/mlir/dialects/gpu/__init__.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._gpu_ops_gen import * +from .._gpu_enum_gen import * diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -9,6 +9,7 @@ # definitions following these steps: # DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py. from .._linalg_ops_gen import * +from .._linalg_enum_gen import * # These are the ground truth functions defined as: # ``` diff --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py --- a/mlir/python/mlir/dialects/nvgpu.py +++ b/mlir/python/mlir/dialects/nvgpu.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._nvgpu_ops_gen import * +from ._nvgpu_enum_gen import * diff --git a/mlir/python/mlir/dialects/nvvm.py b/mlir/python/mlir/dialects/nvvm.py --- a/mlir/python/mlir/dialects/nvvm.py +++ b/mlir/python/mlir/dialects/nvvm.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._nvvm_ops_gen import * +from ._nvvm_enum_gen import * diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py --- a/mlir/python/mlir/dialects/sparse_tensor.py +++ b/mlir/python/mlir/dialects/sparse_tensor.py @@ -3,5 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._sparse_tensor_ops_gen import * +from ._sparse_tensor_enum_gen import * from .._mlir_libs._mlirDialectsSparseTensor import * from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._structured_transform_ops_gen import * +from .._structured_transform_enum_gen import * diff --git a/mlir/python/mlir/dialects/vector.py b/mlir/python/mlir/dialects/vector.py --- a/mlir/python/mlir/dialects/vector.py +++ b/mlir/python/mlir/dialects/vector.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._vector_ops_gen import * +from ._vector_enum_gen import * diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -8,9 +8,9 @@ # Convenience decorator for registering user-friendly Attribute builders. -def register_attribute_builder(kind): +def register_attribute_builder(kind, replace=False): def decorator_builder(func): - AttrBuilder.insert(kind, func) + AttrBuilder.insert(kind, func, replace=replace) return func return decorator_builder diff --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td --- a/mlir/test/mlir-tblgen/enums-python-bindings.td +++ b/mlir/test/mlir-tblgen/enums-python-bindings.td @@ -4,54 +4,35 @@ // CHECK: Autogenerated by mlir-tblgen; don't manually edit. -// CHECK: from enum import Enum +// CHECK: from enum import IntEnum, Enum // CHECK: from ._ods_common import _cext as _ods_cext +// CHECK: from ..ir import register_attribute_builder // CHECK: _ods_ir = _ods_cext.ir def One : I32EnumAttrCase<"CaseOne", 1, "one">; def Two : I32EnumAttrCase<"CaseTwo", 2, "two">; def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two]>; -// CHECK: def _register_attribute_builder(kind): -// CHECK: def decorator_builder(func): -// CHECK: _ods_ir.AttrBuilder.insert(kind, func) -// CHECK: return func -// CHECK: return decorator_builder - -// CHECK-LABEL: class MyEnum(Enum): +// CHECK-LABEL: class MyEnum(IntEnum): // CHECK: """An example 32-bit enum""" // CHECK: CASE_ONE = 1 // CHECK: CASE_TWO = 2 -// CHECK: def _as_int(self): -// CHECK: if self is MyEnum.CASE_ONE: -// CHECK: return 1 -// CHECK: if self is MyEnum.CASE_TWO: -// CHECK: return 2 -// CHECK: assert False, "Unknown MyEnum enum entry." - def One64 : I64EnumAttrCase<"CaseOne64", 1, "one">; def Two64 : I64EnumAttrCase<"CaseTwo64", 2, "two">; def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>; -// CHECK: @_register_attribute_builder("MyEnum") +// CHECK: @register_attribute_builder("MyEnum") // CHECK: def _my_enum(x, context): -// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), x._as_int()) +// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x)) -// CHECK-LABEL: class MyEnum64(Enum): +// CHECK-LABEL: class MyEnum64(IntEnum): // CHECK: """An example 64-bit enum""" // CHECK: CASE_ONE64 = 1 // CHECK: CASE_TWO64 = 2 -// CHECK: def _as_int(self): -// CHECK: if self is MyEnum64.CASE_ONE64: -// CHECK: return 1 -// CHECK: if self is MyEnum64.CASE_TWO64: -// CHECK: return 2 -// CHECK: assert False, "Unknown MyEnum64 enum entry." - -// CHECK: @_register_attribute_builder("MyEnum64") +// CHECK: @register_attribute_builder("MyEnum64") // CHECK: def _my_enum64(x, context): -// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), x._as_int()) +// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x)) diff --git a/mlir/test/python/dialects/gpu.py b/mlir/test/python/dialects/gpu.py --- a/mlir/test/python/dialects/gpu.py +++ b/mlir/test/python/dialects/gpu.py @@ -1,22 +1,31 @@ # RUN: %PYTHON %s | FileCheck %s from mlir.ir import * -import mlir.dialects.gpu -import mlir.dialects.gpu.passes +import mlir.dialects.gpu as gpu from mlir.passmanager import * def run(f): print("\nTEST:", f.__name__) - f() + with Context(), Location.unknown(): + f() + return f +# CHECK-LABEL: testGPUPass +# CHECK: SUCCESS +@run def testGPUPass(): - with Context() as context: - PassManager.parse("any(gpu-kernel-outlining)") + PassManager.parse("any(gpu-kernel-outlining)") print("SUCCESS") -# CHECK-LABEL: testGPUPass -# CHECK: SUCCESS -run(testGPUPass) +# CHECK-LABEL: testMMAElementWiseAttr +@run +def testMMAElementWiseAttr(): + module = Module.create() + with InsertionPoint(module.body): + gpu.BlockDimOp(gpu.Dimension.Y) + # CHECK: %0 = gpu.block_dim y + print(module) + pass diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -451,3 +451,53 @@ # CHECK: transform.sequence # CHECK: = transform.structured.vectorize # CHECK: {vectorize_padding} + + +@run +def testMatchInterfaceEnum(): + names = ArrayAttr.get([StringAttr.get("test.dummy")]) + result_type = transform.AnyOpType.get() + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + fused = structured.MatchOp.__base__( + result_type, + sequence.bodyTarget, + ops=names, + interface=structured.MatchInterfaceEnum.LINALG_OP, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMatchInterfaceEnum + # CHECK: transform.sequence + # CHECK: = transform.structured.match + # CHECK: interface{LinalgOp} + + +@run +def testMatchInterfaceEnumReplaceAttributeBuilder(): + @register_attribute_builder("MatchInterfaceEnum", replace=True) + def match_interface_enum(x, context): + if x == "LINALG_OP": + y = 0 + elif x == "TILING_INTERFACE": + y = 1 + return IntegerAttr.get(IntegerType.get_signless(32, context=context), y) + + names = ArrayAttr.get([StringAttr.get("test.dummy")]) + result_type = transform.AnyOpType.get() + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + fused = structured.MatchOp.__base__( + result_type, + sequence.bodyTarget, + ops=names, + interface="TILING_INTERFACE", + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMatchInterfaceEnumReplaceAttributeBuilder + # CHECK: transform.sequence + # CHECK: = transform.structured.match + # CHECK: interface{TilingInterface} diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py --- a/mlir/test/python/dialects/vector.py +++ b/mlir/test/python/dialects/vector.py @@ -64,3 +64,21 @@ # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]] # CHECK-NOT: %[[MASK]] print(module) + + +# CHECK-LABEL: TEST: testBitEnumCombiningKind +@run +def testBitEnumCombiningKind(): + module = Module.create() + with InsertionPoint(module.body): + f32 = F32Type.get() + vector_type = VectorType.get([16], f32) + + @func.FuncOp.from_py_func(vector_type) + def reduction(arg): + v = vector.ReductionOp(f32, vector.CombiningKind.ADD, arg) + return v + + # CHECK: func.func @reduction(%[[VEC:.*]]: vector<16xf32>) -> f32 { + # CHECK: %0 = vector.reduction , %[[VEC]] : vector<16xf32> into f32 + print(module) diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp @@ -24,29 +24,45 @@ constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. -from enum import Enum +from enum import IntEnum, Enum from ._ods_common import _cext as _ods_cext +from ..ir import register_attribute_builder _ods_ir = _ods_cext.ir -# Convenience decorator for registering user-friendly Attribute builders. -def _register_attribute_builder(kind): - def decorator_builder(func): - _ods_ir.AttrBuilder.insert(kind, func) - return func - - return decorator_builder - )Py"; /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE. static std::string makePythonEnumCaseName(StringRef name) { + if (llvm::all_of(name, [](char c) { + return std::isupper(c) || c == '_' || std::isdigit(c); + })) + return name.str(); return StringRef(llvm::convertToSnakeFromCamelCase(name)).upper(); } +/// Emits the Python class for the given bit enum. +static void emitDialectEnumClass(StringRef enumName, StringRef description, + ArrayRef cases, + raw_ostream &os) { + os << llvm::formatv("class {0}(str, Enum):\n", enumName); + if (!description.empty()) + os << llvm::formatv(" \"\"\"{0}\"\"\"\n", description); + os << " def __str__(self): return self.value\n"; + os << "\n"; + + for (const EnumAttrCase &enumCase : cases) { + os << llvm::formatv(" {0} = \"{1}\"\n", + makePythonEnumCaseName(enumCase.getSymbol()), + enumCase.getStr()); + } + + os << "\n"; +} + /// Emits the Python class for the given enum. static void emitEnumClass(StringRef enumName, StringRef description, ArrayRef cases, raw_ostream &os) { - os << llvm::formatv("class {0}(Enum):\n", enumName); + os << llvm::formatv("class {0}(IntEnum):\n", enumName); if (!description.empty()) os << llvm::formatv(" \"\"\"{0}\"\"\"\n", description); os << "\n"; @@ -58,14 +74,6 @@ } os << "\n"; - os << llvm::formatv(" def _as_int(self):\n"); - for (const EnumAttrCase &enumCase : cases) { - os << llvm::formatv(" if self is {0}.{1}:\n", enumName, - makePythonEnumCaseName(enumCase.getSymbol())); - os << llvm::formatv(" return {0}\n", enumCase.getValue()); - } - os << llvm::formatv(" assert False, \"Unknown {0} enum entry.\"\n\n\n", - enumName); } /// Attempts to extract the bitwidth B from string "uintB_t" describing the @@ -82,7 +90,8 @@ /// Emits an attribute builder for the given enum attribute to support automatic /// conversion between enum values and attributes in Python. Returns /// `false` on success, `true` on failure. -static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) { +static bool emitAttributeBuilder(StringRef attrDefName, + const EnumAttr &enumAttr, raw_ostream &os) { int64_t bitwidth; if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) { llvm::errs() << "failed to identify bitwidth of " @@ -90,36 +99,99 @@ return true; } - os << llvm::formatv("@_register_attribute_builder(\"{0}\")\n", - enumAttr.getAttrDefName()); + os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName); os << llvm::formatv( "def _{0}(x, context):\n", llvm::convertToSnakeFromCamelCase(enumAttr.getAttrDefName())); os << llvm::formatv( " return " "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, " - "context=context), x._as_int())\n\n", + "context=context), int(x))\n\n", bitwidth); return false; } +/// Emits an attribute builder for the given bit enum attribute to support +/// automatic conversion between enum values and attributes in Python. Returns +/// `false` on success, `true` on failure. +static bool emitDialectEnumAttributeBuilder(StringRef attrDefName, + StringRef formatString, + raw_ostream &os) { + os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName); + os << llvm::formatv("def _{0}(x, context):\n", + llvm::convertToSnakeFromCamelCase(attrDefName)); + os << llvm::formatv(" return " + "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n", + formatString); + return false; +} + /// Emits Python bindings for all enums in the record keeper. Returns /// `false` on success, `true` on failure. static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { os << fileHeader; - std::vector defs = - recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"); - for (const llvm::Record *def : defs) { - EnumAttr enumAttr(*def); - if (enumAttr.isBitEnum()) { - llvm::errs() << "bit enums not supported\n"; - return true; + + // Find the records corresponding to enums with load bearing names that + // correspond directly to op attribute names. + llvm::DenseMap correctlyNamedEnums; + for (auto &it : recordKeeper.getDefs()) { + if (!it.second->isSubClassOf("EnumAttr") and + !it.second->isSubClassOf("EnumAttrInfo")) + continue; + EnumAttr *enumAttr; + if (it.second->isSubClassOf("EnumAttrInfo")) { + enumAttr = new EnumAttr(*it.second); + if (!correctlyNamedEnums.contains(enumAttr->getAttrDefName())) { + correctlyNamedEnums[enumAttr->getAttrDefName()] = &*it.second; + } + } else { + enumAttr = new EnumAttr(it.second->getValueAsDef("enum")); + correctlyNamedEnums[enumAttr->getAttrDefName()] = &*it.second; } - emitEnumClass(enumAttr.getEnumClassName(), enumAttr.getSummary(), - enumAttr.getAllCases(), os); - emitAttributeBuilder(enumAttr, os); } + + if (correctlyNamedEnums.empty()) { + llvm::errs() << "no enums to generate."; + return true; + } + + // Emit the enum classes and registration hooks with the correct names. + for (auto &it : correctlyNamedEnums) { + Attribute attr(it.second); + if (it.second->isSubClassOf("EnumAttrInfo")) { + EnumAttr enumAttr(*it.second); + emitEnumClass(enumAttr.getEnumClassName(), enumAttr.getSummary(), + enumAttr.getAllCases(), os); + emitAttributeBuilder(attr.getAttrDefName(), enumAttr, os); + } else { + EnumAttr enumAttr(it.second->getValueAsDef("enum")); + emitDialectEnumClass(enumAttr.getEnumClassName(), enumAttr.getSummary(), + enumAttr.getAllCases(), os); + StringRef dialect = + attr.getDef().getValueAsDef("dialect")->getValueAsString("name"); + StringRef mnemonic = attr.getDef().getValueAsString("mnemonic"); + StringRef assemblyFormat = + attr.getDef().getValueAsString("assemblyFormat"); + + if ((assemblyFormat != "`<` $value `>`" && assemblyFormat != "$value") || + attr.getDef().getValueAsBit("hasCustomAssemblyFormat")) { + llvm::errs() << "unsupported assembly format for enum"; + return true; + } + + if (assemblyFormat == "`<` $value `>`") { + emitDialectEnumAttributeBuilder( + attr.getAttrDefName(), + llvm::formatv("#{0}.{1}<{{x}>", dialect, mnemonic).str(), os); + } else if (assemblyFormat == "$value") { + emitDialectEnumAttributeBuilder( + attr.getAttrDefName(), + llvm::formatv("#{0}<{1} {{x}>", dialect, mnemonic).str(), os); + } + } + } + return false; }