diff --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h --- a/mlir/include/mlir-c/AffineMap.h +++ b/mlir/include/mlir-c/AffineMap.h @@ -169,6 +169,13 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, intptr_t numResults); +/// Apply AffineExpr::replace(`map`) to each of the results and return a new +/// new AffineMap with the new results and the specified number of dims and +/// symbols. +MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapReplace( + MlirAffineMap affineMap, MlirAffineExpr expression, + MlirAffineExpr replacement, intptr_t numResultDims, intptr_t numResultSyms); + /// Returns the simplified affine map resulting from dropping the symbols that /// do not appear in any of the individual maps in `affineMaps`. /// Asserts that all maps in `affineMaps` are normalized to the same number of diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -654,6 +654,14 @@ mlirAffineMapGetMinorSubMap(self, nResults); return PyAffineMap(self.getContext(), affineMap); }) + .def("replace", + [](PyAffineMap &self, PyAffineExpr &expression, + PyAffineExpr &replacement, intptr_t numResultDims, + intptr_t numResultSyms) { + MlirAffineMap affineMap = mlirAffineMapReplace( + self, expression, replacement, numResultDims, numResultSyms); + return PyAffineMap(self.getContext(), affineMap); + }) .def_property_readonly( "is_permutation", [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) diff --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp --- a/mlir/lib/CAPI/IR/AffineMap.cpp +++ b/mlir/lib/CAPI/IR/AffineMap.cpp @@ -138,6 +138,15 @@ return wrap(unwrap(affineMap).getMinorSubMap(numResults)); } +MlirAffineMap mlirAffineMapReplace(MlirAffineMap affineMap, + MlirAffineExpr expression, + MlirAffineExpr replacement, + intptr_t numResultDims, + intptr_t numResultSyms) { + return wrap(unwrap(affineMap).replace(unwrap(expression), unwrap(replacement), + numResultDims, numResultSyms)); +} + void mlirAffineMapCompressUnusedSymbols( MlirAffineMap *affineMaps, intptr_t size, void *result, void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m)) { diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -9,6 +9,7 @@ """ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple +from enum import Enum from mlir import ir as _ir @@ -133,18 +134,31 @@ return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]" +class OperandKind(Enum): + InputTensor = 0 + Scalar = 1 + OutputTensor = 2 + Attribute = 3 + + class OperandDef: - """Definition of a Tensor or Scalar operand passed to an operation.""" + """Definition of an operand passed to an operation. + + Keep the meta information of Tensor, Scalar, and Attribute operands and + provide the shared registration functionality. + """ - def __init__(self, type_var: TypeVar, shape: Sequence[AffineExprDef], - scalar: bool, output: bool): + def __init__(self, + kind: OperandKind, + type_var: TypeVar, + size_exprs: Optional[Sequence[AffineExprDef]] = None): if not isinstance(type_var, TypeVar): - raise ValueError(f"OperandDef requires a TypeVar. Got: {repr(type_var)}") + raise ValueError( + f"OperandDef requires a TypeVar but got {repr(type_var)}") self.owner = None # type: Optional["LinalgOpDef"] self.type_var = type_var - self.shape = shape - self.scalar = scalar - self.output = output + self.size_exprs = size_exprs + self.kind = kind self.name = None # type: Optional[str] self.registered_index = -1 # type: int @@ -159,10 +173,8 @@ return hash(id(self)) def __repr__(self): - output = "OUTPUT " if self.output else "" - scalar = "SCALAR " if self.scalar else "" - return (f"{self.name}:OperandDef({output}{scalar}" - f"{repr(self.type_var)}, shape={self.shape})") + return (f"{self.name}:OperandDef(kind={self.kind.name}, " + f"type={repr(self.type_var)}, size_exprs={self.size_exprs})") class TensorDef: @@ -170,14 +182,17 @@ Tensor operands are indexed using the associated indexing_map when forwarded to the body of the structured op. A unique name identifies the tensor operands - and an index determines their position in the operation's parameter list. + and an index determines their position in the operation's parameter list. A + tensor definition takes type, a shape, and an optional flag to mark output + tensors. """ def __init__(self, type_var: TypeVar, *shape: AffineExprDef, output: bool = False): - self.operand_def = OperandDef(type_var, shape, False, output) + kind = OperandKind.OutputTensor if output else OperandKind.InputTensor + self.operand_def = OperandDef(kind, type_var, size_exprs=shape) def __getitem__(self, dims) -> TensorUse: assert self.operand_def.owner, "TensorDef is not attached to an op" @@ -221,7 +236,7 @@ """ def __init__(self, type_var: TypeVar): - self.operand_def = OperandDef(type_var, (), True, False) + self.operand_def = OperandDef(OperandKind.Scalar, type_var) @property def scalar_name(self) -> str: @@ -233,6 +248,22 @@ return ScalarArg(self.scalar_name).expr() +class AttributeDef: + """Index Attribute definition. + + Index attributes provide a way to define and set symbols that can be used in + indexing expressions. Every attribute specifies a tuple of symbols that at + compile-time are replaced by integer values. + """ + yaml_tag = "!LinalgAttributeDef" + + def __init__(self, *sizes: SymbolDef): + if any(not isinstance(size, SymbolDef) for size in sizes): + raise ValueError(f"AttributeDef requires sizes of type SymbolDef but got " + f"{type(sizes)}") + self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes) + + class Comprehension: """Represents a single comprehension.""" @@ -303,7 +334,7 @@ def __init__(self, operator: PrimFnType, *reduce_dims: DimDef): """Initializes the ReduceFn with a primitive function and dims.""" if not isinstance(operator, PrimFnType): - raise ValueError(f"Reduce expected a Prim operator. Got: {operator}") + raise ValueError(f"Reduce expected a Prim operator but got {operator}") self.operator = operator self.reduce_dims = tuple(reduce_dims) @@ -353,7 +384,7 @@ self.value = str( _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))) else: - raise ValueError(f"const requires int or float. Got: {type(value)}") + raise ValueError(f"const requires int or float but got {type(value)}") def to_scalar_expression(self) -> ScalarExpression: return ScalarConst(self.value).expr() @@ -475,21 +506,22 @@ self.comprehensions = list() # type: List[Comprehension] self._affine_state = AffineBuildState() - @property - def outputs(self) -> Sequence[OperandDef]: - return [ - operand for operand in self.registered_operands.values() - if operand.output - ] - def add_operand(self, name: str, operand: OperandDef): """Registers an operand.""" if name in self.registered_operands: raise ValueError(f"The operand {name} is already registered " f"to {self.registered_operands['name']}") - if not operand.output and self.outputs: - raise ValueError(f"The operand {name} is an input registered after " - f"the output {self.outputs[-1]}") + # Ensure output tensors are registered after input tensors and scalars and + # attributes are registered after all other operand types. + registered_kinds = [ + operand.kind.value for operand in self.registered_operands.values() + ] + if registered_kinds: + maximum = max(registered_kinds) + if maximum > operand.kind.value and maximum > OperandKind.Scalar.value: + raise ValueError( + f"The operand {name} of kind {operand.kind.name} is registered " + f"after an operand of kind {OperandKind(maximum).name}") operand.attach(len(self.registered_operands), name, self) self.registered_operands[name] = operand diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -45,9 +45,11 @@ def __init__(self, operand_def: OperandDef, - shape_map: Optional[_ir.AffineMap] = None): + shape_map: Optional[_ir.AffineMap] = None, + attribute_map: Optional[_ir.AffineMap] = None): self.operand_def = operand_def self.shape_map = shape_map # type: Optional[_ir.AffineMap] + self.attribute_map = attribute_map # type: Optional[_ir.AffineMap] self.indexing_map = None # type: Optional[_ir.AffineMap] @property @@ -60,21 +62,25 @@ @property def usage(self) -> str: - if self.operand_def.output: - return "output" - return "input" + if self.operand_def.kind == OperandKind.Attribute: + return "IndexAttribute" + if self.operand_def.kind == OperandKind.OutputTensor: + return "OutputOperand" + return "InputOperand" def to_yaml_custom_dict(self): - self_dict = dict(name=self.name) - self_dict["usage"] = self.usage - if not self.operand_def.scalar: - self_dict["shape"] = _serialize_affine_map(self.shape_map) - self_dict["type_var"] = self.type_var.name + self_dict = dict( + name=self.name, usage=self.usage, type_var=self.type_var.name) + if self.shape_map: + self_dict["shape_map"] = _serialize_affine_map(self.shape_map) + if self.attribute_map: + self_dict["attribute_map"] = _serialize_affine_map(self.attribute_map) return self_dict def __repr__(self): return (f"OperandDefConfig({self.operand_def}, " - f"shape_map={self.shape_map}, indexing_map={self.indexing_map})") + f"shape_map={self.shape_map}, attribute_map={self.attribute_map}, " + f"indexing_map={self.indexing_map})") class LinalgIndexingMapsConfig(YAMLObject): @@ -109,6 +115,7 @@ def __init__(self, comprehension: Comprehension, + registered_operands: Sequence[OperandDef], context: Optional[_ir.Context] = None): self.context = context if context is not None else _ir.Context() self.affine_state = AffineBuildState() @@ -131,22 +138,33 @@ read_use.collect_scalar_uses(collected_scalar_uses) read_use.collect_indices(collected_indices) - # Need to add all definitions before uses, so process twice. + # Collect all attribute definitions + collected_attr_defs = list() + for operand in registered_operands: + if operand.kind == OperandKind.Attribute: + collected_attr_defs.append(operand) + + # Add all definitions before uses, so process twice. for use in collected_tensor_uses: self.add_operand(use.operand_def) for use in collected_scalar_uses: self.add_operand(use.operand_def) + for definition in collected_attr_defs: + self.add_operand(definition) for use in collected_tensor_uses: self.add_tensor_use(use) - # Now normalize all defs and uses indexing maps now that full count of - # dims and symbols are known. + # Normalize all shape and indexing maps now that full count of dims and + # symbols are known. for cuse in self.uses.values(): cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) - for cdef in self.operands.values(): - if not cdef.operand_def.scalar: - cdef.shape_map = self._normalize_affine_map( - cdef.shape_map, with_dims=False) + for operand_config in self.operands.values(): + if operand_config.shape_map: + operand_config.shape_map = self._normalize_affine_map( + operand_config.shape_map, with_dims=False) + if operand_config.attribute_map: + operand_config.attribute_map = self._normalize_affine_map( + operand_config.attribute_map, with_dims=False) # Now for each write use, propagate the indexing maps from the use to the # tensor, ensuring that there are not conflicts. @@ -174,12 +192,16 @@ # Set the indexing map of all scalar uses to the empty map. for operand_config in self.operands.values(): - if operand_config.operand_def.scalar: - operand_config.indexing_map = self._create_empty_affine_map() + if operand_config.operand_def.kind == OperandKind.Scalar: + operand_config.indexing_map = self._get_scalar_map() - # Sanity check that all defs have an indexing map. - assert all(d.indexing_map for d in self.operands.values()), ( - f"Missing indexing map on OperandConfigDef: {self.operands}") + # Check all registered tensor and scalar operands have an indexing map. + for operand in registered_operands: + if operand.kind == OperandKind.Attribute: + continue + if not (operand in self.operands and self.operands[operand].indexing_map): + raise ValueError(f"Failed to compute an indexing map for operand " + f"{operand.name}") # Collect reduction dims and ensure all the same. all_reduction_dims = set(comprehension.all_reduction_dims) @@ -189,7 +211,7 @@ f"dims. Got: {all_reduction_dims}") self.reduction_dims = next(iter(all_reduction_dims)) - # Check the index dimension exists and resolve + # Check the index dimension exists and resolve. for index in collected_indices: if index.dim_def.dimname not in self.affine_state.all_dims: raise ValueError( @@ -221,7 +243,7 @@ @property def indexing_maps(self) -> Sequence[_ir.AffineMap]: - return [d.indexing_map for d in self.ordered_operands] + return [o.indexing_map for o in self.ordered_operands if o.indexing_map] @property def iterator_types(self) -> Sequence[str]: @@ -237,20 +259,24 @@ def add_operand(self, operand_def: OperandDef): if operand_def in self.operands: return - if operand_def.scalar: + if operand_def.kind == OperandKind.Scalar: self.operands[operand_def] = OperandDefConfig(operand_def) return with self.context: local_state = AffineBuildState( global_state=self.affine_state, allow_new_dims=False) exprs = [] - for expr in operand_def.shape: + for expr in operand_def.size_exprs: exprs.append(expr.build(state=local_state)) assert local_state.local_dim_count == 0 - shape_map = _ir.AffineMap.get( + affine_map = _ir.AffineMap.get( dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs) - def_config = OperandDefConfig(operand_def, shape_map) - self.operands[operand_def] = def_config + if operand_def.kind == OperandKind.Attribute: + self.operands[operand_def] = OperandDefConfig( + operand_def, attribute_map=affine_map) + else: + self.operands[operand_def] = OperandDefConfig( + operand_def, shape_map=affine_map) def add_tensor_use(self, tensor_use: TensorUse): if tensor_use in self.uses: @@ -261,7 +287,6 @@ exprs = [] for expr in tensor_use.indices: exprs.append(expr.build(state=local_state)) - assert local_state.local_symbol_count == 0 indexing_map = _ir.AffineMap.get( dim_count=local_state.dim_count, symbol_count=local_state.symbol_count, @@ -270,8 +295,8 @@ use_config = TensorUseConfig(tensor_use, indexing_map) self.uses[tensor_use] = use_config - def _create_empty_affine_map(self) -> _ir.AffineMap: - """Create an affine map with an empty range.""" + def _get_scalar_map(self) -> _ir.AffineMap: + """Create an empty affine map used to index a scalar.""" with self.context: return _ir.AffineMap.get( dim_count=self.affine_state.dim_count, @@ -345,8 +370,9 @@ return [ LinalgOpConfig( tc_op_def.metadata, - structured_op=LinalgStructuredOpConfig(tc_op_def.comprehensions[0], - context)), + structured_op=LinalgStructuredOpConfig( + tc_op_def.comprehensions[0], + tc_op_def.registered_operands.values(), context)), ] def __repr__(self): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -44,15 +44,20 @@ self.op_name = op_name self.model = model - def __call__(self, *args, emit_generic: bool = False, **kwargs): + def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs): """Emits the corresponding op definition as IR. Most arguments are passed through to the underlying emitter. The following - are interpreted here: + keyword argument is interpreted here: emit_generic: Emits a generic form as appropriate (default True). If False, a named form is emitted (which must have been built in to the compiler). """ + emit_generic = kwargs.pop("emit_generic", False) + if not isinstance(emit_generic, bool): + raise ValueError(f"The named argument 'emit_generic' needs to be " + f" of type bool but got {type(emit_generic)}") + op_configs = LinalgOpConfig.from_linalg_op_def( self.model, context=ir.Context.current) @@ -70,12 +75,16 @@ op_config = op_configs[0] if op_config.structured_op: if emit_generic: - return emit_generic_structured_op(op_config.structured_op, *args, - **kwargs) + return emit_generic_structured_op( + op_config.structured_op, *ins, outs=outs, **kwargs) else: - return emit_named_structured_op(op_config.structured_op, self.op_name, - self.model.metadata.cpp_class_name, - *args, **kwargs) + return emit_named_structured_op( + op_config.structured_op, + self.op_name, + self.model.metadata.cpp_class_name, + *ins, + outs=outs, + **kwargs) raise NotImplementedError( f"Emission of linalg op type not supported: {op_config}") @@ -104,14 +113,12 @@ sig = inspect.signature(dsl_func) for param_name, param in sig.parameters.items(): param_default = param.default - if isinstance(param_default, TensorDef): - tc_model.add_operand(param_name, param_default.operand_def) - elif isinstance(param_default, ScalarDef): + if isinstance(param_default, (TensorDef, ScalarDef, AttributeDef)): tc_model.add_operand(param_name, param_default.operand_def) else: raise ValueError(f"@tc_def_op function parameters must be defaulted as " - f"TensorDef(...) or ScalarDef(...): Found {param_name}" - f": {param_default}") + f"TensorDef(...), ScalarDef(...), or AttributeDef(...): " + f"Found {param_name}: {param_default}") dsl_func_args.append(param_default) # Invoke the DSL func to finish populating the model. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -13,6 +13,7 @@ from .scalar_expr import * from .config import * +import numpy as np __all__ = [ "emit_generic_structured_op", @@ -29,12 +30,14 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, outs: Sequence[Value]): + *ins: Value, outs: Sequence[Value], + **attrs: Sequence[int]): all_arg_defs = op_config.ordered_operands - in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"] - out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"] + in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"] + out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"] + attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"] - # Verify outs and captures are sequences. + # Verify outs is a sequence. if not isinstance(outs, Sequence): raise ValueError(f"Expected named argument outs to have type Sequence " f"but got {type(outs)}") @@ -47,6 +50,40 @@ raise ValueError(f"Expected {len(out_arg_defs)} outputs but got " f"{len(outs)} for {op_config}") + # Compute a replacement list for all attribute symbols. + expressions = [] # type: Sequence[AffineExpr] + replacements = [] # type: Sequence[AffineExpr] + for attr in attr_arg_defs: + if attr.name not in attrs: + raise ValueError(f"Expected named argument for the attribute {attr.name}") + attribute_values = attrs.get(attr.name) + if not all(isinstance(value, int) for value in attribute_values): + raise ValueError(f"Attribute {attr.name} needs to be of type " + f"Sequence[int] but got {type(attribute_values)}") + results = attr.attribute_map.results # type: AffineExprList + if len(attribute_values) != len(results): + raise ValueError(f"Attribute {attr.name} has length {len(results)} " + f"but got {len(attribute_values)} values") + for expr, value in zip(results, attribute_values): + expressions.append(expr) + replacements.append(AffineConstantExpr.get(value)) + + # Replace all index attribute symbols by their value. + # TODO: Add support for shape symbols. + indexing_maps = [] # type: Sequence[AffineMap] + for curr in op_config.indexing_maps: + for expression, replacement in zip(expressions, replacements): + curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols) + indexing_maps.append(curr) + + # TODO: Linalg verification does not currently allow symbols. + # Compress them for now and verify none are left. + indexing_maps = AffineMap.compress_unused_symbols(indexing_maps, + Context.current) + if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps): + raise ValueError(f"Expected indexing_maps to use no symbols after " + f"replacement and compression but got {indexing_maps}") + outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, out_arg_defs, outs) @@ -67,27 +104,28 @@ # Emit the generic op. # TODO: Support emission of pure memref form. - indexing_maps_attr = ArrayAttr.get([ - AffineMapAttr.get(am) - # TODO: linalg verification does not currently allow symbols. - # Compress them for now. - for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, - Context.current) - ]) + indexing_maps_attr = ArrayAttr.get( + [AffineMapAttr.get(am) for am in indexing_maps]) iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) + # Compute a dictionary storing all index attributes. + index_attributes = {} # type: Dict[str, DenseElementAttr] + for attr in attr_arg_defs: + attribute_values = attrs.get(attr.name) + array = np.array(attribute_values, dtype=np.int64) + index_attributes[attr.name] = DenseElementsAttr.get(array) + return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, indexing_maps_attr, iterator_types_attr, - block_arg_types) + index_attributes, block_arg_types) -def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, - outs: Sequence[Value] = ()): +def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, + outs: Sequence[Value], **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, block_arg_types = \ - prepare_common_structured_op(op_config, *ins, outs = outs) + indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ + prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) generic_op = linalg.GenericOp( result_tensors=result_types, @@ -114,14 +152,12 @@ return generic_op.results -def emit_named_structured_op(op_config: LinalgStructuredOpConfig, - op_name: str, - op_class_name: str, - *ins: Value, - outs: Sequence[Value] = ()): +def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, + op_class_name: str, *ins: Value, + outs: Sequence[Value], **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, block_arg_types = \ - prepare_common_structured_op(op_config, *ins, outs = outs) + indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ + prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) # If we get here, there must exist a builtin class `op_class_name`. ctx = Context.current @@ -141,6 +177,10 @@ "linalg.memoized_indexing_maps"] = indexing_maps_attr # iterator_types are hardcoded in C++ both in the yaml and non-yaml path. + # Additionally set all named attributes. + for name, value in index_attributes.items(): + named_op.operation.attributes[name] = value + if len(result_types) == 1: return named_op.result else: @@ -304,7 +344,7 @@ block_arg_types: Sequence[Type]): element_or_self_type = operand_type # Get the element type for tensor operands and the type itself for scalars. - if operand_config.operand_def.shape: + if operand_config.shape_map: try: element_or_self_type = ShapedType(operand_type).element_type except Exception as e: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -74,6 +74,19 @@ C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) +@linalg_structured_op +def depthwise_conv_2d_input_nhwc_filter_hwc_poly( + I=TensorDef(T1, S.N, S.IH, S.IW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.S0, S.S1), + dilations=AttributeDef(S.D0, S.D1)): + """A depth-wise 2-D convolution operation.""" + O[D.n, D.oh, D.ow, D.c] += cast( + U, I[D.n, D.oh * S.S0 + D.kh * S.D0, D.ow * S.S1 + D.kw * S.D1, + D.c]) * cast(U, K[D.kh, D.kw, D.c]) + + @linalg_structured_op def fill_rng_2d( min=ScalarDef(F64), diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py --- a/mlir/test/python/dialects/linalg/opdsl/arguments.py +++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py @@ -7,17 +7,17 @@ # CHECK-LABEL: matmul # CHECK: args: # CHECK: name: A -# CHECK: usage: input -# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s2)> +# CHECK: usage: InputOperand # CHECK: type_var: T +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> # CHECK: name: B -# CHECK: usage: input -# CHECK: shape: affine_map<()[s0, s1, s2] -> (s2, s1)> +# CHECK: usage: InputOperand # CHECK: type_var: T +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> # CHECK: name: C -# CHECK: usage: output -# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s1)> +# CHECK: usage: OutputOperand # CHECK: type_var: U +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> @linalg_structured_op def matmul( A=TensorDef(T, S.M, S.K), @@ -30,9 +30,32 @@ # CHECK-LABEL: fill # CHECK: args: # CHECK: name: value -# CHECK: usage: input -# CHECK-NOT: shape: +# CHECK: usage: InputOperand +# CHECK-NOT: shape_map: # CHECK: type_var: T @linalg_structured_op def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)): O[D.m, D.n] = value + + +# CHECK: --- +# CHECK-LABEL: strided_copy +# CHECK: args: +# CHECK: name: I +# CHECK: usage: InputOperand +# CHECK: type_var: T +# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)> +# CHECK: name: O +# CHECK: usage: OutputOperand +# CHECK: type_var: T +# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)> +# CHECK: name: strides +# CHECK: usage: IndexAttribute +# CHECK: type_var: I64 +# CHECK: attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)> +@linalg_structured_op +def strided_copy( + I=TensorDef(T, S.W, S.H), + O=TensorDef(T, S.OH, S.OW, output=True), + strides=AttributeDef(S.S0, S.S1)): + O[D.oh, D.ow] = I[D.h * S.S0, D.w * S.S1] diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py @@ -7,6 +7,9 @@ from mlir.dialects.linalg.opdsl.lang import * +T1 = TV.T1 +T2 = TV.T2 + @linalg_structured_op def matmul_mono( @@ -18,12 +21,24 @@ @linalg_structured_op def matmul_poly( - A=TensorDef(TV.T1, S.M, S.K), - B=TensorDef(TV.T2, S.K, S.N), + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) +@linalg_structured_op +def conv_poly( + I=TensorDef(T1, S.N, S.IH, S.IW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.S0, S.S1), + dilations=AttributeDef(S.D0, S.D1)): + O[D.n, D.oh, D.ow, D.c] += cast( + U, I[D.n, D.oh * S.S0 + D.kh * S.D0, D.ow * S.S1 + D.kw * S.D1, + D.c]) * cast(U, K[D.kh, D.kw, D.c]) + + @linalg_structured_op def fill_rng( min=ScalarDef(F64), @@ -57,6 +72,10 @@ # CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> # CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + # CHECK: #[[$MAPI:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 4 + d5 * 2, d3)> + # CHECK: #[[$MAPK:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> + # CHECK: #[[$MAPO:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + # CHECK-LABEL: func @test_matmul_mono # CHECK-SAME: %[[A:.+]]: tensor<4x16xf32> # CHECK-SAME: %[[B:.+]]: tensor<16x8xf32> @@ -161,17 +180,11 @@ # CHECK-LABEL: @test_fill_rng # CHECK: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %{{.*}} # CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index - # CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index # CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32 - # CHECK-DAG: %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32 # CHECK-DAG: %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32 # CHECK-DAG: %[[CST0:.+]] = constant 1103515245 : i64 # CHECK-DAG: %[[CST0_CAST:.+]] = trunci %[[CST0]] : i64 to i32 - # CHECK-DAG: %[[CST1:.+]] = constant 12345 : i64 - # CHECK-DAG: %[[CST1_CAST:.+]] = trunci %[[CST1]] : i64 to i32 - # CHECK-DAG: %[[RND1:.+]] = muli %[[RND0]], %[[CST0_CAST]] : i32 - # CHECK-DAG: %[[RND2:.+]] = addi %[[RND1]], %[[CST1_CAST]] : i32 - # Skip random number computation for the second index. + # Skip the remaining random number computation and match the scaling logic. # CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64 # CHECK-DAG: %[[CST3:.+]] = constant 2.3283063999999999E-10 : f64 # CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST3]] : f64 @@ -183,5 +196,24 @@ def test_fill_rng(min, max, seed, init_result): return fill_rng(min, max, seed, outs=[init_result]) + # CHECK-LABEL: @test_f32i32_conv + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$MAPI]], #[[$MAPK]], #[[$MAPO]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"] + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32) + # CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32 + # CHECK-NEXT: %[[FILTER_CAST:.+]] = fptosi %[[FILTER:.+]] : f32 to i32 + # CHECK-NEXT: %[[PROD:.+]] = muli %[[IN_CAST]], %[[FILTER_CAST]] : i32 + # CHECK-NEXT: %[[SUM:.+]] = addi %[[OUT]], %[[PROD]] : i32 + # CHECK-NEXT: linalg.yield %[[SUM]] : i32 + # CHECK-NEXT: -> tensor<2x4xi32> + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2, 1), + f32), + RankedTensorType.get((2, 4), i32)) + def test_f32i32_conv(input, filter, init_result): + return conv_poly( + input, filter, outs=[init_result], strides=[2, 4], dilations=[1, 2]) + print(module) diff --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py --- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py +++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py @@ -7,9 +7,9 @@ # dims auto discovered emits the right shape, indexing maps and iterator types. # CHECK: --- # CHECK-LABEL: matmul -# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s2)> -# CHECK: shape: affine_map<()[s0, s1, s2] -> (s2, s1)> -# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s1)> +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> +# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> # CHECK: static_indexing_maps: # CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> # CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)> @@ -19,9 +19,10 @@ # CHECK-NEXT: - parallel # CHECK-NEXT: - reduction @linalg_structured_op -def matmul(A=TensorDef(T, S.M, S.K), - B=TensorDef(T, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): +def matmul( + A=TensorDef(T, S.M, S.K), + B=TensorDef(T, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) @@ -29,9 +30,9 @@ # correctly. # CHECK: --- # CHECK-LABEL: dot -# CHECK: shape: affine_map<()[s0] -> (s0)> -# CHECK: shape: affine_map<()[s0] -> (s0)> -# CHECK: shape: affine_map<()[s0] -> ()> +# CHECK: shape_map: affine_map<()[s0] -> (s0)> +# CHECK: shape_map: affine_map<()[s0] -> (s0)> +# CHECK: shape_map: affine_map<()[s0] -> ()> # CHECK: static_indexing_maps: # CHECK-NEXT: - affine_map<(d0)[s0] -> (d0)> # CHECK-NEXT: - affine_map<(d0)[s0] -> (d0)> diff --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py --- a/mlir/test/python/dialects/linalg/opsrun.py +++ b/mlir/test/python/dialects/linalg/opsrun.py @@ -58,6 +58,30 @@ } """ +conv_boiler = """ +func @main() -> i32 attributes {llvm.emit_c_interface} { + %v0 = constant 0 : i32 + %v1 = constant 1.0 : f64 + %v2 = constant 2.0 : f64 + + %input = memref.alloc() : memref<1x4x16x1xf64> + %filter = memref.alloc() : memref<2x2x1xf64> + %output = memref.alloc() : memref<1x2x4x1xi32> + linalg.fill(%input, %v1) : memref<1x4x16x1xf64>, f64 + linalg.fill(%filter, %v2) : memref<2x2x1xf64>, f64 + linalg.fill(%output, %v0) : memref<1x2x4x1xi32>, i32 + + call @conv_on_buffers(%input, %filter, %output) : + (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> () + + %c0 = constant 0 : index + %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32> + + // TODO: FFI-based solution to allow testing and printing with python code. + return %0 : i32 +} +""" + def transform(module, boilerplate): import mlir.conversions @@ -69,8 +93,9 @@ mod = Module.parse( str(module.operation.regions[0].blocks[0].operations[0].operation) + boilerplate) - pm = PassManager.parse("func(convert-linalg-to-loops, convert-scf-to-std)," + - "convert-vector-to-llvm," + "convert-std-to-llvm") + pm = PassManager.parse("func(convert-linalg-to-loops, lower-affine, " + + "convert-scf-to-std), convert-vector-to-llvm," + + "convert-std-to-llvm") pm.run(mod) return mod @@ -183,3 +208,38 @@ test_fill_generic() + + +def test_conv_generic(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func( + MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2, 1), f64), + MemRefType.get((1, 2, 4, 1), i32)) + def conv_on_buffers(input, filter, output): + linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly( + input, + filter, + outs=[output], + strides=[2, 4], + dilations=[1, 2], + emit_generic=True) + + execution_engine = ExecutionEngine(transform(module, conv_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # CHECK: RESULT: 8 + + +test_conv_generic() diff --git a/mlir/test/python/ir/affine_map.py b/mlir/test/python/ir/affine_map.py --- a/mlir/test/python/ir/affine_map.py +++ b/mlir/test/python/ir/affine_map.py @@ -3,6 +3,7 @@ import gc from mlir.ir import * + def run(f): print("\nTEST:", f.__name__) f() @@ -21,6 +22,7 @@ assert am2 == am1 assert am2.context is ctx + run(testAffineMapCapsule) @@ -97,6 +99,7 @@ # CHECK: number of results out of bounds print(e) + run(testAffineMapGet) @@ -117,6 +120,7 @@ map34 = map5.get_minor_submap(2) print(map34) + run(testAffineMapDerive) @@ -142,6 +146,7 @@ # CHECK: False print(map3.is_projected_permutation) + run(testAffineMapProperties) @@ -175,23 +180,22 @@ print(expr) assert list(map3.results) == [d2, d0, d1] + run(testAffineMapExprs) + # CHECK-LABEL: TEST: testCompressUnusedSymbols def testCompressUnusedSymbols(): with Context() as ctx: - d0, d1, d2 = ( - AffineDimExpr.get(0), - AffineDimExpr.get(1), - AffineDimExpr.get(2)) - s0, s1, s2 = ( - AffineSymbolExpr.get(0), - AffineSymbolExpr.get(1), - AffineSymbolExpr.get(2)) + d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), + AffineDimExpr.get(2)) + s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1), + AffineSymbolExpr.get(2)) maps = [ AffineMap.get(3, 3, [d2, d0, d1]), AffineMap.get(3, 3, [d2, d0 + s2, d1]), - AffineMap.get(3, 3, [d1, d2, d0])] + AffineMap.get(3, 3, [d1, d2, d0]) + ] compressed_maps = AffineMap.compress_unused_symbols(maps, ctx) @@ -207,3 +211,29 @@ run(testCompressUnusedSymbols) + + +# CHECK-LABEL: TEST: testReplace +def testReplace(): + with Context() as ctx: + d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), + AffineDimExpr.get(2)) + s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1), + AffineSymbolExpr.get(2)) + map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0]) + + replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3) + replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3) + replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2) + + # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42) + print(replace0) + + # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0) + print(replace1) + + # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0) + print(replace3) + + +run(testReplace)