diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -134,6 +134,15 @@ _mlir_libs/_mlir/dialects/transform/__init__.pyi DIALECT_NAME transform) +set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/TransformOps.td") +mlir_tablegen("dialects/_transform_enum_gen.py" -gen-python-enum-bindings) +add_public_tablegen_target(MLIRTransformDialectPyEnumGen) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.transform.enum_gen + ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" + ADD_TO_PARENT MLIRPythonSources.Dialects.transform + SOURCES "dialects/_transform_enum_gen.py") + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -15,68 +15,66 @@ class CastOp: - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - loc=None, - ip=None, - ): - super().__init__( - result_type, _get_op_result_or_value(target), loc=loc, ip=ip - ) + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + loc=None, + ip=None, + ): + super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) class ApplyPatternsOp: + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + loc=None, + ip=None, + ): + operands = [] + operands.append(_get_op_result_or_value(target)) + super().__init__( + self.build_generic( + attributes={}, + results=[], + operands=operands, + successors=None, + regions=None, + loc=loc, + ip=ip, + ) + ) + self.regions[0].blocks.append() - def __init__( - self, - target: Union[Operation, Value, OpView], - *, - loc=None, - ip=None, - ): - operands = [] - operands.append(_get_op_result_or_value(target)) - super().__init__( - self.build_generic(attributes={}, - results=[], - operands=operands, - successors=None, - regions=None, - loc=loc, - ip=ip)) - self.regions[0].blocks.append() - - @property - def patterns(self) -> Block: - return self.regions[0].blocks[0] + @property + def patterns(self) -> Block: + return self.regions[0].blocks[0] class testGetParentOp: - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - isolated_from_above: bool = False, - op_name: Optional[str] = None, - deduplicate: bool = False, - loc=None, - ip=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - isolated_from_above=isolated_from_above, - op_name=op_name, - deduplicate=deduplicate, - loc=loc, - ip=ip, - ) + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + isolated_from_above: bool = False, + op_name: Optional[str] = None, + deduplicate: bool = False, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + isolated_from_above=isolated_from_above, + op_name=op_name, + deduplicate=deduplicate, + loc=loc, + ip=ip, + ) class MergeHandlesOp: @@ -130,12 +128,6 @@ else None ) root_type = root.type if not isinstance(target, Type) else target - if not isinstance(failure_propagation_mode, Attribute): - failure_propagation_mode_attr = IntegerAttr.get( - IntegerType.get_signless(32), failure_propagation_mode._as_int() - ) - else: - failure_propagation_mode_attr = failure_propagation_mode if extra_bindings is None: extra_bindings = [] @@ -152,7 +144,7 @@ super().__init__( results_=results, - failure_propagation_mode=failure_propagation_mode_attr, + failure_propagation_mode=failure_propagation_mode, root=root, extra_bindings=extra_bindings, ) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -2,22 +2,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from enum import Enum - - -class FailurePropagationMode(Enum): - """Propagation mode for silenceable errors.""" - - PROPAGATE = 1 - SUPPRESS = 2 - - def _as_int(self): - if self is FailurePropagationMode.PROPAGATE: - return 1 - - assert self is FailurePropagationMode.SUPPRESS - return 2 - - +from .._transform_enum_gen import * from .._transform_ops_gen import * from ..._mlir_libs._mlirDialectsTransform import * diff --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/enums-python-bindings.td @@ -0,0 +1,57 @@ +// RUN: mlir-tblgen -gen-python-enum-bindings %s -I %S/../../include | FileCheck %s + +include "mlir/IR/EnumAttr.td" + +// CHECK: Autogenerated by mlir-tblgen; don't manually edit. + +// CHECK: from enum import Enum +// CHECK: from ._ods_common import _cext as _ods_cext +// CHECK: _ods_ir = _ods_cext.ir + +def One : I32EnumAttrCase<"CaseOne", 1, "one">; +def Two : I32EnumAttrCase<"CaseTwo", 2, "two">; + +def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two]>; +// CHECK: def _register_attribute_builder(kind): +// CHECK: def decorator_builder(func): +// CHECK: _ods_ir.AttrBuilder.insert(kind, func) +// CHECK: return func +// CHECK: return decorator_builder + +// CHECK-LABEL: class MyEnum(Enum): +// CHECK: """An example 32-bit enum""" + +// CHECK: CASE_ONE = 1 +// CHECK: CASE_TWO = 2 + +// CHECK: def _as_int(self): +// CHECK: if self is MyEnum.CASE_ONE: +// CHECK: return 1 +// CHECK: if self is MyEnum.CASE_TWO: +// CHECK: return 2 +// CHECK: assert False, "Unknown MyEnum enum entry." + +def One64 : I64EnumAttrCase<"CaseOne64", 1, "one">; +def Two64 : I64EnumAttrCase<"CaseTwo64", 2, "two">; + +def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>; +// CHECK: @_register_attribute_builder("MyEnum") +// CHECK: def _my_enum(x, context): +// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), x._as_int()) + +// CHECK-LABEL: class MyEnum64(Enum): +// CHECK: """An example 64-bit enum""" + +// CHECK: CASE_ONE64 = 1 +// CHECK: CASE_TWO64 = 2 + +// CHECK: def _as_int(self): +// CHECK: if self is MyEnum64.CASE_ONE64: +// CHECK: return 1 +// CHECK: if self is MyEnum64.CASE_TWO64: +// CHECK: return 2 +// CHECK: assert False, "Unknown MyEnum64 enum entry." + +// CHECK: @_register_attribute_builder("MyEnum64") +// CHECK: def _my_enum64(x, context): +// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), x._as_int()) diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -14,6 +14,7 @@ DialectGen.cpp DirectiveCommonGen.cpp EnumsGen.cpp + EnumPythonBindingGen.cpp FormatGen.cpp LLVMIRConversionGen.cpp LLVMIRIntrinsicGen.cpp diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp @@ -0,0 +1,130 @@ +//===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// EnumPythonBindingGen uses ODS specification of MLIR enum attributes to +// generate the corresponding Python binding classes. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/GenInfo.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +/// File header and includes. +constexpr const char *fileHeader = R"Py( +# Autogenerated by mlir-tblgen; don't manually edit. + +from enum import Enum +from ._ods_common import _cext as _ods_cext +_ods_ir = _ods_cext.ir + +# Convenience decorator for registering user-friendly Attribute builders. +def _register_attribute_builder(kind): + def decorator_builder(func): + _ods_ir.AttrBuilder.insert(kind, func) + return func + + return decorator_builder + +)Py"; + +/// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE. +static std::string makePythonEnumCaseName(StringRef name) { + return StringRef(llvm::convertToSnakeFromCamelCase(name)).upper(); +} + +/// Emits the Python class for the given enum. +static void emitEnumClass(StringRef enumName, StringRef description, + ArrayRef cases, raw_ostream &os) { + os << llvm::formatv("class {0}(Enum):\n", enumName); + if (!description.empty()) + os << llvm::formatv(" \"\"\"{0}\"\"\"\n", description); + os << "\n"; + + for (const EnumAttrCase &enumCase : cases) { + os << llvm::formatv(" {0} = {1}\n", + makePythonEnumCaseName(enumCase.getSymbol()), + enumCase.getValue()); + } + + os << "\n"; + os << llvm::formatv(" def _as_int(self):\n"); + for (const EnumAttrCase &enumCase : cases) { + os << llvm::formatv(" if self is {0}.{1}:\n", enumName, + makePythonEnumCaseName(enumCase.getSymbol())); + os << llvm::formatv(" return {0}\n", enumCase.getValue()); + } + os << llvm::formatv(" assert False, \"Unknown {0} enum entry.\"\n\n\n", + enumName); +} + +/// Attempts to extract the bitwidth B from string "uintB_t" describing the +/// type. This bitwidth information is not readily available in ODS. Returns +/// `false` on success, `true` on failure. +static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) { + if (!uintType.consume_front("uint")) + return true; + if (!uintType.consume_back("_t")) + return true; + return uintType.getAsInteger(/*Radix=*/10, bitwidth); +} + +/// Emits an attribute builder for the given enum attribute to support automatic +/// conversion between enum values and attributes in Python. Returns +/// `false` on success, `true` on failure. +static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) { + int64_t bitwidth; + if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) { + llvm::errs() << "failed to identify bitwidth of " + << enumAttr.getUnderlyingType(); + return true; + } + + os << llvm::formatv("@_register_attribute_builder(\"{0}\")\n", + enumAttr.getAttrDefName()); + os << llvm::formatv( + "def _{0}(x, context):\n", + llvm::convertToSnakeFromCamelCase(enumAttr.getAttrDefName())); + os << llvm::formatv( + " return " + "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, " + "context=context), x._as_int())\n\n", + bitwidth); + return false; +} + +/// Emits Python bindings for all enums in the record keeper. Returns +/// `false` on success, `true` on failure. +static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper, + raw_ostream &os) { + os << fileHeader; + std::vector defs = + recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"); + for (const llvm::Record *def : defs) { + EnumAttr enumAttr(*def); + if (enumAttr.isBitEnum()) { + llvm::errs() << "bit enums not supported\n"; + return true; + } + emitEnumClass(enumAttr.getEnumClassName(), enumAttr.getSummary(), + enumAttr.getAllCases(), os); + emitAttributeBuilder(enumAttr, os); + } + return false; +} + +// Registers the enum utility generator to mlir-tblgen. +static mlir::GenRegistration + genPythonEnumBindings("gen-python-enum-bindings", + "Generate Python bindings for enum attributes", + &emitPythonEnums); diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -732,6 +732,25 @@ # Transform dialect and extensions. ##---------------------------------------------------------------------------## + +gentbl_filegroup( + name = "TransformEnumPyGen", + tbl_outs = [ + ( + ["-gen-python-enum-bindings"], + "mlir/dialects/_transform_enum_gen.py", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/TransformOps.td", + deps = [ + "//mlir:CallInterfacesTdFiles", + "//mlir:FunctionInterfacesTdFiles", + "//mlir:OpBaseTdFiles", + "//mlir:TransformDialectTdFiles", + ], +) + gentbl_filegroup( name = "TransformOpsPyGen", tbl_outs = [ @@ -898,6 +917,7 @@ ":MemRefTransformOpsPyGen", ":PDLTransformOpsPyGen", ":StructuredTransformOpsPyGen", + ":TransformEnumPyGen", ":TransformOpsPyGen", ], )