diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -10,66 +10,74 @@ from typing import Optional, Union - -def _get_int64_attr(arg: Optional[Union[int, IntegerAttr]], - default_value: int = None): - if isinstance(arg, IntegerAttr): - return arg - - if arg is None: - assert default_value is not None, "must provide default value" - arg = default_value - - return IntegerAttr.get(IntegerType.get_signless(64), arg) +# def _get_int64_attr(arg: Optional[Union[int, IntegerAttr]], default_value: int = None): +# if isinstance(arg, IntegerAttr): +# return arg +# +# if arg is None: +# assert default_value is not None, "must provide default value" +# arg = default_value +# +# return IntegerAttr.get(IntegerType.get_signless(64), arg) class GetParentForOp: """Extension for GetParentForOp.""" - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - num_loops: int = 1, - ip=None, - loc=None): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + num_loops: Optional[int] = None, + ip=None, + loc=None, + ): + if num_loops is None: + num_loops = 1 super().__init__( result_type, _get_op_result_or_value(target), - num_loops=_get_int64_attr(num_loops, default_value=1), + num_loops=num_loops, ip=ip, - loc=loc) + loc=loc, + ) class LoopOutlineOp: """Extension for LoopOutlineOp.""" - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - func_name: Union[str, StringAttr], - ip=None, - loc=None): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + func_name: Union[str, StringAttr], + ip=None, + loc=None, + ): super().__init__( result_type, _get_op_result_or_value(target), func_name=(func_name if isinstance(func_name, StringAttr) else StringAttr.get(func_name)), ip=ip, - loc=loc) + loc=loc, + ) class LoopPeelOp: """Extension for LoopPeelOp.""" - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - fail_if_already_divisible: Union[bool, BoolAttr] = False, - ip=None, - loc=None): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + fail_if_already_divisible: Union[bool, BoolAttr] = False, + ip=None, + loc=None, + ): super().__init__( result_type, _get_op_result_or_value(target), @@ -77,40 +85,51 @@ fail_if_already_divisible, BoolAttr) else BoolAttr.get(fail_if_already_divisible)), ip=ip, - loc=loc) + loc=loc, + ) class LoopPipelineOp: """Extension for LoopPipelineOp.""" - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - iteration_interval: Optional[Union[int, IntegerAttr]] = None, - read_latency: Optional[Union[int, IntegerAttr]] = None, - ip=None, - loc=None): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + iteration_interval: Optional[Union[int, IntegerAttr]] = None, + read_latency: Optional[Union[int, IntegerAttr]] = None, + ip=None, + loc=None, + ): + if iteration_interval is None: + iteration_interval = 1 + if read_latency is None: + read_latency = 10 super().__init__( result_type, _get_op_result_or_value(target), - iteration_interval=_get_int64_attr(iteration_interval, default_value=1), - read_latency=_get_int64_attr(read_latency, default_value=10), + iteration_interval=iteration_interval, + read_latency=read_latency, ip=ip, - loc=loc) + loc=loc, + ) class LoopUnrollOp: """Extension for LoopUnrollOp.""" - def __init__(self, - target: Union[Operation, Value], - *, - factor: Union[int, IntegerAttr], - ip=None, - loc=None): + def __init__( + self, + target: Union[Operation, Value], + *, + factor: Union[int, IntegerAttr], + ip=None, + loc=None, + ): super().__init__( _get_op_result_or_value(target), - factor=_get_int64_attr(factor), + factor=factor, ip=ip, - loc=loc) + loc=loc, + ) diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -8,61 +8,27 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Union, Optional, Sequence, List, Mapping -from ._ods_common import get_op_result_or_value as _get_value, get_op_results_or_values as _get_values - - -def _get_int_attr(bits: int, value: Union[IntegerAttr, int]) -> IntegerAttr: - """Converts the given value to signless integer attribute of given bit width.""" - if isinstance(value, int): - ty = IntegerType.get_signless(bits) - return IntegerAttr.get(ty, value) - else: - return value - - -def _get_array_attr(attrs: Union[ArrayAttr, Sequence[Attribute]]) -> ArrayAttr: - """Converts the given value to array attribute.""" - if isinstance(attrs, ArrayAttr): - return attrs - else: - return ArrayAttr.get(list(attrs)) - - -def _get_str_array_attr(attrs: Union[ArrayAttr, Sequence[str]]) -> ArrayAttr: - """Converts the given value to string array attribute.""" - if isinstance(attrs, ArrayAttr): - return attrs - else: - return ArrayAttr.get([StringAttr.get(s) for s in attrs]) - - -def _get_str_attr(name: Union[StringAttr, str]) -> Optional[StringAttr]: - """Converts the given value to string attribute.""" - if isinstance(name, str): - return StringAttr.get(name) - else: - return name - - -def _get_type_attr(type: Union[TypeAttr, Type]) -> TypeAttr: - """Converts the given value to type attribute.""" - if isinstance(type, Type): - return TypeAttr.get(type) - else: - return type +from typing import Union, Optional, Sequence, Mapping +from ._ods_common import ( + get_op_result_or_value as _get_value, + get_op_results_or_values as _get_values, +) class ApplyNativeConstraintOp: """Specialization for PDL apply native constraint op class.""" - def __init__(self, - name: Union[str, StringAttr], - args: Sequence[Union[OpView, Operation, Value]] = [], - *, - loc=None, - ip=None): - name = _get_str_attr(name) + def __init__( + self, + name: Union[str, StringAttr], + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + name = name args = _get_values(args) super().__init__(name, args, loc=loc, ip=ip) @@ -70,14 +36,18 @@ class ApplyNativeRewriteOp: """Specialization for PDL apply native rewrite op class.""" - def __init__(self, - results: Sequence[Type], - name: Union[str, StringAttr], - args: Sequence[Union[OpView, Operation, Value]] = [], - *, - loc=None, - ip=None): - name = _get_str_attr(name) + def __init__( + self, + results: Sequence[Type], + name: Union[str, StringAttr], + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + name = name args = _get_values(args) super().__init__(results, name, args, loc=loc, ip=ip) @@ -85,12 +55,14 @@ class AttributeOp: """Specialization for PDL attribute op class.""" - def __init__(self, - valueType: Optional[Union[OpView, Operation, Value]] = None, - value: Optional[Attribute] = None, - *, - loc=None, - ip=None): + def __init__( + self, + valueType: Optional[Union[OpView, Operation, Value]] = None, + value: Optional[Attribute] = None, + *, + loc=None, + ip=None, + ): valueType = valueType if valueType is None else _get_value(valueType) result = pdl.AttributeType.get() super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) @@ -99,11 +71,13 @@ class EraseOp: """Specialization for PDL erase op class.""" - def __init__(self, - operation: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None): + def __init__( + self, + operation: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): operation = _get_value(operation) super().__init__(operation, loc=loc, ip=ip) @@ -111,11 +85,13 @@ class OperandOp: """Specialization for PDL operand op class.""" - def __init__(self, - type: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None): + def __init__( + self, + type: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): type = type if type is None else _get_value(type) result = pdl.ValueType.get() super().__init__(result, valueType=type, loc=loc, ip=ip) @@ -124,11 +100,13 @@ class OperandsOp: """Specialization for PDL operands op class.""" - def __init__(self, - types: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None): + def __init__( + self, + types: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): types = types if types is None else _get_value(types) result = pdl.RangeType.get(pdl.ValueType.get()) super().__init__(result, valueType=types, loc=loc, ip=ip) @@ -137,15 +115,24 @@ class OperationOp: """Specialization for PDL operand op class.""" - def __init__(self, - name: Optional[Union[str, StringAttr]] = None, - args: Sequence[Union[OpView, Operation, Value]] = [], - attributes: Mapping[str, Union[OpView, Operation, Value]] = {}, - types: Sequence[Union[OpView, Operation, Value]] = [], - *, - loc=None, - ip=None): - name = name if name is None else _get_str_attr(name) + def __init__( + self, + name: Optional[Union[str, StringAttr]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + attributes: Optional[Mapping[str, Union[OpView, Operation, + Value]]] = None, + types: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if types is None: + types = [] + if attributes is None: + attributes = {} + if args is None: + args = [] + name = name if name is None else name args = _get_values(args) attrNames = [] attrValues = [] @@ -155,22 +142,29 @@ attrNames = ArrayAttr.get(attrNames) types = _get_values(types) result = pdl.OperationType.get() - super().__init__(result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip) + super().__init__(result, + args, + attrValues, + attrNames, + types, + opName=name, + loc=loc, + ip=ip) class PatternOp: """Specialization for PDL pattern op class.""" - def __init__(self, - benefit: Union[IntegerAttr, int], - name: Optional[Union[StringAttr, str]] = None, - *, - loc=None, - ip=None): + def __init__( + self, + benefit: Union[IntegerAttr, int], + name: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None, + ): """Creates an PDL `pattern` operation.""" - name_attr = None if name is None else _get_str_attr(name) - benefit_attr = _get_int_attr(16, benefit) - super().__init__(benefit_attr, sym_name=name_attr, loc=loc, ip=ip) + super().__init__(benefit, sym_name=name, loc=loc, ip=ip) self.regions[0].blocks.append() @property @@ -182,13 +176,17 @@ class ReplaceOp: """Specialization for PDL replace op class.""" - def __init__(self, - op: Union[OpView, Operation, Value], - *, - with_op: Optional[Union[OpView, Operation, Value]] = None, - with_values: Sequence[Union[OpView, Operation, Value]] = [], - loc=None, - ip=None): + def __init__( + self, + op: Union[OpView, Operation, Value], + *, + with_op: Optional[Union[OpView, Operation, Value]] = None, + with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + loc=None, + ip=None, + ): + if with_values is None: + with_values = [] op = _get_value(op) with_op = with_op if with_op is None else _get_value(with_op) with_values = _get_values(with_values) @@ -198,13 +196,14 @@ class ResultOp: """Specialization for PDL result op class.""" - def __init__(self, - parent: Union[OpView, Operation, Value], - index: Union[IntegerAttr, int], - *, - loc=None, - ip=None): - index = _get_int_attr(32, index) + def __init__( + self, + parent: Union[OpView, Operation, Value], + index: Union[IntegerAttr, int], + *, + loc=None, + ip=None, + ): parent = _get_value(parent) result = pdl.ValueType.get() super().__init__(result, parent, index, loc=loc, ip=ip) @@ -213,32 +212,37 @@ class ResultsOp: """Specialization for PDL results op class.""" - def __init__(self, - result: Type, - parent: Union[OpView, Operation, Value], - index: Optional[Union[IntegerAttr, int]] = None, - *, - loc=None, - ip=None): + def __init__( + self, + result: Type, + parent: Union[OpView, Operation, Value], + index: Optional[Union[IntegerAttr, int]] = None, + *, + loc=None, + ip=None, + ): parent = _get_value(parent) - index = index if index is None else _get_int_attr(32, index) super().__init__(result, parent, index=index, loc=loc, ip=ip) class RewriteOp: """Specialization for PDL rewrite op class.""" - def __init__(self, - root: Optional[Union[OpView, Operation, Value]] = None, - name: Optional[Union[StringAttr, str]] = None, - args: Sequence[Union[OpView, Operation, Value]] = [], - *, - loc=None, - ip=None): + def __init__( + self, + root: Optional[Union[OpView, Operation, Value]] = None, + name: Optional[Union[StringAttr, str]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] root = root if root is None else _get_value(root) - name = name if name is None else _get_str_attr(name) + name = name if name is None else name args = _get_values(args) - super().__init__(args, root=root,name=name, loc=loc, ip=ip) + super().__init__(args, root=root, name=name, loc=loc, ip=ip) def add_body(self): """Add body (block) to the rewrite.""" @@ -259,8 +263,6 @@ *, loc=None, ip=None): - constantType = constantType if constantType is None else _get_type_attr( - constantType) result = pdl.TypeType.get() super().__init__(result, constantType=constantType, loc=loc, ip=ip) @@ -268,13 +270,14 @@ class TypesOp: """Specialization for PDL types op class.""" - def __init__(self, - constantTypes: Sequence[Union[TypeAttr, Type]] = [], - *, - loc=None, - ip=None): - constantTypes = _get_array_attr( - [_get_type_attr(ty) for ty in constantTypes]) - constantTypes = None if not constantTypes else constantTypes + def __init__( + self, + constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None, + *, + loc=None, + ip=None, + ): + if constantTypes is None: + constantTypes = [] result = pdl.RangeType.get(pdl.TypeType.get()) super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -15,180 +15,159 @@ OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] -def _get_int64_attr(value: Union[int, Attribute]) -> IntegerAttr: - if isinstance(value, int): - return IntegerAttr.get(IntegerType.get_signless(64), value) - return value - - -def _get_array_attr( - values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr: - """Creates an array attribute from its operand.""" - if values is None: - return ArrayAttr.get([]) - if isinstance(values, ArrayAttr): - return values - - return ArrayAttr.get(values) - - -def _get_int_array_attr( - values: Optional[Union[ArrayAttr, Sequence[Union[IntegerAttr, int]]]] -) -> ArrayAttr: - """Creates an integer array attribute from its operand. - - If the operand is already an array attribute, forwards it. Otherwise treats - the operand as a list of attributes or integers, possibly intersperced, to - create a new array attribute containing integer attributes. Expects the - thread-local MLIR context to have been set by the context manager. - """ - if values is None: - return ArrayAttr.get([]) - if isinstance(values, ArrayAttr): - return values - - return ArrayAttr.get([_get_int64_attr(v) for v in values]) - -def _get_dense_int64_array_attr( - values: Sequence[int]) -> DenseI64ArrayAttr: - """Creates a dense integer array from a sequence of integers. - Expects the thread-local MLIR context to have been set by the context - manager. - """ - if values is None: - return DenseI64ArrayAttr.get([]) - return DenseI64ArrayAttr.get(values) - def _get_int_int_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] ) -> ArrayAttr: """Creates an array attribute containing array attributes of integers. - If the operand is already an array attribute, forwards it. Otherwise treats - the operand as a list of attributes or integers, potentially interpserced, to - create a new array-of-array attribute. Expects the thread-local MLIR context - to have been set by the context manager. - """ + If the operand is already an array attribute, forwards it. Otherwise treats + the operand as a list of attributes or integers, potentially interpserced, to + create a new array-of-array attribute. Expects the thread-local MLIR context + to have been set by the context manager. + """ if values is None: return ArrayAttr.get([]) if isinstance(values, ArrayAttr): return values + if isinstance(values, list): + values = [ + ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), v) + for v in value]) + for value in values + ] - return ArrayAttr.get([_get_int_array_attr(value) for value in values]) + return ArrayAttr.get(values) class DecomposeOp: """Specialization for DecomposeOp class.""" def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__( - pdl.OperationType.get(), - _get_op_result_or_value(target), - loc=loc, - ip=ip) + super().__init__(pdl.OperationType.get(), + _get_op_result_or_value(target), + loc=loc, + ip=ip) class GeneralizeOp: """Specialization for GeneralizeOp class.""" def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__( - pdl.OperationType.get(), - _get_op_result_or_value(target), - loc=loc, - ip=ip) + super().__init__(pdl.OperationType.get(), + _get_op_result_or_value(target), + loc=loc, + ip=ip) class InterchangeOp: """Specialization for InterchangeOp class.""" - def __init__(self, - target: Union[Operation, Value], - *, - iterator_interchange: OptionalIntList = None, - loc=None, - ip=None): + def __init__( + self, + target: Union[Operation, Value], + *, + iterator_interchange: OptionalIntList = None, + loc=None, + ip=None, + ): pdl_operation_type = pdl.OperationType.get() - interchange_attr = _get_dense_int64_array_attr(iterator_interchange) super().__init__( pdl_operation_type, _get_op_result_or_value(target), - iterator_interchange=interchange_attr, + iterator_interchange=iterator_interchange, loc=loc, - ip=ip) + ip=ip, + ) class MatchOp: """Specialization for MatchOp class.""" @classmethod - def match_op_names(MatchOp, - target: Union[Operation, Value], - names: Sequence[str], - loc=None, - ip=None): + def match_op_names( + MatchOp, + target: Union[Operation, Value], + names: Sequence[str], + loc=None, + ip=None, + ): pdl_operation_type = pdl.OperationType.get() return MatchOp( pdl_operation_type, _get_op_result_or_value(target), ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), loc=loc, - ip=ip) + ip=ip, + ) class MultiTileSizesOp: """Specialization for MultitileSizesOp class.""" - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - dimension: Union[int, IntegerAttr], - target_size: Union[int, IntegerAttr], - divisor: Optional[Union[int, IntegerAttr]] = None, - loc=None, - ip=None): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + dimension: Union[int, IntegerAttr], + target_size: Union[int, IntegerAttr], + divisor: Optional[Optional[Union[int, IntegerAttr]]] = None, + loc=None, + ip=None, + ): + if divisor is None: + divisor = 1 super().__init__( result_type, result_type, result_type, _get_op_result_or_value(target), - dimension=_get_int64_attr(dimension), - target_size=_get_int64_attr(target_size), - divisor=_get_int64_attr(divisor if divisor else 1), + dimension=dimension, + target_size=target_size, + divisor=divisor, loc=loc, - ip=ip) + ip=ip, + ) class PadOp: """Specialization for PadOp class.""" - def __init__(self, - target: Union[Operation, Value], - *, - padding_values: Optional[Union[ArrayAttr, - Sequence[Attribute]]] = None, - padding_dimensions: OptionalIntList = None, - pack_paddings: OptionalIntList = None, - transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[ - ArrayAttr, IntOrAttrList]]]] = None, - loc=None, - ip=None): + def __init__( + self, + target: Union[Operation, Value], + *, + padding_values: Optional[Optional[Union[ArrayAttr, + Sequence[Attribute]]]] = None, + padding_dimensions: OptionalIntList = None, + pack_paddings: OptionalIntList = None, + transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[ + ArrayAttr, IntOrAttrList]]]] = None, + loc=None, + ip=None, + ): + if transpose_paddings is None: + transpose_paddings = [] + if pack_paddings is None: + pack_paddings = [] + if padding_dimensions is None: + padding_dimensions = [] + if padding_values is None: + padding_values = [] pdl_operation_type = pdl.OperationType.get() - padding_values_attr = _get_array_attr(padding_values) - padding_dimensions_attr = _get_int_array_attr(padding_dimensions) - pack_paddings_attr = _get_int_array_attr(pack_paddings) transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) super().__init__( pdl_operation_type, _get_op_result_or_value(target), - padding_values=padding_values_attr, - padding_dimensions=padding_dimensions_attr, - pack_paddings=pack_paddings_attr, + padding_values=padding_values, + padding_dimensions=padding_dimensions, + pack_paddings=pack_paddings, transpose_paddings=transpose_paddings_attr, loc=loc, - ip=ip) + ip=ip, + ) class ScalarizeOp: @@ -196,29 +175,29 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): pdl_operation_type = pdl.OperationType.get() - super().__init__( - pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip) + super().__init__(pdl_operation_type, + _get_op_result_or_value(target), + loc=loc, + ip=ip) class SplitOp: """Specialization for SplitOp class.""" - def __init__(self, - target: Union[Operation, Value], - dimension: Union[int, Attribute], - split_point: Union[int, Operation, Value, Attribute], - *, - loc=None, - ip=None): - dimension = _get_int64_attr(dimension) + def __init__( + self, + target: Union[Operation, Value], + dimension: Union[int, Attribute], + split_point: Union[int, Operation, Value, Attribute], + *, + loc=None, + ip=None, + ): if isinstance(split_point, int): - split_point = _get_int64_attr(split_point) - - if isinstance(split_point, Attribute): static_split_point = split_point dynamic_split_point = None else: - static_split_point = _get_int64_attr(ShapedType.get_dynamic_size()) + static_split_point = ShapedType.get_dynamic_size() dynamic_split_point = _get_op_result_or_value(split_point) target = _get_op_result_or_value(target) @@ -231,44 +210,53 @@ static_split_point=static_split_point, dynamic_split_point=dynamic_split_point, loc=loc, - ip=ip) + ip=ip, + ) class TileOp: """Specialization for TileOp class.""" @overload - def __init__(self, - loop_types: Union[Type, List[Type]], - target: Union[Operation, Value], - *, - sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, - Value]], ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None): + def __init__( + self, + loop_types: Union[Type, List[Type]], + target: Union[Operation, Value], + *, + sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], + ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): ... @overload - def __init__(self, - target: Union[Operation, Value, OpView], - *, - sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, - Value]], ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None): + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], + ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): ... - def __init__(self, - loop_types_or_target: Union[Type, List[Type], Operation, Value], - target_or_none: Optional[Union[Operation, Value, OpView]] = None, - *, - sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, - Value]], ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None): + def __init__( + self, + loop_types_or_target: Union[Type, List[Type], Operation, Value], + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], + ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + if interchange is None: + interchange = [] if sizes is None: sizes = [] @@ -293,8 +281,8 @@ target = loop_types_or_target assert target_or_none is None, "Cannot construct TileOp with two targets." else: - loop_types = ([loop_types_or_target] * num_loops) if isinstance( - loop_types_or_target, Type) else loop_types_or_target + loop_types = (([loop_types_or_target] * num_loops) if isinstance( + loop_types_or_target, Type) else loop_types_or_target) target = target_or_none target = _get_op_result_or_value(target) @@ -305,10 +293,10 @@ target, dynamic_sizes=dynamic_sizes, static_sizes=sizes_attr, - interchange=_get_dense_int64_array_attr(interchange) - if interchange else None, + interchange=interchange, loc=loc, - ip=ip) + ip=ip, + ) def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: if not attr: @@ -319,12 +307,14 @@ class VectorizeOp: """Specialization for VectorizeOp class.""" - def __init__(self, - target: Union[Operation, Value], - *, - vectorize_padding: Union[bool, BoolAttr] = False, - loc=None, - ip=None): + def __init__( + self, + target: Union[Operation, Value], + *, + vectorize_padding: Union[bool, BoolAttr] = False, + loc=None, + ip=None, + ): pdl_operation_type = pdl.OperationType.get() if isinstance(vectorize_padding, bool): vectorize_padding = UnitAttr.get() @@ -333,4 +323,5 @@ _get_op_result_or_value(target), vectorize_padding=vectorize_padding, loc=loc, - ip=ip) + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -4,102 +4,119 @@ try: from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + ) except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from argparse import SUPPRESS -from typing import Optional, overload, Sequence, Union - - -def _get_symbol_ref_attr(value: Union[Attribute, str]): - if isinstance(value, Attribute): - return value - return FlatSymbolRefAttr.get(value) +from typing import Optional, Sequence, Union class CastOp: - def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__( - result_type, - _get_op_result_or_value(target), - loc=loc, - ip=ip) + def __init__(self, + result_type: Type, + target: Union[Operation, Value], + *, + loc=None, + ip=None): + super().__init__(result_type, + _get_op_result_or_value(target), + loc=loc, + ip=ip) class GetClosestIsolatedParentOp: - def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__( - result_type, - _get_op_result_or_value(target), - loc=loc, - ip=ip) - - -class MergeHandlesOp: - def __init__(self, - handles: Sequence[Union[Operation, Value]], + result_type: Type, + target: Union[Operation, Value], *, - deduplicate: bool = False, loc=None, ip=None): + super().__init__(result_type, + _get_op_result_or_value(target), + loc=loc, + ip=ip) + + +class MergeHandlesOp: + + def __init__( + self, + handles: Sequence[Union[Operation, Value]], + *, + deduplicate: bool = False, + loc=None, + ip=None, + ): super().__init__( [_get_op_result_or_value(h) for h in handles], deduplicate=deduplicate, loc=loc, - ip=ip) + ip=ip, + ) class PDLMatchOp: - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - pattern_name: Union[Attribute, str], - *, - loc=None, - ip=None): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + pattern_name: Union[Attribute, str], + *, + loc=None, + ip=None, + ): super().__init__( result_type, _get_op_result_or_value(target), - _get_symbol_ref_attr(pattern_name), + pattern_name, loc=loc, - ip=ip) + ip=ip, + ) class ReplicateOp: - def __init__(self, - pattern: Union[Operation, Value], - handles: Sequence[Union[Operation, Value]], - *, - loc=None, - ip=None): + def __init__( + self, + pattern: Union[Operation, Value], + handles: Sequence[Union[Operation, Value]], + *, + loc=None, + ip=None, + ): super().__init__( [_get_op_result_or_value(h).type for h in handles], _get_op_result_or_value(pattern), [_get_op_result_or_value(h) for h in handles], loc=loc, - ip=ip) + ip=ip, + ) class SequenceOp: - def __init__(self, failure_propagation_mode, results: Sequence[Type], - target: Union[Operation, Value, Type], - extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], - Operation, OpView]] = None): - root = _get_op_result_or_value(target) if isinstance( - target, (Operation, Value)) else None + def __init__( + self, + failure_propagation_mode, + results: Sequence[Type], + target: Union[Operation, Value, Type], + extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], Operation, + OpView]] = None, + ): + root = (_get_op_result_or_value(target) if isinstance( + target, (Operation, Value)) else None) root_type = root.type if not isinstance(target, Type) else target if not isinstance(failure_propagation_mode, Attribute): failure_propagation_mode_attr = IntegerAttr.get( IntegerType.get_signless(32), failure_propagation_mode._as_int()) else: - failure_propagation_mode = failure_propagation_mode + failure_propagation_mode_attr = failure_propagation_mode if extra_bindings is None: extra_bindings = [] @@ -114,10 +131,12 @@ else: extra_binding_types = [v.type for v in extra_bindings] - super().__init__(results_=results, - failure_propagation_mode=failure_propagation_mode_attr, - root=root, - extra_bindings=extra_bindings) + super().__init__( + results_=results, + failure_propagation_mode=failure_propagation_mode_attr, + root=root, + extra_bindings=extra_bindings, + ) self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) @property @@ -143,10 +162,7 @@ root = _get_op_result_or_value(target) if not isinstance(target, Type) else None root_type = target if isinstance(target, Type) else root.type - super().__init__( - root=root, - loc=loc, - ip=ip) + super().__init__(root=root, loc=loc, ip=ip) self.regions[0].blocks.append(root_type) @property @@ -160,9 +176,13 @@ class YieldOp: - def __init__(self, - operands: Union[Operation, Sequence[Value]] = [], - *, - loc=None, - ip=None): + def __init__( + self, + operands: Optional[Union[Operation, Sequence[Value]]] = None, + *, + loc=None, + ip=None, + ): + if operands is None: + operands = [] super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -8,9 +8,11 @@ # Convenience decorator for registering user-friendly Attribute builders. def register_attribute_builder(kind): + def decorator_builder(func): AttrBuilder.insert(kind, func) return func + return decorator_builder @@ -18,34 +20,77 @@ def _boolAttr(x, context): return BoolAttr.get(x, context=context) + @register_attribute_builder("IndexAttr") def _indexAttr(x, context): return IntegerAttr.get(IndexType.get(context=context), x) + +@register_attribute_builder("I16Attr") +def _i32Attr(x, context): + return IntegerAttr.get(IntegerType.get_signless(16, context=context), x) + + @register_attribute_builder("I32Attr") def _i32Attr(x, context): - return IntegerAttr.get( - IntegerType.get_signless(32, context=context), x) + return IntegerAttr.get(IntegerType.get_signless(32, context=context), x) + @register_attribute_builder("I64Attr") def _i64Attr(x, context): - return IntegerAttr.get( - IntegerType.get_signless(64, context=context), x) + return IntegerAttr.get(IntegerType.get_signless(64, context=context), x) + @register_attribute_builder("StrAttr") def _stringAttr(x, context): return StringAttr.get(x, context=context) + @register_attribute_builder("SymbolNameAttr") def _symbolNameAttr(x, context): return StringAttr.get(x, context=context) + +@register_attribute_builder("SymbolRefAttr") +def _symbolRefAttr(x, context): + return FlatSymbolRefAttr.get(x, context=context) + + +@register_attribute_builder("ArrayAttr") +def _arrayAttr(x, context): + return ArrayAttr.get(x, context=context) + + +@register_attribute_builder("I64ArrayAttr") +def _i64ArrayAttr(x, context): + return ArrayAttr.get([_i64Attr(v, context) for v in x]) + + +@register_attribute_builder("DenseI64ArrayAttr") +def _denseI64ArrayAttr(x, context): + return DenseI64ArrayAttr.get(x, context=context) + + +@register_attribute_builder("TypeAttr") +def _typeAttr(x, context): + return TypeAttr.get(x, context=context) + + +@register_attribute_builder("TypeArrayAttr") +def _typeArrayAttr(x, context): + return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context) + + try: import numpy as np + @register_attribute_builder("IndexElementsAttr") def _indexElementsAttr(x, context): return DenseElementsAttr.get( - np.array(x, dtype=np.int64), type=IndexType.get(context=context), - context=context) + np.array(x, dtype=np.int64), + type=IndexType.get(context=context), + context=context, + ) + except ImportError: pass