diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -29,7 +29,7 @@ # Perform Python level site initialization. This involves: # 1. Attempting to load initializer modules, specific to the distribution. -# 2. Defining the concrete mlir.ir.Context that does site specific +# 2. Defining the concrete mlir.ir.Context that does site specific # initialization. # # Aside from just being far more convenient to do this at the Python level, @@ -38,13 +38,13 @@ # in the scope of the base class __init__). # # For #1, we: -# a. Probe for modules named '_mlirRegisterEverything' and -# '_site_initialize_{i}', where 'i' is a number starting at zero and +# a. Probe for modules named '_mlirRegisterEverything' and +# '_site_initialize_{i}', where 'i' is a number starting at zero and # proceeding so long as a module with the name is found. # b. If the module has a 'register_dialects' attribute, it will be called # immediately with a DialectRegistry to populate. # c. If the module has a 'context_init_hook', it will be added to a list -# of callbacks that are invoked as the last step of Context +# of callbacks that are invoked as the last step of Context # initialization (and passed the Context under construction). # # This facility allows downstreams to customize Context creation to their @@ -64,9 +64,10 @@ except ModuleNotFoundError: return False except ImportError: - message = (f"Error importing mlir initializer {module_name}. This may " - "happen in unclean incremental builds but is likely a real bug if " - "encountered otherwise and the MLIR Python API may not function.") + message = ( + f"Error importing mlir initializer {module_name}. This may " + "happen in unclean incremental builds but is likely a real bug if " + "encountered otherwise and the MLIR Python API may not function.") logger.warning(message, exc_info=True) logger.debug("Initializing MLIR with module: %s", module_name) @@ -78,7 +79,6 @@ post_init_hooks.append(m.context_init_hook) return True - # If _mlirRegisterEverything is built, then include it as an initializer # module. process_initializer_module("_mlirRegisterEverything") @@ -91,6 +91,7 @@ break class Context(ir._BaseContext): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.append_dialect_registry(registry) @@ -100,6 +101,7 @@ # all dialects. It is being done here in order to preserve existing # behavior. See: https://github.com/llvm/llvm-project/issues/56037 self.load_all_available_dialects() + ir.Context = Context class MLIRError(Exception): @@ -108,6 +110,7 @@ message: str error_diagnostics: List[ir.DiagnosticInfo] """ + def __init__(self, message, error_diagnostics): self.message = message self.error_diagnostics = error_diagnostics @@ -118,10 +121,13 @@ if self.error_diagnostics: s += ':' for diag in self.error_diagnostics: - s += "\nerror: " + str(diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n ') + s += "\nerror: " + str( + diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n ') for note in diag.notes: - s += "\n note: " + str(note.location)[4:-1] + ": " + note.message.replace('\n', '\n ') + s += "\n note: " + str( + note.location)[4:-1] + ": " + note.message.replace('\n', '\n ') return s + ir.MLIRError = MLIRError diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py --- a/mlir/python/mlir/dialects/_arith_ops_ext.py +++ b/mlir/python/mlir/dialects/_arith_ops_ext.py @@ -50,11 +50,10 @@ @classmethod def create_index(cls, value: int, *, loc=None, ip=None): """Create an index-typed constant.""" - return cls( - IndexType.get(context=_get_default_loc_context(loc)), - value, - loc=loc, - ip=ip) + return cls(IndexType.get(context=_get_default_loc_context(loc)), + value, + loc=loc, + ip=ip) @property def type(self): diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py --- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py +++ b/mlir/python/mlir/dialects/_bufferization_ops_ext.py @@ -29,10 +29,9 @@ attributes = {} if escape: attributes["escape"] = escape - op = self.build_generic( - results=[tensor_type], - operands=[dynamic_sizes, copy, size_hint], - attributes=attributes, - loc=loc, - ip=ip) + op = self.build_generic(results=[tensor_type], + operands=[dynamic_sizes, copy, size_hint], + attributes=attributes, + loc=loc, + ip=ip) OpView.__init__(self, op) diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py @@ -7,6 +7,7 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e + class ModuleOp: """Specialization for the module op class.""" diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -15,6 +15,7 @@ ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" RESULT_ATTRIBUTE_NAME = "res_attrs" + class ConstantOp: """Specialization for the constant op class.""" @@ -58,7 +59,11 @@ type = TypeAttr.get(type) sym_visibility = StringAttr.get( str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + super().__init__(sym_name, + type, + sym_visibility=sym_visibility, + loc=loc, + ip=ip) if body_builder: entry_block = self.add_entry_block() with InsertionPoint(entry_block): @@ -230,6 +235,7 @@ return decorator + class CallOp: """Specialization for the call op class.""" @@ -272,14 +278,13 @@ raise ValueError("unexpected third argument when constructing a call" + "to a function") - super().__init__( - calleeOrResults.type.results, - FlatSymbolRefAttr.get( - calleeOrResults.name.value, - context=_get_default_loc_context(loc)), - argumentsOrCallee, - loc=loc, - ip=ip) + super().__init__(calleeOrResults.type.results, + FlatSymbolRefAttr.get( + calleeOrResults.name.value, + context=_get_default_loc_context(loc)), + argumentsOrCallee, + loc=loc, + ip=ip) return if isinstance(argumentsOrCallee, list): @@ -288,13 +293,16 @@ f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}") if isinstance(argumentsOrCallee, FlatSymbolRefAttr): - super().__init__( - calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip) + super().__init__(calleeOrResults, + argumentsOrCallee, + arguments, + loc=loc, + ip=ip) elif isinstance(argumentsOrCallee, str): - super().__init__( - calleeOrResults, - FlatSymbolRefAttr.get( - argumentsOrCallee, context=_get_default_loc_context(loc)), - arguments, - loc=loc, - ip=ip) + super().__init__(calleeOrResults, + FlatSymbolRefAttr.get( + argumentsOrCallee, + context=_get_default_loc_context(loc)), + arguments, + loc=loc, + ip=ip) diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -12,6 +12,7 @@ from ._ods_common import get_op_result_or_value as _get_op_result_or_value + def isa(cls: Type, ty: Type): try: cls(ty) diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py --- a/mlir/python/mlir/dialects/_ml_program_ops_ext.py +++ b/mlir/python/mlir/dialects/_ml_program_ops_ext.py @@ -11,7 +11,6 @@ from ._ml_program_ops_gen import * - ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" RESULT_ATTRIBUTE_NAME = "res_attrs" @@ -48,7 +47,11 @@ type = TypeAttr.get(type) sym_visibility = StringAttr.get( str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + super().__init__(sym_name, + type, + sym_visibility=sym_visibility, + loc=loc, + ip=ip) if body_builder: entry_block = self.add_entry_block() with InsertionPoint(entry_block): diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -124,7 +124,8 @@ def get_op_result_or_value( - arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList] + arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, + _cext.ir.OpResultList] ) -> _cext.ir.Value: """Returns the given value or the single result of the given op. @@ -145,7 +146,8 @@ def get_op_results_or_values( arg: _Union[_cext.ir.OpView, _cext.ir.Operation, - _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]] + _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, + _cext.ir.Value]]] ) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]: """Returns the given sequence of values or the results of the given op. diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py --- a/mlir/python/mlir/dialects/_scf_ops_ext.py +++ b/mlir/python/mlir/dialects/_scf_ops_ext.py @@ -10,6 +10,7 @@ from typing import Any, Optional, Sequence, Union from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values + class ForOp: """Specialization for the SCF for op class.""" @@ -36,15 +37,14 @@ results = [arg.type for arg in iter_args] super().__init__( - self.build_generic( - regions=1, - results=results, - operands=[ - _get_op_result_or_value(o) - for o in [lower_bound, upper_bound, step] - ] + list(iter_args), - loc=loc, - ip=ip)) + self.build_generic(regions=1, + results=results, + operands=[ + _get_op_result_or_value(o) + for o in [lower_bound, upper_bound, step] + ] + list(iter_args), + loc=loc, + ip=ip)) self.regions[0].blocks.append(IndexType.get(), *results) @property @@ -69,13 +69,7 @@ class IfOp: """Specialization for the SCF if op class.""" - def __init__(self, - cond, - results_=[], - *, - hasElse=False, - loc=None, - ip=None): + def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None): """Creates an SCF `if` operation. - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. @@ -86,15 +80,14 @@ results = [] results.extend(results_) super().__init__( - self.build_generic( - regions=2, - results=results, - operands=operands, - loc=loc, - ip=ip)) + self.build_generic(regions=2, + results=results, + operands=operands, + loc=loc, + ip=ip)) self.regions[0].blocks.append(*[]) if hasElse: - self.regions[1].blocks.append(*[]) + self.regions[1].blocks.append(*[]) @property def then_block(self): diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py --- a/mlir/python/mlir/dialects/_tensor_ops_ext.py +++ b/mlir/python/mlir/dialects/_tensor_ops_ext.py @@ -33,10 +33,9 @@ static_sizes.append(ShapedType.get_dynamic_size()) dynamic_sizes.append(s) result_type = RankedTensorType.get(static_sizes, element_type) - op = self.build_generic( - results=[result_type], - operands=dynamic_sizes, - attributes={}, - loc=loc, - ip=ip) + op = self.build_generic(results=[result_type], + operands=dynamic_sizes, + attributes={}, + loc=loc, + ip=ip) OpView.__init__(self, op) diff --git a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py --- a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py @@ -13,6 +13,7 @@ from typing import Union + class PDLMatchOp: def __init__( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -497,14 +497,15 @@ raise ValueError(f"TensorDef requires index dims of type DimDef but " f"got {index_dims}") kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR - self.operand_def = OperandDef( - kind, type_var=type_var, size_exprs=shape, index_dims=index_dims) + self.operand_def = OperandDef(kind, + type_var=type_var, + size_exprs=shape, + index_dims=index_dims) def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse: assert self.operand_def.owner, "TensorDef is not registered with an op" - state = AffineBuildState( - global_state=self.operand_def.owner._affine_state, - allow_new_symbols=False) + state = AffineBuildState(global_state=self.operand_def.owner._affine_state, + allow_new_symbols=False) if not isinstance(dims, tuple): dims = (dims,) # Handle single subscript case. # Special case: (None) is a 0d-scalar use. @@ -572,8 +573,9 @@ if len(sizes) != len(default): raise ValueError(f"IndexAttrDef expects {len(sizes)} default values " f"but got {len(default)}") - self.operand_def = OperandDef( - OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default) + self.operand_def = OperandDef(OperandKind.INDEX_ATTR, + size_exprs=sizes, + default_indices=default) class UnaryFnAttrDef: @@ -588,8 +590,8 @@ if not isinstance(default, UnaryFnType): raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType " f"but got {default}") - self.operand_def = OperandDef( - OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name) + self.operand_def = OperandDef(OperandKind.UNARY_FN_ATTR, + default_fn=default.fn_name) def __call__(self, arg: TensorExpression) -> TensorFn: return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg]) @@ -607,8 +609,8 @@ if not isinstance(default, BinaryFnType): raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType " f"but got {default}") - self.operand_def = OperandDef( - OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name) + self.operand_def = OperandDef(OperandKind.BINARY_FN_ATTR, + default_fn=default.fn_name) def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn: @@ -631,8 +633,8 @@ if not isinstance(default, TypeFnType): raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType " f"but got {default}") - self.operand_def = OperandDef( - OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name) + self.operand_def = OperandDef(OperandKind.TYPE_FN_ATTR, + default_fn=default.fn_name) def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn: return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg]) @@ -735,8 +737,9 @@ name: str, cpp_class_name: Optional[str] = None, doc: Optional[str] = None): - self.metadata = OpMetadataDef( - name=name, cpp_class_name=cpp_class_name, doc=doc) + self.metadata = OpMetadataDef(name=name, + cpp_class_name=cpp_class_name, + doc=doc) self.registered_operands = dict() # type: Dict[str, OperandDef] self.domain = list() # type: List[DimDef] self.comprehensions = list() # type: List[Comprehension] diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -156,8 +156,8 @@ # Instantiate the dimensions in the given order. with self.context: - local_state = AffineBuildState( - global_state=self.affine_state, allow_new_symbols=False) + local_state = AffineBuildState(global_state=self.affine_state, + allow_new_symbols=False) for dim in domain: dim.build(state=local_state) @@ -277,9 +277,8 @@ @property def ordered_operands(self) -> Sequence[OperandDefConfig]: - return sorted( - self.operands.values(), - key=lambda operand: operand.operand_def.registered_index) + return sorted(self.operands.values(), + key=lambda operand: operand.operand_def.registered_index) @property def ordered_dims(self) -> Sequence[Tuple[str, int]]: @@ -314,25 +313,26 @@ self.operands[operand_def] = OperandDefConfig(operand_def) return with self.context: - local_state = AffineBuildState( - global_state=self.affine_state, allow_new_dims=False) + local_state = AffineBuildState(global_state=self.affine_state, + allow_new_dims=False) exprs = [] for expr in operand_def.size_exprs: exprs.append(expr.build(state=local_state)) assert local_state.local_dim_count == 0 - affine_map = _ir.AffineMap.get( - dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs) + affine_map = _ir.AffineMap.get(dim_count=0, + symbol_count=local_state.symbol_count, + exprs=exprs) if operand_def.kind == OperandKind.INDEX_ATTR: - self.operands[operand_def] = OperandDefConfig( - operand_def, index_attr_map=affine_map) + self.operands[operand_def] = OperandDefConfig(operand_def, + index_attr_map=affine_map) else: - self.operands[operand_def] = OperandDefConfig( - operand_def, shape_map=affine_map) + self.operands[operand_def] = OperandDefConfig(operand_def, + shape_map=affine_map) def add_indexed_operand(self, operand_def: OperandDef): with self.context: - local_state = AffineBuildState( - global_state=self.affine_state, allow_new_symbols=False) + local_state = AffineBuildState(global_state=self.affine_state, + allow_new_symbols=False) exprs = [] for expr in operand_def.index_dims: exprs.append(expr.build(state=local_state)) @@ -345,15 +345,14 @@ if tensor_use in self.uses: return with self.context: - local_state = AffineBuildState( - global_state=self.affine_state, allow_new_symbols=False) + local_state = AffineBuildState(global_state=self.affine_state, + allow_new_symbols=False) exprs = [] for expr in tensor_use.indices: exprs.append(expr.build(state=local_state)) - indexing_map = _ir.AffineMap.get( - dim_count=local_state.dim_count, - symbol_count=local_state.symbol_count, - exprs=exprs) + indexing_map = _ir.AffineMap.get(dim_count=local_state.dim_count, + symbol_count=local_state.symbol_count, + exprs=exprs) use_config = TensorUseConfig(tensor_use, indexing_map) self.uses[tensor_use] = use_config @@ -361,10 +360,9 @@ def _get_scalar_map(self) -> _ir.AffineMap: """Create an empty affine map used to index a scalar.""" with self.context: - return _ir.AffineMap.get( - dim_count=self.affine_state.dim_count, - symbol_count=self.affine_state.symbol_count, - exprs=list()) + return _ir.AffineMap.get(dim_count=self.affine_state.dim_count, + symbol_count=self.affine_state.symbol_count, + exprs=list()) def _normalize_affine_map(self, affine_map: _ir.AffineMap, @@ -430,11 +428,10 @@ # TODO: Many LinalgOpDef patterns need to expand to multiple generics. assert len(op_def.comprehensions) == 1, "Only one comprehension supported" return [ - LinalgOpConfig( - op_def.metadata, - structured_op=LinalgStructuredOpConfig( - op_def.comprehensions[0], op_def.domain, - op_def.registered_operands.values(), context)), + LinalgOpConfig(op_def.metadata, + structured_op=LinalgStructuredOpConfig( + op_def.comprehensions[0], op_def.domain, + op_def.registered_operands.values(), context)), ] def __repr__(self): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -72,8 +72,8 @@ raise ValueError(f"The named argument 'emit_generic' needs to be " f" of type bool but got {type(emit_generic)}") - op_configs = LinalgOpConfig.from_linalg_op_def( - self.op_def, context=ir.Context.current) + op_configs = LinalgOpConfig.from_linalg_op_def(self.op_def, + context=ir.Context.current) if len(op_configs) != 1: # TODO: Support composite ops. @@ -83,24 +83,25 @@ ctx = ir.Context.current linalgDialect = ctx.get_dialect_descriptor("linalg") fully_qualified_name = "linalg." + self.op_name - emit_generic = ( - emit_generic or not ctx.is_registered_operation(fully_qualified_name)) + emit_generic = (emit_generic or + not ctx.is_registered_operation(fully_qualified_name)) op_config = op_configs[0] out_values = _prepare_structured_op_outs(outs) in_values = [_get_op_result_or_value(i) for i in ins] if op_config.structured_op: if emit_generic: - return emit_generic_structured_op( - op_config.structured_op, *in_values, outs=out_values, **kwargs) + return emit_generic_structured_op(op_config.structured_op, + *in_values, + outs=out_values, + **kwargs) else: - return emit_named_structured_op( - op_config.structured_op, - self.op_name, - self.op_def.metadata.cpp_class_name, - *in_values, - outs=out_values, - **kwargs) + return emit_named_structured_op(op_config.structured_op, + self.op_name, + self.op_def.metadata.cpp_class_name, + *in_values, + outs=out_values, + **kwargs) raise NotImplementedError( f"Emission of linalg op type not supported: {op_config}") @@ -112,8 +113,9 @@ op_class_name=None) -> DefinedOpCallable: if dsl_func is None: # Curry the keyword args in for delayed application. - return functools.partial( - linalg_structured_op, op_name=op_name, op_class_name=op_class_name) + return functools.partial(linalg_structured_op, + op_name=op_name, + op_class_name=op_class_name) # Determine default names by introspecting the function. if op_name is None: op_name = dsl_func.__name__ @@ -121,8 +123,9 @@ # Camel case it. op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" - op_def = LinalgOpDef( - name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)) + op_def = LinalgOpDef(name=op_name, + cpp_class_name=op_class_name, + doc=inspect.getdoc(dsl_func)) # Extract arguments and TensorDefs from the signature. dsl_func_args = list() diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -290,8 +290,8 @@ value_attr = Attribute.parse(expr.scalar_const.value) return arith.ConstantOp(value_attr.type, value_attr).result elif expr.scalar_index: - dim_attr = IntegerAttr.get( - IntegerType.get_signless(64), expr.scalar_index.dim) + dim_attr = IntegerAttr.get(IntegerType.get_signless(64), + expr.scalar_index.dim) return linalg.IndexOp(dim_attr).result elif expr.scalar_fn: kind = expr.scalar_fn.kind.name.lower() diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -975,9 +975,9 @@ D.c] = ReduceFn.min_unsigned[D.kh, D.kw](TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + @linalg_structured_op -def pooling_nwc_sum(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), +def pooling_nwc_sum(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KW, index_dims=[D.kw]), O=TensorDef(U, S.N, S.OW, S.C, output=True), strides=IndexAttrDef(S.SW, default=[1]), @@ -993,13 +993,12 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.ow, D.c, D.kw) - O[D.n, D.ow, D.c] += TypeFn.cast_signed( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + O[D.n, D.ow, + D.c] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) @linalg_structured_op -def pooling_ncw_sum(I=TensorDef(T1, S.N, S.C, - S.OW * S.SW + S.KW * S.DW), +def pooling_ncw_sum(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.KW, index_dims=[D.kw]), O=TensorDef(U, S.N, S.C, S.OW, output=True), strides=IndexAttrDef(S.SW, default=[1]), @@ -1015,13 +1014,12 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.c, D.ow, D.kw) - O[D.n, D.c, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) + O[D.n, D.c, + D.ow] += TypeFn.cast_signed(U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) @linalg_structured_op -def pooling_nwc_max(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), +def pooling_nwc_max(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KW, index_dims=[D.kw]), O=TensorDef(U, S.N, S.OW, S.C, output=True), strides=IndexAttrDef(S.SW, default=[1]), @@ -1038,11 +1036,9 @@ @linalg_structured_op -def pooling_nwc_max_unsigned(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KW, - index_dims=[D.kw]), +def pooling_nwc_max_unsigned(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, + S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), O=TensorDef(U, S.N, S.OW, S.C, output=True), strides=IndexAttrDef(S.SW, default=[1]), dilations=IndexAttrDef(S.DW, default=[1])): @@ -1053,14 +1049,12 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.ow, D.c, D.kw) - O[D.n, D.ow, - D.c] = ReduceFn.max_unsigned[[D.kw]](TypeFn.cast_unsigned( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) + O[D.n, D.ow, D.c] = ReduceFn.max_unsigned[[D.kw]](TypeFn.cast_unsigned( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op -def pooling_ncw_max(I=TensorDef(T1, S.N, S.C, - S.OW * S.SW + S.KW * S.DW), +def pooling_ncw_max(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.KW, index_dims=[D.kw]), O=TensorDef(U, S.N, S.C, S.OW, output=True), strides=IndexAttrDef(S.SW, default=[1]), @@ -1077,8 +1071,7 @@ @linalg_structured_op -def pooling_nwc_min(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), +def pooling_nwc_min(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KW, index_dims=[D.kw]), O=TensorDef(U, S.N, S.OW, S.C, output=True), strides=IndexAttrDef(S.SW, default=[1]), @@ -1095,11 +1088,9 @@ @linalg_structured_op -def pooling_nwc_min_unsigned(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KW, - index_dims=[D.kw]), +def pooling_nwc_min_unsigned(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, + S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), O=TensorDef(U, S.N, S.OW, S.C, output=True), strides=IndexAttrDef(S.SW, default=[1]), dilations=IndexAttrDef(S.DW, default=[1])): @@ -1110,10 +1101,8 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.ow, D.c, D.kw) - O[D.n, D.ow, - D.c] = ReduceFn.min_unsigned[[D.kw]](TypeFn.cast_unsigned( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) - + O[D.n, D.ow, D.c] = ReduceFn.min_unsigned[[D.kw]](TypeFn.cast_unsigned( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -5,6 +5,7 @@ from ._python_test_ops_gen import * from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue + def register_python_test_dialect(context, load=True): from .._mlir_libs import _mlirPythonTest _mlirPythonTest.register_python_test_dialect(context, load) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -17,5 +17,6 @@ assert self is FailurePropagationMode.SUPPRESS return 2 + from .._transform_ops_gen import * from ..._mlir_libs._mlirDialectsTransform import * diff --git a/mlir/python/mlir/execution_engine.py b/mlir/python/mlir/execution_engine.py --- a/mlir/python/mlir/execution_engine.py +++ b/mlir/python/mlir/execution_engine.py @@ -7,9 +7,10 @@ import ctypes __all__ = [ - "ExecutionEngine", + "ExecutionEngine", ] + class ExecutionEngine(_execution_engine.ExecutionEngine): def lookup(self, name): diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -129,8 +129,8 @@ def ranked_memref_to_numpy(ranked_memref): """Converts ranked memrefs to numpy arrays.""" - np_arr = np.ctypeslib.as_array( - ranked_memref[0].aligned, shape=ranked_memref[0].shape) + np_arr = np.ctypeslib.as_array(ranked_memref[0].aligned, + shape=ranked_memref[0].shape) strided_arr = np.lib.stride_tricks.as_strided( np_arr, np.ctypeslib.as_array(ranked_memref[0].shape),