diff --git a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py --- a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py @@ -73,10 +73,10 @@ # TODO: This class layering is awkward. if isinstance(value, DefinedOpCallable): try: - linalg_config = LinalgOpConfig.from_linalg_op_def(value.model) + linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def) except Exception as e: raise ValueError( - f"Could not create LinalgOpConfig from {value.model}") from e + f"Could not create LinalgOpConfig from {value.op_def}") from e configs.extend(linalg_config) # Print. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py @@ -64,9 +64,6 @@ "SymbolDef", ] -# Type aliases. -SymbolPosMap = Dict[str, int] - class AffineBuildState: """Internal state for the AffineExprDef._create impls. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -17,9 +17,6 @@ from .types import * from .yaml_helper import * -# Type aliases. -AffineDimList = Dict[str, _ir.AffineExpr] - class TensorExpression: """An expression that can appear on the RHS of a comprehension.""" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -421,18 +421,18 @@ @staticmethod def from_linalg_op_def( - tc_op_def: LinalgOpDef, + op_def: LinalgOpDef, context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]: """Expands a LinalgOpDef into corresponding Linalg configured ops.""" # TODO: Many LinalgOpDef patterns need to expand to multiple generics. assert len( - tc_op_def.comprehensions) == 1, "Only one comprehension supported" + op_def.comprehensions) == 1, "Only one comprehension supported" return [ LinalgOpConfig( - tc_op_def.metadata, + op_def.metadata, structured_op=LinalgStructuredOpConfig( - tc_op_def.comprehensions[0], tc_op_def.domain, - tc_op_def.registered_operands.values(), context)), + op_def.comprehensions[0], op_def.domain, + op_def.registered_operands.values(), context)), ] def __repr__(self): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -22,12 +22,12 @@ @contextmanager -def bind_op_def(model: LinalgOpDef): +def bind_op_def(op_def: LinalgOpDef): if hasattr(_CONTEXT, "current_op_def"): raise ValueError("Cannot recursively define an operation") - _CONTEXT.current_op_def = model + _CONTEXT.current_op_def = op_def try: - yield model + yield op_def finally: del _CONTEXT.current_op_def @@ -53,9 +53,9 @@ class DefinedOpCallable: """Callable that wraps any defined op function.""" - def __init__(self, op_name: str, model: LinalgOpDef): + def __init__(self, op_name: str, op_def: LinalgOpDef): self.op_name = op_name - self.model = model + self.op_def = op_def def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value], outs: StructuredOpOuts, **kwargs): @@ -73,7 +73,7 @@ f" of type bool but got {type(emit_generic)}") op_configs = LinalgOpConfig.from_linalg_op_def( - self.model, context=ir.Context.current) + self.op_def, context=ir.Context.current) if len(op_configs) != 1: # TODO: Support composite ops. @@ -97,7 +97,7 @@ return emit_named_structured_op( op_config.structured_op, self.op_name, - self.model.metadata.cpp_class_name, + self.op_def.metadata.cpp_class_name, *in_values, outs=out_values, **kwargs) @@ -121,7 +121,7 @@ # Camel case it. op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" - tc_model = LinalgOpDef( + op_def = LinalgOpDef( name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)) # Extract arguments and TensorDefs from the signature. @@ -130,7 +130,7 @@ for param_name, param in sig.parameters.items(): param_default = param.default if isinstance(param_default, (TensorDef, ScalarDef, IndexAttrDef)): - tc_model.add_operand(param_name, param_default.operand_def) + op_def.add_operand(param_name, param_default.operand_def) else: raise ValueError( f"@linalg_structured_op function parameters must be defaulted as " @@ -138,13 +138,13 @@ f"Found {param_name}: {param_default}") dsl_func_args.append(param_default) - # Invoke the DSL func to finish populating the model. - with bind_op_def(tc_model): + # Invoke the DSL func to finish populating the op definition. + with bind_op_def(op_def): dsl_func(*dsl_func_args) # TODO: The returned callable should be an IR emitter but that is not # upstreamed yet. - return DefinedOpCallable(op_name, tc_model) + return DefinedOpCallable(op_name, op_def) def implements(*interfaces: OpInterfaceDef): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -23,6 +23,7 @@ "ValueList", ] +# Type aliases. ValueList = Union[Sequence[Value], OpResultList]