diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py @@ -51,7 +51,7 @@ True """ -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Dict, Optional from ..... import ir as _ir @@ -77,20 +77,20 @@ def __init__(self, *, - global_state: "AffineBuildState" = None, + global_state: Optional["AffineBuildState"] = None, allow_new_symbols: bool = True, allow_new_dims: bool = True): if not global_state: - self.all_symbols = dict() # type: Dict[str, int] - self.all_dims = dict() # type: Dict[str, int] + self.all_symbols: Dict[str, int] = dict() + self.all_dims: Dict[str, int] = dict() else: # Alias the global dict. self.all_symbols = global_state.all_symbols self.all_dims = global_state.all_dims # Map of symbols and dims in the current build. - self.local_symbols = dict() # type: Dict[str, int] - self.local_dims = dict() # type: Dict[str, int] + self.local_symbols: Dict[str, int] = dict() + self.local_dims: Dict[str, int] = dict() self.allow_new_symbols = allow_new_symbols self.allow_new_dims = allow_new_dims @@ -228,7 +228,7 @@ """Represents a named dimension. """ - ALL_DIMS = dict() # type: Dict[str, "DimDef"] + ALL_DIMS: Dict[str, "DimDef"] = dict() def __new__(cls, dimname: str): existing = cls.ALL_DIMS.get(dimname) @@ -271,7 +271,7 @@ >>> s1 is SymbolDef("s1") True """ - ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"] + ALL_SYMBOLS: Dict[str, "SymbolDef"] = dict() def __new__(cls, symname: str): existing = cls.ALL_SYMBOLS.get(symname) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -174,7 +174,7 @@ def __init__(self, reduce_use: "ReduceFnUse", args: Sequence[TensorExpression]): self.reduce_use = reduce_use - self.lhs = None # type: Optional[TensorUse] + self.lhs: Optional[TensorUse] = None self.args = args def to_scalar_expression(self) -> ScalarExpression: @@ -431,15 +431,15 @@ if type_var and not isinstance(type_var, TypeVar): raise ValueError( f"OperandDef requires a TypeVar but got {repr(type_var)}") - self.owner = None # type: Optional["LinalgOpDef"] + self.owner: Optional["LinalgOpDef"] = None self.type_var = type_var self.size_exprs = size_exprs self.index_dims = index_dims self.default_indices = default_indices self.default_fn = default_fn self.kind = kind - self.name = None # type: Optional[str] - self.registered_index = -1 # type: int + self.name: Optional[str] = None + self.registered_index: int = -1 def attach(self, index: int, name: str, owner: "LinalgOpDef"): if self.owner: @@ -525,11 +525,11 @@ Note that due to the way assignment works in Python, we have to capture direct assignment as a setitem on the TensorDef. """ + assert self.operand_def.owner, "TensorDef is not registered with an op" if not isinstance(value, TensorExpression): raise ValueError(f"Only TensorExpressions can be assigned to TensorDefs. " f"Got: {repr(value)}") - use = self[dims] - comp = Comprehension((use, value)) + comp = Comprehension((self[dims], value)) self.operand_def.owner.comprehensions.append(comp) @@ -647,8 +647,8 @@ """Represents a single comprehension.""" def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]): - self.definitions = list() # List[TensorUse] - self.values = list() # List[TensorExpression] + self.definitions: List[TensorUse] = list() + self.values: List[TensorExpression] = list() # Find the lhs to reduction rhs. for assign, value in bindings: @@ -662,7 +662,7 @@ @property def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]: """Gets the reduction dims for the comprehension or None.""" - result = set() + result: Set[Tuple[DimDef, ...]] = set() for use in self.values: if isinstance(use, TensorReduceFn): result.add(use.reduce_use.reduce_dims) @@ -710,10 +710,10 @@ def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]): self.name = name - self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name + self.cpp_class_name = cpp_class_name if cpp_class_name else name self.doc = doc - self.implements = [] # type: List[OpInterfaceDef] - self.defines = [] # type: List[OpDefinitionsDef] + self.implements: List[OpInterfaceDef] = [] + self.defines: List[OpDefinitionDef] = [] def to_yaml_custom_dict(self): d = dict( @@ -737,9 +737,9 @@ doc: Optional[str] = None): self.metadata = OpMetadataDef( name=name, cpp_class_name=cpp_class_name, doc=doc) - self.registered_operands = dict() # type: Dict[str, OperandDef] - self.domain = list() # type: List[DimDef] - self.comprehensions = list() # type: List[Comprehension] + self.registered_operands: Dict[str, OperandDef] = dict() + self.domain: List[DimDef] = list() + self.comprehensions: List[Comprehension] = list() self._affine_state = AffineBuildState() def add_operand(self, name: str, operand: OperandDef): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -47,9 +47,9 @@ shape_map: Optional[_ir.AffineMap] = None, index_attr_map: Optional[_ir.AffineMap] = None): self.operand_def = operand_def - self.shape_map = shape_map # type: Optional[_ir.AffineMap] - self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap] - self.indexing_map = None # type: Optional[_ir.AffineMap] + self.shape_map: Optional[_ir.AffineMap] = shape_map + self.index_attr_map: Optional[_ir.AffineMap] = index_attr_map + self.indexing_map: Optional[_ir.AffineMap] = None @property def name(self) -> str: @@ -121,9 +121,9 @@ context: Optional[_ir.Context] = None): self.context = context if context is not None else _ir.Context() self.affine_state = AffineBuildState() - self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] - self.operands = dict() # type: Dict[OperandDef, OperandDefConfig] - self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] + self.writes: List[Tuple[TensorUse, TensorExpression]] = list() + self.operands: Dict[OperandDef, OperandDefConfig] = dict() + self.uses: Dict[TensorUse, TensorUseConfig] = dict() # Compute the ordered set of writes and collect the tensor, capture, dims, # and index uses. @@ -228,7 +228,7 @@ # For each read use, propagate the indexing maps from the use to the # tensor, ensuring that there are not conflicts. for _, read_expr in self.writes: - read_uses = set() # type: Set[TensorUse] + read_uses: Set[TensorUse] = set() read_expr.collect_tensor_uses(read_uses) for read_use in read_uses: read_operand_config = self.operands[read_use.operand_def] diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -2,7 +2,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Dict, List, Sequence, Union +from typing import Sequence, Union from contextlib import contextmanager import functools diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -11,7 +11,6 @@ from .... import math from .... import arith from .... import complex -from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values from .scalar_expr import * from .config import * @@ -71,8 +70,8 @@ f"{len(outs)} for {op_config}") # Compute a replacement list for all index attribute symbols. - expressions = [] # type: Sequence[AffineExpr] - replacements = [] # type: Sequence[AffineExpr] + expressions: Sequence[AffineExpr] = [] + replacements: Sequence[AffineExpr] = [] for index_attr in index_attr_arg_defs: index_attr_vals = index_attr.operand_def.default_indices if index_attr.name in attrs: @@ -81,7 +80,7 @@ if not all(isinstance(value, int) for value in index_attr_vals): raise ValueError(f"Attribute {index_attr.name} needs to be of type " f"Sequence[int] but got {type(index_attr_vals)}") - results = index_attr.index_attr_map.results # type: AffineExprList + results: AffineExprList = index_attr.index_attr_map.results if len(index_attr_vals) != len(results): raise ValueError(f"Attribute {index_attr.name} has length {len(results)} " f"but got {len(index_attr_vals)} values") @@ -91,7 +90,7 @@ # Replace all index attribute symbols by their value. # TODO: Add support for shape symbols. - indexing_maps = [] # type: Sequence[AffineMap] + indexing_maps: Sequence[AffineMap] = [] for curr in op_config.indexing_maps: for expression, replacement in zip(expressions, replacements): curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols) @@ -111,14 +110,14 @@ result_types = [t for t in out_types if isa(RankedTensorType, t)] # Initialize the type dictionary with the predefined types. - type_mapping = dict() # type: Dict[str, Type] + type_mapping: Dict[str, Type] = dict() type_mapping["F32"] = F32Type.get() type_mapping["F64"] = F64Type.get() type_mapping["I32"] = IntegerType.get_signless(32) type_mapping["I64"] = IntegerType.get_signless(64) # Extract type vars for input/output based types. - block_arg_types = list() # type: List[Type] + block_arg_types: List[Type] = list() for arg_def, arg_element_type in zip(in_arg_defs + out_arg_defs, _get_types_from_values(*ins, *outs)): _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types) @@ -133,7 +132,7 @@ ]) # Compute the index attributes used when emitting a named structured op. - index_attrs = {} # type: Dict[str, DenseElementAttr] + index_attrs: Dict[str, DenseElementsAttr] = {} for index_attr in index_attr_arg_defs: index_attr_vals = attrs.get(index_attr.name) # Only forward attributes set to a non-default value. @@ -270,7 +269,7 @@ self.type_mapping = type_mapping self.block_arg_mapping = block_arg_mapping self.fn_attr_mapping = fn_attr_mapping - self.yield_mapping = dict() # type: Dict[str, Value] + self.yield_mapping: Dict[str, Value] = dict() def assign(self, assignment: ScalarAssign): if assignment.arg in self.yield_mapping: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py @@ -40,7 +40,7 @@ Type variables are uniqued by name. """ - ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"] + ALL_TYPEVARS: Dict[str, "TypeVar"] = dict() def __new__(cls, name: str): existing = cls.ALL_TYPEVARS.get(name)