diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake --- a/mlir/cmake/modules/AddMLIRPython.cmake +++ b/mlir/cmake/modules/AddMLIRPython.cmake @@ -272,6 +272,8 @@ # SOURCES: Same as declare_mlir_python_sources(). # SOURCES_GLOB: Same as declare_mlir_python_sources(). # DEPENDS: Additional dependency targets. +# GEN_ENUM_BINDINGS: Generate enum bindings. +# GEN_ENUM_BINDINGS_TD_FILE: Optional Tablegen file to generate enums for (relative to ROOT_DIR). # # TODO: Right now `TD_FILE` can't be the actual dialect tablegen file, since we # use its path to determine where to place the generated python file. If @@ -279,9 +281,9 @@ # need for the separate "wrapper" .td files function(declare_mlir_dialect_python_bindings) cmake_parse_arguments(ARG - "" + "GEN_ENUM_BINDINGS" "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME" - "SOURCES;SOURCES_GLOB;DEPENDS" + "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE" ${ARGN}) # Sources. set(_dialect_target "${ARG_ADD_TO_PARENT}.${ARG_DIALECT_NAME}") @@ -306,11 +308,22 @@ ) add_public_tablegen_target(${tblgen_target}) + set(_sources ${dialect_filename}) + if(ARG_GEN_ENUM_BINDINGS OR ARG_GEN_ENUM_BINDINGS_TD_FILE) + if(ARG_GEN_ENUM_BINDINGS_TD_FILE) + set(td_file "${ARG_ROOT_DIR}/${ARG_GEN_ENUM_BINDINGS_TD_FILE}") + set(LLVM_TARGET_DEFINITIONS ${td_file}) + endif() + set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py") + mlir_tablegen(${enum_filename} -gen-python-enum-bindings) + list(APPEND _sources ${enum_filename}) + endif() + # Generated. declare_mlir_python_sources("${_dialect_target}.ops_gen" ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" ADD_TO_PARENT "${_dialect_target}" - SOURCES "${dialect_filename}" + SOURCES ${_sources} ) endif() endfunction() @@ -331,11 +344,13 @@ # SOURCES: Same as declare_mlir_python_sources(). # SOURCES_GLOB: Same as declare_mlir_python_sources(). # DEPENDS: Additional dependency targets. +# GEN_ENUM_BINDINGS: Generate enum bindings. +# GEN_ENUM_BINDINGS_TD_FILE: Optional Tablegen file to generate enums for (relative to ROOT_DIR). function(declare_mlir_dialect_extension_python_bindings) cmake_parse_arguments(ARG - "" + "GEN_ENUM_BINDINGS" "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME" - "SOURCES;SOURCES_GLOB;DEPENDS" + "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE" ${ARGN}) # Source files. set(_extension_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}") @@ -362,10 +377,21 @@ add_dependencies(${tblgen_target} ${ARG_DEPENDS}) endif() + set(_sources ${output_filename}) + if(ARG_GEN_ENUM_BINDINGS OR ARG_GEN_ENUM_BINDINGS_TD_FILE) + if(ARG_GEN_ENUM_BINDINGS_TD_FILE) + set(td_file "${ARG_ROOT_DIR}/${ARG_GEN_ENUM_BINDINGS_TD_FILE}") + set(LLVM_TARGET_DEFINITIONS ${td_file}) + endif() + set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py") + mlir_tablegen(${enum_filename} -gen-python-enum-bindings) + list(APPEND _sources ${enum_filename}) + endif() + declare_mlir_python_sources("${_extension_target}.ops_gen" ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" ADD_TO_PARENT "${_extension_target}" - SOURCES "${output_filename}" + SOURCES ${_sources} ) endif() endfunction() diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -58,10 +58,11 @@ void loadDialectModule(llvm::StringRef dialectNamespace); /// Adds a user-friendly Attribute builder. - /// Raises an exception if the mapping already exists. + /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerAttributeBuilder(const std::string &attributeKind, - pybind11::function pyFunc); + pybind11::function pyFunc, + bool replace = false); /// Adds a user-friendly type caster. Raises an exception if the mapping /// already exists and replace == false. This is intended to be called by diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -242,19 +242,23 @@ static py::function dundeGetItemNamed(const std::string &attributeKind) { auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); if (!builder) - throw py::key_error(); + throw py::key_error(attributeKind); return *builder; } static void dundeSetItemNamed(const std::string &attributeKind, - py::function func) { - PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func)); + py::function func, bool replace) { + PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), + replace); } static void bind(py::module &m) { py::class_(m, "AttrBuilder", py::module_local()) .def_static("contains", &PyAttrBuilderMap::dunderContains) .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed) - .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed); + .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed, + "attribute_kind"_a, "attr_builder"_a, "replace"_a = false, + "Register an attribute builder for building MLIR " + "attributes from python values."); } }; diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -63,11 +63,13 @@ } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, - py::function pyFunc) { + py::function pyFunc, bool replace) { py::object &found = attributeBuilderMap[attributeKind]; - if (found) { + if (found && !found.is_none() && !replace) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + - attributeKind + "' is already registered") + attributeKind + + "' is already registered with func: " + + py::str(found).operator std::string()) .str()); } found = std::move(pyFunc); diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -52,7 +52,8 @@ TD_FILE dialects/AMDGPUOps.td SOURCES dialects/amdgpu.py - DIALECT_NAME amdgpu) + DIALECT_NAME amdgpu + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -109,7 +110,8 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/GPUOps.td SOURCES_GLOB dialects/gpu/*.py - DIALECT_NAME gpu) + DIALECT_NAME gpu + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -120,7 +122,17 @@ SOURCES_GLOB dialects/linalg/*.py DIALECT_NAME linalg - DEPENDS LinalgOdsGen) + DEPENDS LinalgOdsGen + GEN_ENUM_BINDINGS) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/LLVMOps.td + SOURCES + dialects/llvm.py + DIALECT_NAME llvm + GEN_ENUM_BINDINGS) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -140,16 +152,8 @@ dialects/_transform_ops_ext.py dialects/transform/__init__.py _mlir_libs/_mlir/dialects/transform/__init__.pyi - DIALECT_NAME transform) - -set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/TransformOps.td") -mlir_tablegen("dialects/_transform_enum_gen.py" -gen-python-enum-bindings) -add_public_tablegen_target(MLIRTransformDialectPyEnumGen) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.transform.enum_gen - ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" - ADD_TO_PARENT MLIRPythonSources.Dialects.transform - SOURCES "dialects/_transform_enum_gen.py") + DIALECT_NAME transform + GEN_ENUM_BINDINGS) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -159,16 +163,8 @@ dialects/_bufferization_transform_ops_ext.py dialects/transform/bufferization.py DIALECT_NAME transform - EXTENSION_NAME bufferization_transform) - -set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/BufferizationTransformOps.td") -mlir_tablegen("dialects/_bufferization_transform_enum_gen.py" -gen-python-enum-bindings) -add_public_tablegen_target(MLIRBufferizationTransformDialectPyEnumGen) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.bufferization_transform.enum_gen - ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" - ADD_TO_PARENT MLIRPythonSources.Dialects.bufferization_transform - SOURCES "dialects/_bufferization_transform_enum_gen.py") + EXTENSION_NAME bufferization_transform + GEN_ENUM_BINDINGS) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -208,7 +204,8 @@ dialects/_structured_transform_ops_ext.py dialects/transform/structured.py DIALECT_NAME transform - EXTENSION_NAME structured_transform) + EXTENSION_NAME structured_transform + GEN_ENUM_BINDINGS_TD_FILE dialects/LinalgTransformEnums.td) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -227,16 +224,8 @@ SOURCES dialects/transform/vector.py DIALECT_NAME transform - EXTENSION_NAME vector_transform) - -set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/VectorTransformOps.td") -mlir_tablegen("dialects/_vector_transform_enum_gen.py" -gen-python-enum-bindings) -add_public_tablegen_target(MLIRVectorTransformPyEnumGen) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.vector_transform.enum_gen - ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" - ADD_TO_PARENT MLIRPythonSources.Dialects.vector_transform - SOURCES "dialects/_vector_transform_enum_gen.py" ) + EXTENSION_NAME vector_transform + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -252,7 +241,8 @@ SOURCES dialects/arith.py dialects/_arith_ops_ext.py - DIALECT_NAME arith) + DIALECT_NAME arith + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -278,7 +268,8 @@ TD_FILE dialects/NVGPUOps.td SOURCES dialects/nvgpu.py - DIALECT_NAME nvgpu) + DIALECT_NAME nvgpu + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -286,7 +277,8 @@ TD_FILE dialects/NVVMOps.td SOURCES dialects/nvvm.py - DIALECT_NAME nvvm) + DIALECT_NAME nvvm + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -300,6 +292,7 @@ MLIRPythonSources.Dialects.quant ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + GEN_ENUM_BINDINGS SOURCES dialects/quant.py _mlir_libs/_mlir/dialects/quant.pyi) @@ -335,7 +328,8 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/SparseTensorOps.td SOURCES dialects/sparse_tensor.py - DIALECT_NAME sparse_tensor) + DIALECT_NAME sparse_tensor + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -358,7 +352,8 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/VectorOps.td SOURCES dialects/vector.py - DIALECT_NAME vector) + DIALECT_NAME vector + GEN_ENUM_BINDINGS) ################################################################################ # Python extensions. diff --git a/mlir/python/mlir/dialects/LLVMOps.td b/mlir/python/mlir/dialects/LLVMOps.td new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/LLVMOps.td @@ -0,0 +1,14 @@ +//===-- LlvmOps.td - Entry point for llvm bind ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_LLVM_OPS +#define PYTHON_BINDINGS_LLVM_OPS + +include "mlir/Dialect/LLVMIR/LLVMOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/LinalgTransformEnums.td b/mlir/python/mlir/dialects/LinalgTransformEnums.td new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/LinalgTransformEnums.td @@ -0,0 +1,14 @@ +//===-- LinalgTransformEnums.td --------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_LINALG_TRANSFORM_ENUMS +#define PYTHON_BINDINGS_LINALG_TRANSFORM_ENUMS + +include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td" + +#endif // PYTHON_BINDINGS_LINALG_TRANSFORM_ENUMS diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py --- a/mlir/python/mlir/dialects/amdgpu.py +++ b/mlir/python/mlir/dialects/amdgpu.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._amdgpu_ops_gen import * +from ._amdgpu_enum_gen import * diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._arith_ops_gen import * +from ._arith_enum_gen import * diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py --- a/mlir/python/mlir/dialects/gpu/__init__.py +++ b/mlir/python/mlir/dialects/gpu/__init__.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._gpu_ops_gen import * +from .._gpu_enum_gen import * diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -9,6 +9,7 @@ # definitions following these steps: # DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py. from .._linalg_ops_gen import * +from .._linalg_enum_gen import * # These are the ground truth functions defined as: # ``` diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/llvm.py copy from mlir/python/mlir/dialects/amdgpu.py copy to mlir/python/mlir/dialects/llvm.py --- a/mlir/python/mlir/dialects/amdgpu.py +++ b/mlir/python/mlir/dialects/llvm.py @@ -2,4 +2,5 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ._amdgpu_ops_gen import * +from ._llvm_ops_gen import * +from ._llvm_enum_gen import * diff --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py --- a/mlir/python/mlir/dialects/nvgpu.py +++ b/mlir/python/mlir/dialects/nvgpu.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._nvgpu_ops_gen import * +from ._nvgpu_enum_gen import * diff --git a/mlir/python/mlir/dialects/nvvm.py b/mlir/python/mlir/dialects/nvvm.py --- a/mlir/python/mlir/dialects/nvvm.py +++ b/mlir/python/mlir/dialects/nvvm.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._nvvm_ops_gen import * +from ._nvvm_enum_gen import * diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py --- a/mlir/python/mlir/dialects/sparse_tensor.py +++ b/mlir/python/mlir/dialects/sparse_tensor.py @@ -3,5 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._sparse_tensor_ops_gen import * +from ._sparse_tensor_enum_gen import * from .._mlir_libs._mlirDialectsSparseTensor import * from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._structured_transform_ops_gen import * +from .._structured_transform_enum_gen import * diff --git a/mlir/python/mlir/dialects/vector.py b/mlir/python/mlir/dialects/vector.py --- a/mlir/python/mlir/dialects/vector.py +++ b/mlir/python/mlir/dialects/vector.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._vector_ops_gen import * +from ._vector_enum_gen import * diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -8,9 +8,9 @@ # Convenience decorator for registering user-friendly Attribute builders. -def register_attribute_builder(kind): +def register_attribute_builder(kind, replace=False): def decorator_builder(func): - AttrBuilder.insert(kind, func) + AttrBuilder.insert(kind, func, replace=replace) return func return decorator_builder diff --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td --- a/mlir/test/mlir-tblgen/enums-python-bindings.td +++ b/mlir/test/mlir-tblgen/enums-python-bindings.td @@ -2,56 +2,63 @@ include "mlir/IR/EnumAttr.td" +def Test_Dialect : Dialect { + let name = "TestDialect"; + let cppNamespace = "::test"; +} + // CHECK: Autogenerated by mlir-tblgen; don't manually edit. -// CHECK: from enum import Enum +// CHECK: from enum import IntEnum, Enum, auto // CHECK: from ._ods_common import _cext as _ods_cext +// CHECK: from ..ir import register_attribute_builder // CHECK: _ods_ir = _ods_cext.ir def One : I32EnumAttrCase<"CaseOne", 1, "one">; def Two : I32EnumAttrCase<"CaseTwo", 2, "two">; def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two]>; -// CHECK: def _register_attribute_builder(kind): -// CHECK: def decorator_builder(func): -// CHECK: _ods_ir.AttrBuilder.insert(kind, func) -// CHECK: return func -// CHECK: return decorator_builder - -// CHECK-LABEL: class MyEnum(Enum): +// CHECK-LABEL: class MyEnum(IntEnum): // CHECK: """An example 32-bit enum""" -// CHECK: CASE_ONE = 1 -// CHECK: CASE_TWO = 2 +// CHECK: CaseOne = 1 +// CHECK: CaseTwo = 2 -// CHECK: def _as_int(self): -// CHECK: if self is MyEnum.CASE_ONE: -// CHECK: return 1 -// CHECK: if self is MyEnum.CASE_TWO: -// CHECK: return 2 +// CHECK: def __str__(self): +// CHECK: if self is MyEnum.CaseOne: +// CHECK: return "one" +// CHECK: if self is MyEnum.CaseTwo: +// CHECK: return "two" // CHECK: assert False, "Unknown MyEnum enum entry." +// CHECK: @register_attribute_builder("MyEnum") +// CHECK: def _my_enum(x, context): +// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x)) + +def TestMyEnum_Attr : EnumAttr; + def One64 : I64EnumAttrCase<"CaseOne64", 1, "one">; def Two64 : I64EnumAttrCase<"CaseTwo64", 2, "two">; def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>; -// CHECK: @_register_attribute_builder("MyEnum") -// CHECK: def _my_enum(x, context): -// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), x._as_int()) - -// CHECK-LABEL: class MyEnum64(Enum): +// CHECK-LABEL: class MyEnum64(IntEnum): // CHECK: """An example 64-bit enum""" -// CHECK: CASE_ONE64 = 1 -// CHECK: CASE_TWO64 = 2 +// CHECK: CaseOne64 = 1 +// CHECK: CaseTwo64 = 2 -// CHECK: def _as_int(self): -// CHECK: if self is MyEnum64.CASE_ONE64: -// CHECK: return 1 -// CHECK: if self is MyEnum64.CASE_TWO64: -// CHECK: return 2 +// CHECK: def __str__(self): +// CHECK: if self is MyEnum64.CaseOne64: +// CHECK: return "one" +// CHECK: if self is MyEnum64.CaseTwo64: +// CHECK: return "two" // CHECK: assert False, "Unknown MyEnum64 enum entry." -// CHECK: @_register_attribute_builder("MyEnum64") +// CHECK: @register_attribute_builder("MyEnum64") // CHECK: def _my_enum64(x, context): -// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), x._as_int()) +// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x)) + +// CHECK: @register_attribute_builder("TestMyEnum_Attr") +// CHECK: def _test_my_enum_attr(x, context): +// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect', context=context) + diff --git a/mlir/test/python/dialects/gpu.py b/mlir/test/python/dialects/gpu.py --- a/mlir/test/python/dialects/gpu.py +++ b/mlir/test/python/dialects/gpu.py @@ -1,22 +1,32 @@ # RUN: %PYTHON %s | FileCheck %s from mlir.ir import * -import mlir.dialects.gpu +import mlir.dialects.gpu as gpu import mlir.dialects.gpu.passes from mlir.passmanager import * def run(f): print("\nTEST:", f.__name__) - f() + with Context(), Location.unknown(): + f() + return f +# CHECK-LABEL: testGPUPass +# CHECK: SUCCESS +@run def testGPUPass(): - with Context() as context: - PassManager.parse("any(gpu-kernel-outlining)") + PassManager.parse("any(gpu-kernel-outlining)") print("SUCCESS") -# CHECK-LABEL: testGPUPass -# CHECK: SUCCESS -run(testGPUPass) +# CHECK-LABEL: testMMAElementWiseAttr +@run +def testMMAElementWiseAttr(): + module = Module.create() + with InsertionPoint(module.body): + gpu.BlockDimOp(gpu.Dimension.y) + # CHECK: %0 = gpu.block_dim y + print(module) + pass diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/llvm.py copy from mlir/test/python/dialects/nvvm.py copy to mlir/test/python/dialects/llvm.py --- a/mlir/test/python/dialects/nvvm.py +++ b/mlir/test/python/dialects/llvm.py @@ -2,7 +2,7 @@ # This is just a smoke test that the dialect is functional. from mlir.ir import * -from mlir.dialects import nvvm +from mlir.dialects import llvm def constructAndPrintInModule(f): @@ -18,5 +18,8 @@ # CHECK-LABEL: testSmoke @constructAndPrintInModule def testSmoke(): - # CHECK: nvvm.cp.async.wait.group 5 - nvvm.CpAsyncWaitGroupOp(5) + mat64f32_t = Type.parse( + "!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>" + ) + result = llvm.UndefOp(mat64f32_t) + # CHECK: %0 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py --- a/mlir/test/python/dialects/nvvm.py +++ b/mlir/test/python/dialects/nvvm.py @@ -3,6 +3,8 @@ from mlir.ir import * from mlir.dialects import nvvm +from mlir.dialects import llvm +from mlir.dialects import func def constructAndPrintInModule(f): @@ -18,5 +20,30 @@ # CHECK-LABEL: testSmoke @constructAndPrintInModule def testSmoke(): - # CHECK: nvvm.cp.async.wait.group 5 - nvvm.CpAsyncWaitGroupOp(5) + i64 = IntegerType.get_signless(64) + mat64f32_t = Type.parse( + "!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>" + ) + shape_attr = Attribute.parse("#nvvm.shape") + # CHECK-LABEL: func @wgmma_f32_f16_f16(%arg0: i64, %arg1: i64) + @func.FuncOp.from_py_func(i64, i64) + def wgmma_f32_f16_f16(desc_a, desc_b): + # CHECK: nvvm.cp.async.wait.group 5 + nvvm.CpAsyncWaitGroupOp(5) + # CHECK: %0 = llvm.mlir.undef : [[MAT_T:.*]] + result = llvm.UndefOp(mat64f32_t) + # CHECK: %1 = nvvm.wgmma.mma_async %arg0, %arg1, , D[%0, ], A[, , ], B[, , ] : [[MAT_T]] -> [[MAT_T]] + result1 = nvvm.WgmmaMmaAsyncOp( + results_=mat64f32_t, + inouts=result, + descriptorA=desc_a, + descriptorB=desc_b, + shape=shape_attr, + typeA=nvvm.WGMMATypes.f16, + typeB=nvvm.WGMMATypes.f16, + scaleD=nvvm.WGMMAScaleOut.zero, + scaleA=nvvm.WGMMAScaleIn.neg, + scaleB=nvvm.WGMMAScaleIn.neg, + layoutA=nvvm.MMALayout.col, + layoutB=nvvm.MMALayout.col, + ) diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -32,7 +32,7 @@ @run def testSequenceOp(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [transform.AnyOpType.get()], transform.AnyOpType.get(), ) @@ -48,15 +48,15 @@ @run def testNestedSequenceOp(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): nested = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], sequence.bodyTarget + transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget ) with InsertionPoint(nested.body): doubly_nested = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [transform.AnyOpType.get()], nested.bodyTarget, ) @@ -80,7 +80,7 @@ @run def testSequenceOpWithExtras(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get(), [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], @@ -95,14 +95,14 @@ @run def testNestedSequenceOpWithExtras(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get(), [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], ) with InsertionPoint(sequence.body): nested = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget, sequence.bodyExtraArgs, @@ -121,7 +121,7 @@ withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) with InsertionPoint(withPdl.body): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [transform.AnyOpType.get()], withPdl.bodyTarget, ) @@ -144,7 +144,7 @@ @run def testGetParentOp(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): transform.GetParentOp( @@ -160,7 +160,7 @@ @run def testMergeHandlesOp(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): transform.MergeHandlesOp([sequence.bodyTarget]) @@ -174,7 +174,7 @@ @run def testApplyPatternsOpCompact(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): @@ -189,7 +189,7 @@ @run def testApplyPatternsOpWithType(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], + transform.FailurePropagationMode.Propagate, [], transform.OperationType.get('test.dummy') ) with InsertionPoint(sequence.body): @@ -207,7 +207,7 @@ with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) with InsertionPoint(with_pdl.body): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget + transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget ) with InsertionPoint(sequence.body): m1 = transform_pdl.PDLMatchOp( diff --git a/mlir/test/python/dialects/transform_bufferization_ext.py b/mlir/test/python/dialects/transform_bufferization_ext.py --- a/mlir/test/python/dialects/transform_bufferization_ext.py +++ b/mlir/test/python/dialects/transform_bufferization_ext.py @@ -18,7 +18,7 @@ @run def testEmptyTensorToAllocTensorOpCompact(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("tensor.empty"), ) @@ -33,7 +33,7 @@ @run def testEmptyTensorToAllocTensorOpTyped(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("tensor.empty"), ) @@ -51,7 +51,7 @@ @run def testOneShotBufferizeOpCompact(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): bufferization.OneShotBufferizeOp(sequence.bodyTarget) @@ -64,7 +64,7 @@ @run def testOneShotBufferizeOpTyped(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): bufferization.OneShotBufferizeOp( @@ -80,7 +80,7 @@ @run def testOneShotBufferizeOpAttributes(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): bufferization.OneShotBufferizeOp( @@ -89,7 +89,7 @@ allow_unknown_ops=True, bufferize_function_boundaries=True, create_deallocs=False, - function_boundary_type_conversion=bufferization.LayoutMapOption.IDENTITY_LAYOUT_MAP, + function_boundary_type_conversion=bufferization.LayoutMapOption.IdentityLayoutMap, memcpy_op="linalg.copy", print_conflicts=True, test_analysis_only=True, diff --git a/mlir/test/python/dialects/transform_gpu_ext.py b/mlir/test/python/dialects/transform_gpu_ext.py --- a/mlir/test/python/dialects/transform_gpu_ext.py +++ b/mlir/test/python/dialects/transform_gpu_ext.py @@ -10,7 +10,7 @@ module = Module.create() with InsertionPoint(module.body): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get(), ) diff --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py --- a/mlir/test/python/dialects/transform_loop_ext.py +++ b/mlir/test/python/dialects/transform_loop_ext.py @@ -19,7 +19,7 @@ @run def getParentLoop(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) with InsertionPoint(sequence.body): loop.GetParentForOp( @@ -34,7 +34,7 @@ @run def loopOutline(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("scf.for"), ) @@ -54,7 +54,7 @@ @run def loopPeel(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("scf.for"), ) @@ -68,7 +68,7 @@ @run def loopPipeline(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("scf.for"), ) @@ -86,7 +86,7 @@ @run def loopUnroll(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("scf.for"), ) diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py --- a/mlir/test/python/dialects/transform_memref_ext.py +++ b/mlir/test/python/dialects/transform_memref_ext.py @@ -19,7 +19,7 @@ @run def testMemRefMultiBufferOpCompact(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("memref.alloc"), ) @@ -35,7 +35,7 @@ @run def testMemRefMultiBufferOpTyped(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("memref.alloc"), ) @@ -53,7 +53,7 @@ @run def testMemRefMultiBufferOpAttributes(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("memref.alloc"), ) diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -21,7 +21,7 @@ @run def testDecompose(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) with InsertionPoint(sequence.body): structured.DecomposeOp(sequence.bodyTarget) @@ -34,7 +34,7 @@ @run def testFuseIntoContainingOpTypes(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) @@ -56,7 +56,7 @@ @run def testFuseIntoContainingOpCompact(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) @@ -73,7 +73,7 @@ @run def testGeneralize(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) with InsertionPoint(sequence.body): structured.GeneralizeOp(sequence.bodyTarget) @@ -86,7 +86,7 @@ @run def testInterchange(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) with InsertionPoint(sequence.body): structured.InterchangeOp(sequence.bodyTarget, iterator_interchange=[1, 0]) @@ -100,7 +100,7 @@ @run def testMapCopyToThreadsOpCompact(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): structured.MapCopyToThreadsOp( @@ -117,7 +117,7 @@ @run def testMapCopyToThreadsOpTypes(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): structured.MapCopyToThreadsOp( @@ -138,7 +138,7 @@ @run def testMatchOpNamesString(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): structured.MatchOp.match_op_names(sequence.bodyTarget, "test.dummy") @@ -152,7 +152,7 @@ @run def testMatchOpNamesList(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) @@ -166,7 +166,7 @@ @run def testMatchOpNamesTyped(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): structured.MatchOp.match_op_names( @@ -184,7 +184,7 @@ @run def testMultitileSizes(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) with InsertionPoint(sequence.body): structured.MultiTileSizesOp( @@ -201,7 +201,7 @@ @run def testPad(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) with InsertionPoint(sequence.body): structured.PadOp( @@ -223,7 +223,7 @@ @run def testScalarize(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) with InsertionPoint(sequence.body): structured.ScalarizeOp(sequence.bodyTarget) @@ -235,7 +235,7 @@ @run def testSplit(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) with InsertionPoint(sequence.body): split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42) @@ -249,7 +249,7 @@ @run def testTileCompact(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) with InsertionPoint(sequence.body): structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1]) @@ -263,7 +263,7 @@ @run def testTileAttributes(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) attr = DenseI64ArrayAttr.get([4, 8]) ichange = DenseI64ArrayAttr.get([0, 1]) @@ -279,7 +279,7 @@ @run def testTileZero(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) with InsertionPoint(sequence.body): structured.TileOp( @@ -297,7 +297,7 @@ with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get()) with InsertionPoint(with_pdl.body): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget + transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget ) with InsertionPoint(sequence.body): m1 = transform_pdl.PDLMatchOp( @@ -317,7 +317,7 @@ @run def testTileExplicitLoopTypeSingle(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): structured.TileOp( @@ -332,7 +332,7 @@ @run def testTileExplicitLoopTypeAll(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) types = [ transform.OperationType.get(x) @@ -350,7 +350,7 @@ @run def testTileToForallCompact(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("linalg.matmul"), ) @@ -366,7 +366,7 @@ @run def testTileToForallLoopsAndTileOpTypes(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): structured.TileToForallOp( @@ -385,7 +385,7 @@ @run def testTileToForallTileSizes(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): structured.TileToForallOp(sequence.bodyTarget, tile_sizes=[2, 3, 4]) @@ -398,7 +398,7 @@ @run def testTileToForallMixedDynamic(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) @@ -412,7 +412,7 @@ @run def testTileToForallPackedDynamic(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) @@ -426,7 +426,7 @@ @run def testTileToForallMapping(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): mapping = Attribute.parse("[ #gpu.thread, #gpu.thread ]") @@ -442,7 +442,7 @@ @run def testVectorize(): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) with InsertionPoint(sequence.body): structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True) @@ -451,3 +451,53 @@ # CHECK: transform.sequence # CHECK: = transform.structured.vectorize # CHECK: {vectorize_padding} + + +@run +def testMatchInterfaceEnum(): + names = ArrayAttr.get([StringAttr.get("test.dummy")]) + result_type = transform.AnyOpType.get() + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + fused = structured.MatchOp.__base__( + result_type, + sequence.bodyTarget, + ops=names, + interface=structured.MatchInterfaceEnum.LinalgOp, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMatchInterfaceEnum + # CHECK: transform.sequence + # CHECK: = transform.structured.match + # CHECK: interface{LinalgOp} + + +@run +def testMatchInterfaceEnumReplaceAttributeBuilder(): + @register_attribute_builder("MatchInterfaceEnum", replace=True) + def match_interface_enum(x, context): + if x == "LinalgOp": + y = 0 + elif x == "TilingInterface": + y = 1 + return IntegerAttr.get(IntegerType.get_signless(32, context=context), y) + + names = ArrayAttr.get([StringAttr.get("test.dummy")]) + result_type = transform.AnyOpType.get() + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + fused = structured.MatchOp.__base__( + result_type, + sequence.bodyTarget, + ops=names, + interface="TilingInterface", + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMatchInterfaceEnumReplaceAttributeBuilder + # CHECK: transform.sequence + # CHECK: = transform.structured.match + # CHECK: interface{TilingInterface} diff --git a/mlir/test/python/dialects/transform_tensor_ext.py b/mlir/test/python/dialects/transform_tensor_ext.py --- a/mlir/test/python/dialects/transform_tensor_ext.py +++ b/mlir/test/python/dialects/transform_tensor_ext.py @@ -11,7 +11,7 @@ module = Module.create() with InsertionPoint(module.body): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get(), ) diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py --- a/mlir/test/python/dialects/transform_vector_ext.py +++ b/mlir/test/python/dialects/transform_vector_ext.py @@ -10,7 +10,7 @@ module = Module.create() with InsertionPoint(module.body): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get(), ) @@ -72,12 +72,12 @@ # CHECK: transform.apply_patterns.vector.lower_contraction # CHECK-SAME: lowering_strategy = matmulintrinsics vector.ApplyLowerContractionPatternsOp( - lowering_strategy=vector.VectorContractLowering.MATMUL + lowering_strategy=vector.VectorContractLowering.Matmul ) # CHECK: transform.apply_patterns.vector.lower_contraction # CHECK-SAME: lowering_strategy = parallelarith vector.ApplyLowerContractionPatternsOp( - lowering_strategy=vector.VectorContractLowering.PARALLEL_ARITH + lowering_strategy=vector.VectorContractLowering.ParallelArith ) # CHECK: transform.apply_patterns.vector.lower_multi_reduction @@ -85,12 +85,12 @@ # CHECK: transform.apply_patterns.vector.lower_multi_reduction # This is the default mode, not printed. vector.ApplyLowerMultiReductionPatternsOp( - lowering_strategy=vector.VectorMultiReductionLowering.INNER_PARALLEL + lowering_strategy=vector.VectorMultiReductionLowering.InnerParallel ) # CHECK: transform.apply_patterns.vector.lower_multi_reduction # CHECK-SAME: lowering_strategy = innerreduction vector.ApplyLowerMultiReductionPatternsOp( - lowering_strategy=vector.VectorMultiReductionLowering.INNER_REDUCTION + lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction ) # CHECK: transform.apply_patterns.vector.lower_transpose @@ -101,31 +101,31 @@ # CHECK-SAME: lowering_strategy = eltwise # CHECK-SAME: avx2_lowering_strategy = false vector.ApplyLowerTransposePatternsOp( - lowering_strategy=vector.VectorTransposeLowering.ELT_WISE + lowering_strategy=vector.VectorTransposeLowering.EltWise ) # CHECK: transform.apply_patterns.vector.lower_transpose # CHECK-SAME: lowering_strategy = flat_transpose # CHECK-SAME: avx2_lowering_strategy = false vector.ApplyLowerTransposePatternsOp( - lowering_strategy=vector.VectorTransposeLowering.FLAT + lowering_strategy=vector.VectorTransposeLowering.Flat ) # CHECK: transform.apply_patterns.vector.lower_transpose # CHECK-SAME: lowering_strategy = shuffle_1d # CHECK-SAME: avx2_lowering_strategy = false vector.ApplyLowerTransposePatternsOp( - lowering_strategy=vector.VectorTransposeLowering.SHUFFLE1_D + lowering_strategy=vector.VectorTransposeLowering.Shuffle1D ) # CHECK: transform.apply_patterns.vector.lower_transpose # CHECK-SAME: lowering_strategy = shuffle_16x16 # CHECK-SAME: avx2_lowering_strategy = false vector.ApplyLowerTransposePatternsOp( - lowering_strategy=vector.VectorTransposeLowering.SHUFFLE16X16 + lowering_strategy=vector.VectorTransposeLowering.Shuffle16x16 ) # CHECK: transform.apply_patterns.vector.lower_transpose # CHECK-SAME: lowering_strategy = flat_transpose # CHECK-SAME: avx2_lowering_strategy = true vector.ApplyLowerTransposePatternsOp( - lowering_strategy=vector.VectorTransposeLowering.FLAT, + lowering_strategy=vector.VectorTransposeLowering.Flat, avx2_lowering_strategy=True, ) @@ -134,20 +134,20 @@ # CHECK: transform.apply_patterns.vector.split_transfer_full_partial # CHECK-SAME: split_transfer_strategy = none vector.ApplySplitTransferFullPartialPatternsOp( - split_transfer_strategy=vector.VectorTransferSplit.NONE + split_transfer_strategy=vector.VectorTransferSplit.None_ ) # CHECK: transform.apply_patterns.vector.split_transfer_full_partial # CHECK-SAME: split_transfer_strategy = "vector-transfer" vector.ApplySplitTransferFullPartialPatternsOp( - split_transfer_strategy=vector.VectorTransferSplit.VECTOR_TRANSFER + split_transfer_strategy=vector.VectorTransferSplit.VectorTransfer ) # CHECK: transform.apply_patterns.vector.split_transfer_full_partial # This is the default mode, not printed. vector.ApplySplitTransferFullPartialPatternsOp( - split_transfer_strategy=vector.VectorTransferSplit.LINALG_COPY + split_transfer_strategy=vector.VectorTransferSplit.LinalgCopy ) # CHECK: transform.apply_patterns.vector.split_transfer_full_partial # CHECK-SAME: split_transfer_strategy = "force-in-bounds" vector.ApplySplitTransferFullPartialPatternsOp( - split_transfer_strategy=vector.VectorTransferSplit.FORCE_IN_BOUNDS + split_transfer_strategy=vector.VectorTransferSplit.ForceInBounds ) diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py --- a/mlir/test/python/dialects/vector.py +++ b/mlir/test/python/dialects/vector.py @@ -64,3 +64,21 @@ # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]] # CHECK-NOT: %[[MASK]] print(module) + + +# CHECK-LABEL: TEST: testBitEnumCombiningKind +@run +def testBitEnumCombiningKind(): + module = Module.create() + with InsertionPoint(module.body): + f32 = F32Type.get() + vector_type = VectorType.get([16], f32) + + @func.FuncOp.from_py_func(vector_type) + def reduction(arg): + v = vector.ReductionOp(f32, vector.CombiningKind.ADD, arg) + return v + + # CHECK: func.func @reduction(%[[VEC:.*]]: vector<16xf32>) -> f32 { + # CHECK: %0 = vector.reduction , %[[VEC]] : vector<16xf32> into f32 + print(module) diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp @@ -17,6 +17,8 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Record.h" +#include + using namespace mlir; using namespace mlir::tblgen; @@ -24,48 +26,51 @@ constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. -from enum import Enum +from enum import IntEnum, Enum, auto from ._ods_common import _cext as _ods_cext +from ..ir import register_attribute_builder _ods_ir = _ods_cext.ir -# Convenience decorator for registering user-friendly Attribute builders. -def _register_attribute_builder(kind): - def decorator_builder(func): - _ods_ir.AttrBuilder.insert(kind, func) - return func - - return decorator_builder - )Py"; /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE. static std::string makePythonEnumCaseName(StringRef name) { - return StringRef(llvm::convertToSnakeFromCamelCase(name)).upper(); + std::set keywords = { + "False", "None", "True", "and", "as", "assert", "async", + "await", "break", "class", "continue", "def", "del", "elif", + "else", "except", "finally", "for", "from", "global", "if", + "import", "in", "is", "lambda", "nonlocal", "not", "or", + "pass", "raise", "return", "try", "while", "with", "yield"}; + if (keywords.count(name.str())) + return (name + "_").str(); + return name.str(); } /// Emits the Python class for the given enum. static void emitEnumClass(StringRef enumName, StringRef description, ArrayRef cases, raw_ostream &os) { - os << llvm::formatv("class {0}(Enum):\n", enumName); + os << llvm::formatv("class {0}(IntEnum):\n", enumName); if (!description.empty()) os << llvm::formatv(" \"\"\"{0}\"\"\"\n", description); os << "\n"; for (const EnumAttrCase &enumCase : cases) { - os << llvm::formatv(" {0} = {1}\n", - makePythonEnumCaseName(enumCase.getSymbol()), - enumCase.getValue()); + os << llvm::formatv( + " {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()), + enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue()) + : "auto()"); } os << "\n"; - os << llvm::formatv(" def _as_int(self):\n"); + os << llvm::formatv(" def __str__(self):\n"); for (const EnumAttrCase &enumCase : cases) { os << llvm::formatv(" if self is {0}.{1}:\n", enumName, makePythonEnumCaseName(enumCase.getSymbol())); - os << llvm::formatv(" return {0}\n", enumCase.getValue()); + os << llvm::formatv(" return \"{0}\"\n", enumCase.getStr()); } os << llvm::formatv(" assert False, \"Unknown {0} enum entry.\"\n\n\n", enumName); + os << "\n"; } /// Attempts to extract the bitwidth B from string "uintB_t" describing the @@ -90,7 +95,7 @@ return true; } - os << llvm::formatv("@_register_attribute_builder(\"{0}\")\n", + os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", enumAttr.getAttrDefName()); os << llvm::formatv( "def _{0}(x, context):\n", @@ -98,28 +103,66 @@ os << llvm::formatv( " return " "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, " - "context=context), x._as_int())\n\n", + "context=context), int(x))\n\n", bitwidth); return false; } +/// Emits an attribute builder for the given dialect enum attribute to support +/// automatic conversion between enum values and attributes in Python. Returns +/// `false` on success, `true` on failure. +static bool emitDialectEnumAttributeBuilder(StringRef attrDefName, + StringRef formatString, + raw_ostream &os) { + os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName); + os << llvm::formatv("def _{0}(x, context):\n", + llvm::convertToSnakeFromCamelCase(attrDefName)); + os << llvm::formatv(" return " + "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n", + formatString); + return false; +} + /// Emits Python bindings for all enums in the record keeper. Returns /// `false` on success, `true` on failure. static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { os << fileHeader; - std::vector defs = - recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"); - for (const llvm::Record *def : defs) { - EnumAttr enumAttr(*def); - if (enumAttr.isBitEnum()) { - llvm::errs() << "bit enums not supported\n"; - return true; + for (auto &it : recordKeeper.getDefs()) { + if (!it.second->isSubClassOf("EnumAttr") and + !it.second->isSubClassOf("EnumAttrInfo")) + continue; + if (it.second->isSubClassOf("EnumAttrInfo")) { + EnumAttr enumAttr(*it.second); + emitEnumClass(enumAttr.getEnumClassName(), enumAttr.getSummary(), + enumAttr.getAllCases(), os); + emitAttributeBuilder(enumAttr, os); + } else { + Attribute attr(&*it.second); + StringRef dialect = + attr.getDef().getValueAsDef("dialect")->getValueAsString("name"); + StringRef mnemonic = attr.getDef().getValueAsString("mnemonic"); + StringRef assemblyFormat = + attr.getDef().getValueAsString("assemblyFormat"); + + if ((assemblyFormat != "`<` $value `>`" && assemblyFormat != "$value") || + attr.getDef().getValueAsBit("hasCustomAssemblyFormat")) { + llvm::errs() << "unsupported assembly format for enum"; + return true; + } + + if (assemblyFormat == "`<` $value `>`") { + emitDialectEnumAttributeBuilder( + attr.getAttrDefName(), + llvm::formatv("#{0}.{1}<{{x}>", dialect, mnemonic).str(), os); + } else if (assemblyFormat == "$value") { + emitDialectEnumAttributeBuilder( + attr.getAttrDefName(), + llvm::formatv("#{0}<{1} {{x}>", dialect, mnemonic).str(), os); + } } - emitEnumClass(enumAttr.getEnumClassName(), enumAttr.getSummary(), - enumAttr.getAllCases(), os); - emitAttributeBuilder(enumAttr, os); } + return false; }