diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -8,7 +8,7 @@ represent actual op definitions (i.e. YAML). """ -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple from mlir import ir as _ir @@ -50,7 +50,7 @@ self.visit_tensor_exprs(visit_affine_exprs) return results - def collect_uses(self, uses: Set["TensorUse"]): + def collect_tensor_uses(self, uses: Set["TensorUse"]): """Collects all TensorUses reachable through this expression.""" def visit_tensor_use(expr): @@ -68,14 +68,14 @@ self.visit_tensor_exprs(visit_index) - def collect_captures(self, captures: Set["CaptureDef"]): - """Collects all CaptureDefs reachable through this expression.""" + def collect_scalar_uses(self, uses: Set["ScalarDef"]): + """Collects all ScalarDefs reachable through this expression.""" - def visit_capture_def(expr): - if isinstance(expr, CaptureDef): - captures.add(expr) + def visit_scalar_def(expr): + if isinstance(expr, ScalarDef): + uses.add(expr) - self.visit_tensor_exprs(visit_capture_def) + self.visit_tensor_exprs(visit_scalar_def) def __add__(self, rhs: "TensorExpression") -> "TensorExpression": return PrimFn.add(self, rhs) @@ -101,19 +101,19 @@ TensorDef.__setitem__ """ - def __init__(self, tensor_def: "TensorDef", indices: Sequence[AffineExprDef]): - self.tensor_def = tensor_def + def __init__(self, operand_def: "OperandDef", + indices: Sequence[AffineExprDef]): + self.operand_def = operand_def self.indices = tuple(indices) def to_scalar_expression(self) -> ScalarExpression: - assert self.tensor_def.tensor_name is not None - return ScalarArg(self.tensor_def.tensor_name).expr() + return ScalarArg(self.tensor_name).expr() @property def tensor_name(self) -> str: - n = self.tensor_def.tensor_name - assert n is not None, "TensorDef not attached" - return n + name = self.operand_def.name + assert name is not None, "TensorDef not attached" + return name def __iadd__(self, rhs: TensorExpression) -> TensorExpression: return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs) @@ -133,40 +133,57 @@ return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]" -class TensorDef: - """Bookkeeping of a single registered tensor, held in dict by name.""" +class OperandDef: + """Definition of a Tensor or Scalar operand passed to an operation.""" - def __init__(self, - type_var: TypeVar, - *shape: AffineExprDef, - indexing_map: Optional[_ir.AffineMap] = None, - output: bool = False): + def __init__(self, type_var: TypeVar, shape: Sequence[AffineExprDef], + scalar: bool, output: bool): if not isinstance(type_var, TypeVar): - raise ValueError(f"TensorDef requires a TypeVar. Got: {repr(type_var)}") + raise ValueError(f"OperandDef requires a TypeVar. Got: {repr(type_var)}") self.owner = None # type: Optional["LinalgOpDef"] self.type_var = type_var self.shape = shape - self.indexing_map = indexing_map + self.scalar = scalar self.output = output - self.tensor_name = None # type: Optional[str] + self.name = None # type: Optional[str] self.registered_index = -1 # type: int - @property - def rank(self) -> int: - """The rank of the tensor.""" - return len(self.shape) - - def attach(self, index: int, tensor_name: str, owner: "LinalgOpDef"): + def attach(self, index: int, name: str, owner: "LinalgOpDef"): if self.owner: - raise ValueError(f"TensorDef already registered with op: {self}") + raise ValueError(f"OperandDef already registered with op: {self}") self.registered_index = index - self.tensor_name = tensor_name + self.name = name self.owner = owner + def __hash__(self): + return hash(id(self)) + + def __repr__(self): + output = "OUTPUT " if self.output else "" + scalar = "SCALAR " if self.scalar else "" + return (f"{self.name}:OperandDef({output}{scalar}" + f"{repr(self.type_var)}, shape={self.shape})") + + +class TensorDef: + """Tensor operand definition. + + Tensor operands are indexed using the associated indexing_map when forwarded + to the body of the structured op. A unique name identifies the tensor operands + and an index determines their position in the operation's parameter list. + """ + + def __init__(self, + type_var: TypeVar, + *shape: AffineExprDef, + output: bool = False): + self.operand_def = OperandDef(type_var, shape, False, output) + def __getitem__(self, dims) -> TensorUse: - assert self.owner, "TensorDef is not attached to an op" + assert self.operand_def.owner, "TensorDef is not attached to an op" state = AffineBuildState( - global_state=self.owner._affine_state, allow_new_symbols=False) + global_state=self.operand_def.owner._affine_state, + allow_new_symbols=False) if not isinstance(dims, tuple): dims = (dims,) # Handle single subscript case. # Special case: (None) is a 0d-scalar use. @@ -179,7 +196,7 @@ raise KeyError( "A TensorDef can only be subscripted by a tuple of affine dims") exprs.append(expr_def) - return TensorUse(self, exprs) + return TensorUse(self.operand_def, exprs) def __setitem__(self, dims, value): """Creates a new 1:1 comprehension by binding this tensor to an expression. @@ -192,46 +209,28 @@ f"Got: {repr(value)}") use = self[dims] comp = Comprehension((use, value)) - self.owner.comprehensions.append(comp) + self.operand_def.owner.comprehensions.append(comp) - def __hash__(self): - return hash(id(self)) - def __repr__(self): - output = "OUTPUT " if self.output else "" - return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, " - f"shape={self.shape})") - - -class CaptureDef(TensorExpression): - """Defines an SSA value captured by the operation. +class ScalarDef(TensorExpression): + """Scalar operand definition. - The captured SSA values are not indexed by the indexing_maps of the - structured op (as opposed to memrefs and tensors). A unique name - identifies the captures and an index determines their position the - operation's parameter list. + Scalar operands are forwarded to the body of the structured op as they are. + A unique name identifies the scalars and an index determines their position in + the operation's parameter list. """ def __init__(self, type_var: TypeVar): - if not isinstance(type_var, TypeVar): - raise ValueError(f"CaptureDef requires a TypeVar. Got: {repr(type_var)}") - self.owner = None # type: Optional["LinalgOpDef"] - self.type_var = type_var - self.capture_name = None # type: Optional[str] - self.registered_index = -1 # type: int + self.operand_def = OperandDef(type_var, (), True, False) - def attach(self, index: int, capture_name: str, owner: "LinalgOpDef"): - if self.owner: - raise ValueError(f"CaptureDef already registered with op: {self}") - self.registered_index = index - self.capture_name = capture_name - self.owner = owner + @property + def scalar_name(self) -> str: + name = self.operand_def.name + assert name is not None, "ScalarDef not attached" + return name def to_scalar_expression(self) -> ScalarExpression: - return ScalarCapture(self.capture_name).expr() - - def __repr__(self): - return (f"{self.capture_name}:CaptureDef({repr(self.type_var)})") + return ScalarArg(self.scalar_name).expr() class Comprehension: @@ -472,43 +471,34 @@ doc: Optional[str] = None): self.metadata = OpMetadataDef( name=name, cpp_class_name=cpp_class_name, doc=doc) - self.registered_tensors = dict() # type: Dict[str, TensorDef] - self.registered_captures = dict() # type: Dict[str, CaptureDef] + self.registered_operands = dict() # type: Dict[str, OperandDef] self.comprehensions = list() # type: List[Comprehension] self._affine_state = AffineBuildState() @property - def inputs(self) -> Sequence[TensorDef]: - return [t for t in self.registered_tensors.values() if not t.output] + def outputs(self) -> Sequence[OperandDef]: + return [ + operand for operand in self.registered_operands.values() + if operand.output + ] - @property - def outputs(self) -> Sequence[TensorDef]: - return [t for t in self.registered_tensors.values() if t.output] - - def add_tensor(self, tensor_name: str, tensor: TensorDef): - """Registers a tensor.""" - if tensor_name in self.registered_tensors: - raise ValueError(f"Tensor {tensor_name} is already registered " - f"to {self.registered_tensors['tensor_name']}") - tensor.attach(len(self.registered_tensors), tensor_name, self) - self.registered_tensors[tensor_name] = tensor - - def add_capture(self, capture_name: str, capture: CaptureDef): - """Registers a capture.""" - if capture_name in self.registered_captures: - raise ValueError(f"Capture {capture_name} is already registered " - f"to {self.registered_captures['capture_name']}") - capture.attach(len(self.registered_captures), capture_name, self) - self.registered_captures[capture_name] = capture + def add_operand(self, name: str, operand: OperandDef): + """Registers an operand.""" + if name in self.registered_operands: + raise ValueError(f"The operand {name} is already registered " + f"to {self.registered_operands['name']}") + if not operand.output and self.outputs: + raise ValueError(f"The operand {name} is an input registered after " + f"the output {self.outputs[-1]}") + operand.attach(len(self.registered_operands), name, self) + self.registered_operands[name] = operand def __repr__(self): lines = [ f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name}," ] - for name, tensor in self.registered_tensors.items(): - lines.append(f" {tensor}") - for name, capture in self.registered_captures.items(): - lines.append(f" {capture}") + for name, operand in self.registered_operands.items(): + lines.append(f" {operand}") if self.comprehensions: lines[-1] += " {" for comprehension in self.comprehensions: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -18,11 +18,7 @@ from .comprehension import * from .yaml_helper import * -__all__ = [ - "LinalgStructuredOpConfig", - "LinalgOpConfig", - "TensorDefConfig", -] +__all__ = ["LinalgStructuredOpConfig", "LinalgOpConfig", "OperandDefConfig"] def _serialize_affine_map(affine_map: _ir.AffineMap) -> str: @@ -43,49 +39,42 @@ return f"Use({self.tensor_use}, indexing_map={self.indexing_map})" -class TensorDefConfig(YAMLObject): - """Wrapper around a TensorDef with additional context-bound state.""" - yaml_tag = "LinalgTensorDef" +class OperandDefConfig(YAMLObject): + """Wrapper containing an operand definition with additional state.""" + yaml_tag = "!LinalgOperandDefConfig" - def __init__(self, tensor_def: TensorDef, shape_map: _ir.AffineMap): - self.tensor_def = tensor_def - self.shape_map = shape_map + def __init__(self, + operand_def: OperandDef, + shape_map: Optional[_ir.AffineMap] = None): + self.operand_def = operand_def + self.shape_map = shape_map # type: Optional[_ir.AffineMap] self.indexing_map = None # type: Optional[_ir.AffineMap] @property - def usage(self) -> str: - if self.tensor_def.output: - return "output" - else: - return "input" - - def to_yaml_custom_dict(self): - return dict( - name=self.tensor_def.tensor_name, - usage=self.usage, - shape=_serialize_affine_map(self.shape_map), - element_type_var=self.tensor_def.type_var.name, - ) - - def __repr__(self): - return f"Def({self.tensor_def}, shape_map={self.shape_map}, indexing_map={self.indexing_map})" + def name(self) -> str: + return self.operand_def.name + @property + def type_var(self) -> TypeVar: + return self.operand_def.type_var -class CaptureDefConfig(YAMLObject): - """Wrapper around a CaptureDef.""" - yaml_tag = "LinalgCaptureDef" - - def __init__(self, capture_def: CaptureDef): - self.capture_def = capture_def + @property + def usage(self) -> str: + if self.operand_def.output: + return "output" + return "input" def to_yaml_custom_dict(self): - return dict( - name=self.capture_def.capture_name, - type_var=self.capture_def.type_var.name, - ) + self_dict = dict(name=self.name) + self_dict["usage"] = self.usage + if not self.operand_def.scalar: + self_dict["shape"] = _serialize_affine_map(self.shape_map) + self_dict["type_var"] = self.type_var.name + return self_dict def __repr__(self): - return f"Def({self.capture_def})" + return (f"OperandDefConfig({self.operand_def}, " + f"shape_map={self.shape_map}, indexing_map={self.indexing_map})") class LinalgIndexingMapsConfig(YAMLObject): @@ -124,67 +113,73 @@ self.context = context if context is not None else _ir.Context() self.affine_state = AffineBuildState() self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] - self.tensor_args = dict() # type: Dict[TensorDef, TensorDefConfig] - self.capture_args = dict() # type: Dict[CaptureDef, CaptureDefConfig] + self.operands = dict() # type: Dict[OperandDef, OperandDefConfig] self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] # Compute the ordered set of writes and collect the tensor, capture, and # index uses. - collected_uses = set() - collected_captures = set() + collected_tensor_uses = set() + collected_scalar_uses = set() collected_indices = set() for write_use, read_use in zip(comprehension.definitions, comprehension.values): self.writes.append((write_use, read_use)) for write_use, read_use in self.writes: - collected_uses.add(write_use) - read_use.collect_uses(collected_uses) - read_use.collect_captures(collected_captures) + collected_tensor_uses.add(write_use) + read_use.collect_tensor_uses(collected_tensor_uses) + read_use.collect_scalar_uses(collected_scalar_uses) read_use.collect_indices(collected_indices) # Need to add all definitions before uses, so process twice. - for use in collected_uses: - self.add_tensor_arg(use.tensor_def) - for capture in collected_captures: - self.add_capture_arg(capture) - for use in collected_uses: - self.add_use(use) + for use in collected_tensor_uses: + self.add_operand(use.operand_def) + for use in collected_scalar_uses: + self.add_operand(use.operand_def) + for use in collected_tensor_uses: + self.add_tensor_use(use) # Now normalize all defs and uses indexing maps now that full count of # dims and symbols are known. for cuse in self.uses.values(): cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) - for cdef in self.tensor_args.values(): - cdef.shape_map = self._normalize_affine_map( - cdef.shape_map, with_dims=False) + for cdef in self.operands.values(): + if not cdef.operand_def.scalar: + cdef.shape_map = self._normalize_affine_map( + cdef.shape_map, with_dims=False) # Now for each write use, propagate the indexing maps from the use to the # tensor, ensuring that there are not conflicts. for write_use, _ in self.writes: - write_tensor_def = self.tensor_args[write_use.tensor_def] - if write_tensor_def.indexing_map: + write_tensor_config = self.operands[write_use.operand_def] + if write_tensor_config.indexing_map: raise ValueError( - f"Unexpected multi-write to a single tensor: {write_tensor_def}") - write_tensor_def.indexing_map = self.uses[write_use].indexing_map + f"Unexpected multi-write to a single tensor: {write_tensor_config}") + write_tensor_config.indexing_map = self.uses[write_use].indexing_map # For each read use, propagate the indexing maps from the use to the # tensor, ensuring that there are not conflicts. for _, read_expr in self.writes: read_uses = set() # type: Set[TensorUse] - read_expr.collect_uses(read_uses) + read_expr.collect_tensor_uses(read_uses) for read_use in read_uses: - read_tensor_def = self.tensor_args[read_use.tensor_def] - if (read_tensor_def.indexing_map and - read_tensor_def.indexing_map != self.uses[read_use].indexing_map): + read_operand_config = self.operands[read_use.operand_def] + if (read_operand_config.indexing_map and + read_operand_config.indexing_map != + self.uses[read_use].indexing_map): raise ValueError( f"Unexpected multi-read of a tensor with different accesses:" - f"{read_tensor_def} vs {read_use}") - read_tensor_def.indexing_map = self.uses[read_use].indexing_map + f"{read_operand_config} vs {read_use}") + read_operand_config.indexing_map = self.uses[read_use].indexing_map + + # Set the indexing map of all scalar uses to the empty map. + for operand_config in self.operands.values(): + if operand_config.operand_def.scalar: + operand_config.indexing_map = self._create_empty_affine_map() # Sanity check that all defs have an indexing map. - assert all(d.indexing_map for d in self.tensor_args.values()), ( - f"Missing indexing map on TensorDef: {self.tensor_args}") + assert all(d.indexing_map for d in self.operands.values()), ( + f"Missing indexing map on OperandConfigDef: {self.operands}") # Collect reduction dims and ensure all the same. all_reduction_dims = set(comprehension.all_reduction_dims) @@ -209,22 +204,10 @@ ] @property - def ordered_tensor_args(self) -> Sequence[TensorDefConfig]: + def ordered_operands(self) -> Sequence[OperandDefConfig]: return sorted( - self.tensor_args.values(), - key=lambda tdc: tdc.tensor_def.registered_index) - - @property - def ordered_tensor_uses(self) -> Sequence[TensorUseConfig]: - return sorted( - self.uses.values(), - key=lambda tuc: tuc.tensor_use.tensor_def.registered_index) - - @property - def ordered_capture_args(self) -> Sequence[CaptureDefConfig]: - return sorted( - self.capture_args.values(), - key=lambda cdc: cdc.capture_def.registered_index) + self.operands.values(), + key=lambda operand: operand.operand_def.registered_index) @property def ordered_dims(self) -> Sequence[Tuple[str, int]]: @@ -238,7 +221,7 @@ @property def indexing_maps(self) -> Sequence[_ir.AffineMap]: - return [use.indexing_map for use in self.ordered_tensor_uses] + return [d.indexing_map for d in self.ordered_operands] @property def iterator_types(self) -> Sequence[str]: @@ -251,23 +234,25 @@ return [get_type(*dim) for dim in self.ordered_dims] - def add_tensor_arg(self, tensor_def: TensorDef): - if tensor_def in self.tensor_args: + def add_operand(self, operand_def: OperandDef): + if operand_def in self.operands: + return + if operand_def.scalar: + self.operands[operand_def] = OperandDefConfig(operand_def) return with self.context: local_state = AffineBuildState( global_state=self.affine_state, allow_new_dims=False) exprs = [] - for expr in tensor_def.shape: + for expr in operand_def.shape: exprs.append(expr.build(state=local_state)) assert local_state.local_dim_count == 0 - indexing_map = _ir.AffineMap.get( + shape_map = _ir.AffineMap.get( dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs) + def_config = OperandDefConfig(operand_def, shape_map) + self.operands[operand_def] = def_config - def_config = TensorDefConfig(tensor_def, indexing_map) - self.tensor_args[tensor_def] = def_config - - def add_use(self, tensor_use: TensorUse): + def add_tensor_use(self, tensor_use: TensorUse): if tensor_use in self.uses: return with self.context: @@ -285,11 +270,13 @@ use_config = TensorUseConfig(tensor_use, indexing_map) self.uses[tensor_use] = use_config - def add_capture_arg(self, capture_def: CaptureDef): - if capture_def in self.capture_args: - return - def_config = CaptureDefConfig(capture_def) - self.capture_args[capture_def] = def_config + def _create_empty_affine_map(self) -> _ir.AffineMap: + """Create an affine map with an empty range.""" + with self.context: + return _ir.AffineMap.get( + dim_count=self.affine_state.dim_count, + symbol_count=self.affine_state.symbol_count, + exprs=list()) def _normalize_affine_map(self, affine_map: _ir.AffineMap, @@ -302,9 +289,7 @@ exprs=list(affine_map.results)) def to_yaml_custom_dict(self): - self_dict = dict(args=self.ordered_tensor_args) - if self.ordered_capture_args: - self_dict["captures"] = self.ordered_capture_args + self_dict = dict(args=self.ordered_operands) # TODO: Refactor the hierarchy internally when supporting more # than static (preserving this serialized form). self_dict["indexing_maps"] = LinalgIndexingMapsConfig( @@ -315,11 +300,8 @@ def __repr__(self): lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"] - lines.append("tensor_args=[") - for def_config in self.ordered_tensor_args: - lines.append(f" {repr(def_config)}") - lines.append("], capture_args=[") - for def_config in self.ordered_capture_args: + lines.append("operands=[") + for def_config in self.ordered_operands: lines.append(f" {repr(def_config)}") lines.append("], indexing_maps=[") for m in self.indexing_maps: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -53,8 +53,8 @@ False, a named form is emitted (which must have been built in to the compiler). """ - op_configs = LinalgOpConfig.from_linalg_op_def(self.model, - context=ir.Context.current) + op_configs = LinalgOpConfig.from_linalg_op_def( + self.model, context=ir.Context.current) if len(op_configs) != 1: # TODO: Support composite ops. @@ -63,8 +63,9 @@ ctx = ir.Context.current linalgDialect = ctx.get_dialect_descriptor("linalg") - fully_qualified_name = 'linalg.' + self.op_name - emit_generic = (emit_generic or not ctx.is_registered_operation(fully_qualified_name)) + fully_qualified_name = "linalg." + self.op_name + emit_generic = ( + emit_generic or not ctx.is_registered_operation(fully_qualified_name)) op_config = op_configs[0] if op_config.structured_op: @@ -72,9 +73,9 @@ return emit_generic_structured_op(op_config.structured_op, *args, **kwargs) else: - return emit_named_structured_op( - op_config.structured_op, self.op_name, - self.model.metadata.cpp_class_name, *args, **kwargs) + return emit_named_structured_op(op_config.structured_op, self.op_name, + self.model.metadata.cpp_class_name, + *args, **kwargs) raise NotImplementedError( f"Emission of linalg op type not supported: {op_config}") @@ -86,9 +87,8 @@ op_class_name=None) -> DefinedOpCallable: if dsl_func is None: # Curry the keyword args in for delayed application. - return functools.partial(tc_def_op, - op_name=op_name, - op_class_name=op_class_name) + return functools.partial( + tc_def_op, op_name=op_name, op_class_name=op_class_name) # Determine default names by introspecting the function. if op_name is None: op_name = dsl_func.__name__ @@ -96,9 +96,8 @@ # Camel case it. op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" - tc_model = LinalgOpDef(name=op_name, - cpp_class_name=op_class_name, - doc=inspect.getdoc(dsl_func)) + tc_model = LinalgOpDef( + name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)) # Extract arguments and TensorDefs from the signature. dsl_func_args = list() @@ -106,12 +105,12 @@ for param_name, param in sig.parameters.items(): param_default = param.default if isinstance(param_default, TensorDef): - tc_model.add_tensor(param_name, param_default) - elif isinstance(param_default, CaptureDef): - tc_model.add_capture(param_name, param_default) + tc_model.add_operand(param_name, param_default.operand_def) + elif isinstance(param_default, ScalarDef): + tc_model.add_operand(param_name, param_default.operand_def) else: raise ValueError(f"@tc_def_op function parameters must be defaulted as " - f"TensorDef(...) or CaptureDef(...): Found {param_name}" + f"TensorDef(...) or ScalarDef(...): Found {param_name}" f": {param_default}") dsl_func_args.append(param_default) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -29,20 +29,15 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, outs: Sequence[Value], - captures: Sequence[Value]): - all_arg_defs = op_config.ordered_tensor_args + *ins: Value, outs: Sequence[Value]): + all_arg_defs = op_config.ordered_operands in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"] out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"] - capture_arg_defs = op_config.ordered_capture_args # Verify outs and captures are sequences. if not isinstance(outs, Sequence): raise ValueError(f"Expected named argument outs to have type Sequence " f"but got {type(outs)}") - if not isinstance(captures, Sequence): - raise ValueError(f"Expected named argument captures to have type Sequence " - f"but got {type(outs)}") # Arity validation. if len(ins) != len(in_arg_defs): @@ -51,9 +46,6 @@ if outs and len(outs) != len(out_arg_defs): raise ValueError(f"Expected {len(out_arg_defs)} outputs but got " f"{len(outs)} for {op_config}") - if captures and len(captures) != len(capture_arg_defs): - raise ValueError(f"Expected {len(capture_arg_defs)} captures but got " - f"{len(captures)} for {op_config}") outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, out_arg_defs, outs) @@ -68,18 +60,10 @@ type_mapping["I64"] = IntegerType.get_signless(64) # Extract type vars for input/output based types. - for arg_def, arg_element_type in zip( - in_arg_defs + out_arg_defs, - _get_shaped_element_types_from_values(*ins, *outs)): - _add_type_mapping(arg_def.tensor_def.type_var.name, arg_element_type, - type_mapping) - - # Extract type vars for captures and compute capture argument mapping. - capture_arg_mapping = dict() # type: Dict[str, Value] - for arg_def, capture_value in zip(capture_arg_defs, captures): - _add_type_mapping(arg_def.capture_def.type_var.name, capture_value.type, - type_mapping) - capture_arg_mapping[arg_def.capture_def.capture_name] = capture_value + block_arg_types = list() # type: List[Type] + for arg_def, arg_element_type in zip(in_arg_defs + out_arg_defs, + _get_types_from_values(*ins, *outs)): + _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types) # Emit the generic op. # TODO: Support emission of pure memref form. @@ -94,18 +78,16 @@ [StringAttr.get(s) for s in op_config.iterator_types]) return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, - type_mapping, capture_arg_mapping, indexing_maps_attr, - iterator_types_attr) + type_mapping, indexing_maps_attr, iterator_types_attr, + block_arg_types) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, - outs: Sequence[Value] = (), - captures: Sequence[Value] = ()): + outs: Sequence[Value] = ()): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - capture_arg_mapping, indexing_maps_attr, iterator_types_attr = \ - prepare_common_structured_op(op_config, *ins, outs = outs, - captures=captures) + indexing_maps_attr, iterator_types_attr, block_arg_types = \ + prepare_common_structured_op(op_config, *ins, outs = outs) generic_op = linalg.GenericOp( result_tensors=result_types, @@ -117,16 +99,14 @@ library_call=None) # TODO: Make optional. # Construct the body. - block_arg_names = _get_tensor_def_names(*in_arg_defs, *out_arg_defs) - block_arg_types = _get_shaped_element_types_from_values(*ins, *outs) + block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs) block = generic_op.regions[0].blocks.append(*block_arg_types) block_arg_mapping = dict(zip(block_arg_names, block.arguments)) with InsertionPoint(block): - body_builder = _BodyBuilder(type_mapping, block_arg_mapping, - capture_arg_mapping) + body_builder = _BodyBuilder(type_mapping, block_arg_mapping) for assignment in op_config.assignments: body_builder.assign(assignment) - body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs)) + body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) if len(result_types) == 1: return generic_op.result @@ -138,12 +118,10 @@ op_name: str, op_class_name: str, *ins: Value, - outs: Sequence[Value] = (), - captures: Sequence[Value] = ()): + outs: Sequence[Value] = ()): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - capture_arg_mapping, indexing_maps_attr, iterator_types_attr = \ - prepare_common_structured_op(op_config, *ins, outs = outs, - captures = captures) + indexing_maps_attr, iterator_types_attr, block_arg_types = \ + prepare_common_structured_op(op_config, *ins, outs = outs) # If we get here, there must exist a builtin class `op_class_name`. ctx = Context.current @@ -173,11 +151,9 @@ """Constructs a structured op body by evaluating assignments.""" def __init__(self, type_mapping: Dict[str, Type], - block_arg_mapping: Dict[str, Value], - capture_arg_mapping: Dict[str, Value]): + block_arg_mapping: Dict[str, Value]): self.type_mapping = type_mapping self.block_arg_mapping = block_arg_mapping - self.capture_arg_mapping = capture_arg_mapping self.yield_mapping = dict() # type: Dict[str, Value] def assign(self, assignment: ScalarAssign): @@ -194,13 +170,6 @@ except KeyError: raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for " f"this structured op.") - elif expr.scalar_capture: - try: - return self.capture_arg_mapping[expr.scalar_capture.capture] - except KeyError: - raise ValueError( - f"Capture {expr.scalar_capture.capture} is not bound for " - f"this structured op.") elif expr.scalar_const: value_attr = Attribute.parse(expr.scalar_const.value) return std.ConstantOp(value_attr.type, value_attr).result @@ -229,7 +198,7 @@ to_type = self.type_mapping[type_var_name] except KeyError: raise ValueError(f"Unbound type variable '{type_var_name}' (" - f"expected one of {self.type_mappings.keys()}") + f"expected one of {self.type_mapping.keys()}") if operand.type == to_type: return operand if _is_integer_type(to_type): @@ -300,9 +269,9 @@ def _infer_structured_outs(op_config: LinalgStructuredOpConfig, - in_arg_defs: Sequence[TensorDefConfig], + in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value], - out_arg_defs: Sequence[TensorDefConfig], + out_arg_defs: Sequence[OperandDefConfig], outs: Sequence[Value]): """Infers implicit outs and output types. @@ -319,28 +288,34 @@ "structured ops") -def _get_shaped_element_types_from_values(*values: Value) -> Sequence[Type]: +def _get_types_from_values(*values: Value) -> Sequence[Type]: types = [] for v in values: - try: - t = ShapedType(v.type) - except Exception as e: - raise ValueError(f"Expected ShapedType but got {v}") from e - types.append(t.element_type) + types.append(v.type) return types -def _get_tensor_def_names( - *tensor_def_configs: TensorDefConfig) -> Sequence[str]: - return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs] +def _get_operand_def_names(*operand_configs: OperandDefConfig) -> Sequence[str]: + return [odc.operand_def.name for odc in operand_configs] -def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]): +def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type, + type_mapping: Dict[str, Type], + block_arg_types: Sequence[Type]): + element_or_self_type = operand_type + # Get the element type for tensor operands and the type itself for scalars. + if operand_config.operand_def.shape: + try: + element_or_self_type = ShapedType(operand_type).element_type + except Exception as e: + raise ValueError(f"Expected ShapedType but got {operand_type}") from e + name = operand_config.type_var.name if name in type_mapping: - if type_mapping[name] != type: + if type_mapping[name] != element_or_self_type: raise ValueError(f"Cannot overwrite type mapping {name} = " - f"{type_mapping[name]} by type {type}") - type_mapping[name] = type + f"{type_mapping[name]} by type {element_or_self_type}") + type_mapping[name] = element_or_self_type + block_arg_types.append(element_or_self_type) def _is_floating_point_type(t: Type) -> bool: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -22,7 +22,6 @@ "ScalarAssign", "ScalarApplyFn", "ScalarArg", - "ScalarCapture", "ScalarConst", "ScalarIndex", "ScalarExpression", @@ -57,19 +56,6 @@ return f"(ScalarArg({self.arg})" -class ScalarCapture: - """A type of ScalarExpression that references a named capture.""" - - def __init__(self, capture: str): - self.capture = capture - - def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_capture=self) - - def __repr__(self): - return f"(ScalarCapture({self.capture})" - - class ScalarConst: """A type of ScalarExpression representing a constant.""" @@ -116,7 +102,6 @@ Can be one of: - ScalarApplyFn - ScalarArg - - ScalarCapture - ScalarConst - ScalarIndex - ScalarSymbolicCast @@ -126,18 +111,15 @@ def __init__(self, scalar_apply: Optional[ScalarApplyFn] = None, scalar_arg: Optional[ScalarArg] = None, - scalar_capture: Optional[ScalarCapture] = None, scalar_const: Optional[ScalarConst] = None, scalar_index: Optional[ScalarIndex] = None, symbolic_cast: Optional[ScalarSymbolicCast] = None): - if (bool(scalar_apply) + bool(scalar_arg) + bool(scalar_capture) + - bool(scalar_const) + bool(scalar_index) + bool(symbolic_cast)) != 1: - raise ValueError( - "One of 'scalar_apply', 'scalar_arg', 'scalar_capture', 'scalar_const', " - "'scalar_index', 'symbolic_cast' must be specified") + if (bool(scalar_apply) + bool(scalar_arg) + bool(scalar_const) + + bool(scalar_index) + bool(symbolic_cast)) != 1: + raise ValueError("One of 'scalar_apply', 'scalar_arg', 'scalar_const', " + "'scalar_index', 'symbolic_cast' must be specified") self.scalar_apply = scalar_apply self.scalar_arg = scalar_arg - self.scalar_capture = scalar_capture self.scalar_const = scalar_const self.scalar_index = scalar_index self.symbolic_cast = symbolic_cast @@ -151,8 +133,6 @@ )) elif self.scalar_arg: return dict(scalar_arg=self.scalar_arg.arg) - elif self.scalar_capture: - return dict(scalar_capture=self.scalar_capture.capture) elif self.scalar_const: return dict(scalar_const=self.scalar_const.value) elif self.scalar_index: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -75,7 +75,11 @@ @linalg_structured_op -def fill_rng_2d(O=TensorDef(T, S.M, S.N, output=True)): +def fill_rng_2d( + min=ScalarDef(F64), + max=ScalarDef(F64), + seed=ScalarDef(I32), + O=TensorDef(T, S.M, S.N, output=True)): """Fills the output tensor with pseudo random numbers. The operation generations pseudo random numbers using a linear congruential @@ -85,13 +89,7 @@ and runs them in parallel. The seed operand and the indices of the data element seed the random number generation. The min and max operands limit the range of the generated random numbers. - - Note: The captures are hard-coded till there is capture support on the C++ - side. """ - min = cast(F64, const(-1000)) - max = cast(F64, const(+1000)) - seed = cast(I32, const(42)) multiplier = cast(I32, const(1103515245)) increment = cast(I32, const(12345)) rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py --- a/mlir/test/python/dialects/linalg/opdsl/arguments.py +++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py @@ -9,15 +9,15 @@ # CHECK: name: A # CHECK: usage: input # CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s2)> -# CHECK: element_type_var: T +# CHECK: type_var: T # CHECK: name: B # CHECK: usage: input # CHECK: shape: affine_map<()[s0, s1, s2] -> (s2, s1)> -# CHECK: element_type_var: T +# CHECK: type_var: T # CHECK: name: C # CHECK: usage: output # CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s1)> -# CHECK: element_type_var: U +# CHECK: type_var: U @linalg_structured_op def matmul( A=TensorDef(T, S.M, S.K), @@ -28,10 +28,11 @@ # CHECK: --- # CHECK-LABEL: fill -# CHECK: captures: -# CHECK: - ! -# CHECK: name: value -# CHECK: type_var: T +# CHECK: args: +# CHECK: name: value +# CHECK: usage: input +# CHECK-NOT: shape: +# CHECK: type_var: T @linalg_structured_op -def fill(O=TensorDef(T, S.M, S.K, output=True), value=CaptureDef(T)): +def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)): O[D.m, D.n] = value diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py --- a/mlir/test/python/dialects/linalg/opdsl/assignments.py +++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py @@ -82,7 +82,7 @@ # CHECK: assignments: # CHECK: - # CHECK: arg: O -# CHECK: scalar_capture: value +# CHECK: scalar_arg: value @linalg_structured_op -def fill(O=TensorDef(T, S.M, S.K, output=True), value=CaptureDef(T)): +def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)): O[D.m, D.n] = value diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py @@ -26,10 +26,10 @@ @linalg_structured_op def fill_rng( - O=TensorDef(T, S.M, S.N, output=True), - min=CaptureDef(F64), - max=CaptureDef(F64), - seed=CaptureDef(I32)): + min=ScalarDef(F64), + max=ScalarDef(F64), + seed=ScalarDef(I32), + O=TensorDef(T, S.M, S.N, output=True)): multiplier = cast(I32, const(1103515245)) increment = cast(I32, const(12345)) rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment @@ -159,7 +159,7 @@ return matmul_poly(lhs, rhs, outs=[init_result]) # CHECK-LABEL: @test_fill_rng - # CHECK-SAME: %{{.*}} tensor<4x16xi32>, %[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32 + # CHECK: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %{{.*}} # CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index # CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index # CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32 @@ -178,10 +178,10 @@ # CHECK-DAG: %[[RND4:.+]] = mulf %{{.+}}, %[[FACT]] : f64 # CHECK-DAG: %[[RND5:.+]] = addf %[[RND4]], %[[MIN]] : f64 # CHECK-DAG: %{{.*}} = fptosi %[[RND5]] : f64 to i32 - @builtin.FuncOp.from_py_func( - RankedTensorType.get((4, 16), i32), f64, f64, i32) - def test_fill_rng(init_result, min, max, seed): - return fill_rng(outs=[init_result], captures=[min, max, seed]) + @builtin.FuncOp.from_py_func(f64, f64, i32, + RankedTensorType.get((4, 16), i32)) + def test_fill_rng(min, max, seed, init_result): + return fill_rng(min, max, seed, outs=[init_result]) print(module) diff --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py --- a/mlir/test/python/dialects/linalg/opsrun.py +++ b/mlir/test/python/dialects/linalg/opsrun.py @@ -43,9 +43,12 @@ fill_boiler = """ func @main() -> i32 attributes {llvm.emit_c_interface} { %O = memref.alloc() : memref<4x16xi32> + %min = constant -1000.0 : f64 + %max = constant 1000.0 : f64 + %seed = constant 42 : i32 - call @fill_on_buffers(%O) : - (memref<4x16xi32>) -> () + call @fill_on_buffers(%min, %max, %seed, %O) : + (f64, f64, i32, memref<4x16xi32>) -> () %c0 = constant 0 : index %0 = memref.load %O[%c0, %c0] : memref<4x16xi32> @@ -128,33 +131,6 @@ test_matmul_generic() -def test_fill_builtin(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f64 = F64Type.get() - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): - - @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), i32)) - def fill_on_buffers(out): - linalg.fill_rng_2d(outs=[out]) - - execution_engine = ExecutionEngine(transform(module, fill_boiler)) - - # TODO: FFI-based solution to allow testing and printing with python code. - # Prepare arguments: one result i32. - # Arguments must be passed as pointers. - c_int_p = ctypes.c_int * 1 - res = c_int_p(-1) - execution_engine.invoke("main", res) - - log("RESULT: ", res[0]) - # CHECK: RESULT: -480 - - -test_fill_builtin() - - def test_fill_generic(): with Context() as ctx, Location.unknown(): module = Module.create() @@ -162,9 +138,9 @@ i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): - @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), i32)) - def fill_on_buffers(out): - linalg.fill_rng_2d(outs=[out]) + @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) + def fill_on_buffers(min, max, seed, out): + linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True) execution_engine = ExecutionEngine(transform(module, fill_boiler))