diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py @@ -232,7 +232,6 @@ """ ALL_DIMS = dict() # type: Dict[str, "DimDef"] - dimname: str def __new__(cls, dimname: str): existing = cls.ALL_DIMS.get(dimname) @@ -276,7 +275,6 @@ True """ ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"] - symname: str def __new__(cls, symname: str): existing = cls.ALL_SYMBOLS.get(symname) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -27,24 +27,49 @@ def to_scalar_expression(self) -> ScalarExpression: raise NotImplementedError() - def visit_affine_exprs(self, callback): - """Visits all affine expressions reachable by the expression.""" - pass + def visit_tensor_exprs(self, callback): + """Visits all tensor expression reachable by the expression.""" + callback(self) def _get_all_dim_defs(self) -> Set[DimDef]: """Recursively gets all DimDef affine expressions that are referenced.""" results = set() - def visitor(affine_expr): - if isinstance(affine_expr, DimDef): - results.add(affine_expr) + def visit_dim_def(dim_def): + if isinstance(dim_def, DimDef): + results.add(dim_def) - self.visit_affine_exprs(visitor) + def visit_affine_exprs(expr): + if isinstance(expr, TensorUse): + for ind in expr.indices: + ind.visit_affine_exprs(visit_dim_def) + if isinstance(expr, ReduceApply): + for ind in expr.reduce.reduce_dims: + ind.visit_affine_exprs(visit_dim_def) + + self.visit_tensor_exprs(visit_affine_exprs) return results def collect_uses(self, uses: Set["TensorUse"]): """Collects all TensorUses reachable through this expression.""" - pass + def visit_tensor_use(expr): + if isinstance(expr, TensorUse): + uses.add(expr) + self.visit_tensor_exprs(visit_tensor_use) + + def collect_indices(self, indices: Set["index"]): + """Collects all index accesses reachable through this expression.""" + def visit_index(expr): + if isinstance(expr, index): + indices.add(expr) + self.visit_tensor_exprs(visit_index) + + def collect_captures(self, captures: Set["CaptureDef"]): + """Collects all CaptureDefs reachable through this expression.""" + def visit_capture_def(expr): + if isinstance(expr, CaptureDef): + captures.add(expr) + self.visit_tensor_exprs(visit_capture_def) def __add__(self, rhs: "TensorExpression") -> "TensorExpression": return PrimFn.add(self, rhs) @@ -84,13 +109,6 @@ assert n is not None, "TensorDef not attached" return n - def visit_affine_exprs(self, callback): - for ind in self.indices: - ind.visit_affine_exprs(callback) - - def collect_uses(self, uses: Set["TensorUse"]): - uses.add(self) - def __iadd__(self, rhs: TensorExpression) -> TensorExpression: return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs) @@ -178,6 +196,29 @@ return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, " f"shape={self.shape})") +class CaptureDef(TensorExpression): + """Bookkeeping of a single registered capture, held in dict by name.""" + + def __init__(self, type_var: TypeVar): + if not isinstance(type_var, TypeVar): + raise ValueError(f"CaptureDef requires a TypeVar. Got: {repr(type_var)}") + self.owner = None # type: Optional["LinalgOpDef"] + self.type_var = type_var + self.capture_name = None # type: Optional[str] + self.registered_index = -1 # type: int + + def attach(self, index: int, capture_name: str, owner: "LinalgOpDef"): + if self.owner: + raise ValueError(f"CaptureDef already registered with op: {self}") + self.registered_index = index + self.capture_name = capture_name + self.owner = owner + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarCap(self.capture_name).expr() + + def __repr__(self): + return (f"{self.capture_name}:CaptureDef({repr(self.type_var)})") class Comprehension: """Represents a single comprehension.""" @@ -279,17 +320,44 @@ *[arg.to_scalar_expression() for arg in self.args ]).expr() - def visit_affine_exprs(self, callback): - for arg in self.args: - arg.visit_affine_exprs(callback) - - def collect_uses(self, uses: Set["TensorUse"]): + def visit_tensor_exprs(self, callback): + super().visit_tensor_exprs(callback) for arg in self.args: - arg.collect_uses(uses) + arg.visit_tensor_exprs(callback) def __repr__(self): return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})" +class const(TensorExpression): + """Constant value.""" + + def __init__(self, type: TypeVar, value: object): + self.type = type + self.value = value + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarConst(self.type, self.value).expr() + + def __repr__(self): + return f"const({self.type}, {self.value})" + +class index(TensorExpression): + """Returns the iteration index for a given dimension.""" + + def __init__(self, dim : DimDef): + self.dim_def = dim + self.dim = -1 + + def resolve_dimension(self, affine_state: AffineBuildState): + self.dim = affine_state.get_dim(self.dim_def.dimname) + + def to_scalar_expression(self) -> ScalarExpression: + assert self.dim != -1, "Dimension not resolved" + return ScalarIndex(self.dim).expr() + + def __repr__(self): + return f"index({repr(self.dim)})" + class cast(TensorExpression): """Casts the element type to a type (typically symbolic TypeVar).""" @@ -302,11 +370,9 @@ return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression()).expr() - def visit_affine_exprs(self, callback): - self.operand.visit_affine_exprs(callback) - - def collect_uses(self, uses: Set["TensorUse"]): - self.operand.collect_uses(uses) + def visit_tensor_exprs(self, callback): + super().visit_tensor_exprs(callback) + self.operand.visit_tensor_exprs(callback) def __repr__(self): return f"cast({self.to_type}, {repr(self.operand)})" @@ -331,15 +397,9 @@ ] + [arg.to_scalar_expression() for arg in self.args] return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr() - def visit_affine_exprs(self, callback): - for ind in self.reduce.reduce_dims: - ind.visit_affine_exprs(callback) - for arg in self.args: - arg.visit_affine_exprs(callback) - - def collect_uses(self, uses: Set["TensorUse"]): + def visit_tensor_exprs(self, callback): for arg in self.args: - arg.collect_uses(uses) + arg.visit_tensor_exprs(callback) def __repr__(self): return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})" @@ -385,6 +445,7 @@ doc: Optional[str] = None): self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc) self.registered_tensors = dict() # type: Dict[str, TensorDef] + self.registered_captures = dict() # type: Dict[str, CaptureDef] self.comprehensions = list() # type: List[Comprehension] self._affine_state = AffineBuildState() @@ -404,12 +465,13 @@ tensor.attach(len(self.registered_tensors), tensor_name, self) self.registered_tensors[tensor_name] = tensor - def tensor(self, name): - """Gets a registered tensor by name.""" - try: - return self.registered_tensors[name] - except KeyError: - raise KeyError(f"Tensor {name} is not registered") + def add_capture(self, capture_name: str, capture: CaptureDef): + """Registers a capture.""" + if capture_name in self.registered_captures: + raise ValueError(f"Capture {capture_name} is already registered " + f"to {self.registered_captures['capture_name']}") + capture.attach(len(self.registered_captures), capture_name, self) + self.registered_captures[capture_name] = capture def __repr__(self): lines = [ diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py @@ -70,6 +70,22 @@ def __repr__(self): return f"Def({self.tensor_def}, shape_map={self.shape_map}, indexing_map={self.indexing_map})" +class CaptureDefConfig(YAMLObject): + """Wrapper around a CaptureDef.""" + yaml_tag = "LinalgCaptureDef" + + def __init__(self, capture_def: CaptureDef): + self.capture_def = capture_def + + def to_yaml_custom_dict(self): + return dict( + name=self.capture_def.capture_name, + type_var=self.capture_def.type_var.name, + ) + + def __repr__(self): + return f"Def({self.capture_def})" + class LinalgIndexingMapsConfig(YAMLObject): """Abstracts the style of indexing maps that the op exports. @@ -109,10 +125,14 @@ self.affine_state = AffineBuildState() self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] self.tensor_args = dict() # type: Dict[TensorDef, TensorDefConfig] + self.capture_args = dict() # type: Dict[CaptureDef, CaptureDefConfig] self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] - # Compute the ordered set of writes. + # Compute the ordered set of writes and collect the tensor, capture, and + # index uses. collected_uses = set() + collected_captures = set() + collected_indices = set() for write_use, read_use in zip(comprehension.definitions, comprehension.values): self.writes.append((write_use, read_use)) @@ -120,10 +140,14 @@ for write_use, read_use in self.writes: collected_uses.add(write_use) read_use.collect_uses(collected_uses) + read_use.collect_captures(collected_captures) + read_use.collect_indices(collected_indices) # Need to add all definitions before uses, so process twice. for use in collected_uses: self.add_tensor_arg(use.tensor_def) + for capture in collected_captures: + self.add_capture_arg(capture) for use in collected_uses: self.add_use(use) @@ -170,6 +194,14 @@ f"dims. Got: {all_reduction_dims}") self.reduction_dims = next(iter(all_reduction_dims)) + # Check the index dimension exists and resolve + for index in collected_indices: + if index.dim_def.dimname not in self.affine_state.all_dims: + raise ValueError( + f"The dimension {index.dim.dimname} is not part of the iteration " + f"domain {self.affine_state.all_dims}") + index.resolve_dimension(self.affine_state) + # Generate the scalar assignments (used to build a body). self.assignments = [ ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression()) @@ -186,6 +218,11 @@ return sorted(self.uses.values(), key=lambda tuc: tuc.tensor_use.tensor_def.registered_index) + @property + def ordered_capture_args(self) -> Sequence[CaptureDefConfig]: + return sorted(self.capture_args.values(), + key=lambda cdc: cdc.capture_def.registered_index) + @property def ordered_dims(self) -> Sequence[Tuple[str, int]]: """Gets the ordered list of dim bindings (symbolic name, position). @@ -245,6 +282,12 @@ use_config = TensorUseConfig(tensor_use, indexing_map) self.uses[tensor_use] = use_config + def add_capture_arg(self, capture_def: CaptureDef): + if capture_def in self.capture_args: + return + def_config = CaptureDefConfig(capture_def) + self.capture_args[capture_def] = def_config + def _normalize_affine_map(self, affine_map: _ir.AffineMap, with_dims: bool = True) -> _ir.AffineMap: @@ -258,6 +301,7 @@ def to_yaml_custom_dict(self): self_dict = dict( args=self.ordered_tensor_args, + caps=self.ordered_capture_args, # TODO: Refactor the hierarchy internally when supporting more # than static (preserving this serialized form). indexing_maps=LinalgIndexingMapsConfig( @@ -272,6 +316,9 @@ lines.append("tensor_args=[") for def_config in self.ordered_tensor_args: lines.append(f" {repr(def_config)}") + lines.append("], capture_args=[") + for def_config in self.ordered_capture_args: + lines.append(f" {repr(def_config)}") lines.append("], indexing_maps=[") for m in self.indexing_maps: lines.append(f" {repr(m)}") diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -105,11 +105,15 @@ sig = inspect.signature(dsl_func) for param_name, param in sig.parameters.items(): param_default = param.default - if not isinstance(param_default, TensorDef): + if isinstance(param_default, TensorDef): + tc_model.add_tensor(param_name, param_default) + elif isinstance(param_default, CaptureDef): + tc_model.add_capture(param_name, param_default) + else: raise ValueError(f"@tc_def_op function parameters must be defaulted as " - f"TensorDef(...): Found {param_name}: {param_default}") + f"TensorDef(...) or CaptureDef(...): Found {param_name}" + f": {param_default}") dsl_func_args.append(param_default) - tc_model.add_tensor(param_name, param_default) # Invoke the DSL func to finish populating the model. with bind_op_def(tc_model): diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -28,10 +28,20 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, - outs: Value): + outs: Sequence[Value], + caps: Sequence[Value]): all_arg_defs = op_config.ordered_tensor_args in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"] out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"] + capture_arg_defs = op_config.ordered_capture_args + + # Verify outs and caps are sequences. + if not isinstance(outs, Sequence): + raise ValueError(f"Expected named argument outs to have type Sequence " + f"but got {type(outs)}") + if not isinstance(caps, Sequence): + raise ValueError(f"Expected named argument caps to have type Sequence " + f"but got {type(outs)}") # Arity validation. if len(ins) != len(in_arg_defs): @@ -40,19 +50,35 @@ if outs and len(outs) != len(out_arg_defs): raise ValueError(f"Expected {len(out_arg_defs)} outputs but got " f"{len(outs)} for {op_config}") + if caps and len(caps) != len(capture_arg_defs): + raise ValueError(f"Expected {len(capture_arg_defs)} captures but got " + f"{len(caps)} for {op_config}") outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, out_arg_defs, outs) result_types = [t for t in out_types if isa(RankedTensorType, t)] - # Extract type vars for input/output based types. + # Initialize the type dictionary with the predefined types. type_mapping = dict() # type: Dict[str, Type] + type_mapping["F32"] = F32Type.get() + type_mapping["F64"] = F64Type.get() + type_mapping["I32"] = IntegerType.get_signless(32) + type_mapping["I64"] = IntegerType.get_signless(64) + + # Extract type vars for input/output based types. for arg_def, arg_element_type in zip( in_arg_defs + out_arg_defs, _get_shaped_element_types_from_values(*ins, *outs)): - tv_name = arg_def.tensor_def.type_var.name - type_mapping[tv_name] = arg_element_type + _add_type_mapping(arg_def.tensor_def.type_var.name, arg_element_type, + type_mapping) + + # Extract type vars for captures and compute capture argument mapping. + capture_arg_mapping = dict() # type: Dict[str, Value] + for arg_def, capture_value in zip(capture_arg_defs, caps): + _add_type_mapping(arg_def.capture_def.type_var.name, capture_value.type, + type_mapping) + capture_arg_mapping[arg_def.capture_def.capture_name] = capture_value # Emit the generic op. # TODO: Support emission of pure memref form. @@ -63,21 +89,21 @@ for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, Context.current)]) iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) - sparse_attr = ArrayAttr.get( - [BoolAttr.get(False) for s in list(ins) + list(outs) if isa(RankedTensorType, s.type)]) - if len(sparse_attr) == 0: - sparse_attr = None + # TODO: Add support for sparse operands once there is a stable interface. + sparse_attr = None return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, - type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr) + type_mapping, capture_arg_mapping, indexing_maps_attr, + iterator_types_attr, sparse_attr) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, - outs: Value = ()): - all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \ - type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ - prepare_common_structured_op(op_config, *ins, outs = outs) + outs: Sequence[Value] = (), + caps: Sequence[Value] = ()): + all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ + capture_arg_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ + prepare_common_structured_op(op_config, *ins, outs = outs, caps=caps) generic_op = linalg.GenericOp( result_tensors=result_types, @@ -95,7 +121,8 @@ block = generic_op.regions[0].blocks.append(*block_arg_types) block_arg_mapping = dict(zip(block_arg_names, block.arguments)) with InsertionPoint(block): - body_builder = _BodyBuilder(type_mapping, block_arg_mapping) + body_builder = _BodyBuilder(type_mapping, block_arg_mapping, + capture_arg_mapping) for assignment in op_config.assignments: body_builder.assign(assignment) body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs)) @@ -110,10 +137,11 @@ op_name: str, op_class_name: str, *ins: Value, - outs: Value = ()): - all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \ - type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ - prepare_common_structured_op(op_config, *ins, outs = outs) + outs: Sequence[Value] = (), + caps: Sequence[Value] = ()): + all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ + capture_arg_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ + prepare_common_structured_op(op_config, *ins, outs = outs, caps = caps) # If we get here, there must exist a builtin class `op_class_name`. ctx = Context.current @@ -127,7 +155,7 @@ linalgDialect = ctx.get_dialect_descriptor("linalg") fill_builtin_region(linalgDialect, named_op.operation) # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps - # attribute that the non-yaml path does not. The non-yaml path hardcodes the + # attribute that the non-yaml path does not. The non-yaml path hardcodes the # indexing_maps in C++ directly. named_op.operation.attributes["linalg.memoized_indexing_maps"] = indexing_maps_attr # iterator_types are hardcoded in C++ both in the yaml and non-yaml path. @@ -141,10 +169,13 @@ class _BodyBuilder: """Constructs a structured op body by evaluating assignments.""" - def __init__(self, type_mapping: Dict[str, Type], - block_arg_mapping: Dict[str, Value]): + def __init__(self, + type_mapping: Dict[str, Type], + block_arg_mapping: Dict[str, Value], + capture_arg_mapping: Dict[str, Value]): self.type_mapping = type_mapping self.block_arg_mapping = block_arg_mapping + self.capture_arg_mapping = capture_arg_mapping self.yield_mapping = dict() # type: Dict[str, Value] def assign(self, assignment: ScalarAssign): @@ -161,6 +192,16 @@ except KeyError: raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for " f"this structured op.") + elif expr.scalar_cap: + try: + return self.capture_arg_mapping[expr.scalar_cap.cap] + except KeyError: + raise ValueError(f"Capture {expr.scalar_cap.cap} is not bound for " + f"this structrued op.") + elif expr.scalar_const: + return self.constant(expr.scalar_const.type.name, expr.scalar_const.value) + elif expr.scalar_index: + return self.index(expr.scalar_index.dim) elif expr.scalar_apply: try: fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}") @@ -177,6 +218,25 @@ return self.cast(expr.symbolic_cast.to_type.name, operand_value) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") + def constant(self, type_var_name: str, value: object) -> Value: + try: + type = self.type_mapping[type_var_name] + except KeyError: + raise ValueError(f"Unbound type variable '{type_var_name}' (" + f"expected one of {self.type_mappings.keys()}") + try: + if(_is_floating_point_type(type)): + return std.ConstantOp(type, FloatAttr.get(type, float(value))).result + elif(_is_integer_type(type)): + return std.ConstantOp(type, IntegerAttr.get(type, int(value))).result + except ValueError: + raise ValueError(f"Unable to cast value {value} to type {type}") + raise NotImplementedError(f"Unimplemented constant type {type}") + + def index(self, dim: int) -> Value: + dim_attr = IntegerAttr.get(IntegerType.get_signless(64), dim) + return linalg.IndexOp(IndexType.get(), dim_attr).result + def cast(self, type_var_name: str, operand: Value) -> Value: try: to_type = self.type_mapping[type_var_name] @@ -198,6 +258,8 @@ operand_type = operand.type if _is_floating_point_type(operand_type): return std.FPToSIOp(to_type, operand).result + if _is_index_type(operand_type): + return std.IndexCastOp(to_type, operand).result # Assume integer. from_width = IntegerType(operand_type).width if to_width > from_width: @@ -234,14 +296,21 @@ def _eval_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return std.AddFOp(lhs.type, lhs, rhs).result - if _is_integer_type(lhs.type): + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return std.AddIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'add' operand: {lhs}") + def _eval_sub(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return std.SubFOp(lhs.type, lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return std.SubIOp(lhs.type, lhs, rhs).result + raise NotImplementedError("Unsupported 'sub' operand: {lhs}") + def _eval_mul(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return std.MulFOp(lhs.type, lhs, rhs).result - if _is_integer_type(lhs.type): + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return std.MulIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operand: {lhs}") @@ -281,6 +350,12 @@ *tensor_def_configs: TensorDefConfig) -> Sequence[str]: return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs] +def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]): + if name in type_mapping: + if type_mapping[name] != type: + raise ValueError(f"Cannot overwrite type mapping {name} = " + f"{type_mapping[name]} by type {type}") + type_mapping[name] = type def _is_floating_point_type(t: Type) -> bool: # TODO: Create a FloatType in the Python API and implement the switch @@ -288,10 +363,11 @@ return (F64Type.isinstance(t) or F32Type.isinstance(t) or F16Type.isinstance(t) or BF16Type.isinstance(t)) - def _is_integer_type(t: Type) -> bool: return IntegerType.isinstance(t) +def _is_index_type(t: Type) -> bool: + return IndexType.isinstance(t) def _get_floating_point_width(t: Type) -> int: # TODO: Create a FloatType in the Python API and implement the switch diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -22,6 +22,9 @@ "ScalarAssign", "ScalarApplyFn", "ScalarArg", + "ScalarCap", + "ScalarConst", + "ScalarIndex", "ScalarExpression", "ScalarSymbolicCast", ] @@ -53,6 +56,42 @@ def __repr__(self): return f"(ScalarArg({self.arg})" +class ScalarCap: + """A type of ScalarExpression that references a named capture.""" + + def __init__(self, cap: str): + self.cap = cap + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_cap=self) + + def __repr__(self): + return f"(ScalarCap({self.cap})" + +class ScalarConst: + """A type of ScalarExpression representing a constant.""" + + def __init__(self, type: TypeVar, value: object): + self.type = type + self.value = value + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_const=self) + + def __repr__(self): + return f"(ScalarConst({self.type}, {self.value})" + +class ScalarIndex: + """A type of ScalarExpression accessing an iteration index.""" + + def __init__(self, dim : int): + self.dim = dim + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_index=self) + + def __repr__(self): + return f"(ScalarIndex({self.dim})" class ScalarSymbolicCast: """A type of ScalarExpression that symbolically casts an operand to a TypeVar. @@ -75,6 +114,9 @@ Can be one of: - ScalarApplyFn - ScalarArg + - ScalarCap + - ScalarConst + - ScalarIndex - ScalarSymbolicCast """ yaml_tag = "!ScalarExpression" @@ -82,13 +124,20 @@ def __init__(self, scalar_apply: Optional[ScalarApplyFn] = None, scalar_arg: Optional[ScalarArg] = None, + scalar_cap: Optional[ScalarCap] = None, + scalar_const: Optional[ScalarConst] = None, + scalar_index: Optional[ScalarIndex] = None, symbolic_cast: Optional[ScalarSymbolicCast] = None): - if (bool(scalar_apply) + bool(scalar_arg) + bool(symbolic_cast)) != 1: + if (bool(scalar_apply) + bool(scalar_arg) + bool(scalar_cap) + + bool(scalar_const) + bool(scalar_index) + bool(symbolic_cast)) != 1: raise ValueError( - "One of 'scalar_apply', 'scalar_block_arg', 'symbolic_cast' must be " - "specified") + "One of 'scalar_apply', 'scalar_arg', 'scalar_cap', 'scalar_const', " + "'scalar_index', 'symbolic_cast' must be specified") self.scalar_apply = scalar_apply self.scalar_arg = scalar_arg + self.scalar_cap = scalar_cap + self.scalar_const = scalar_const + self.scalar_index = scalar_index self.symbolic_cast = symbolic_cast def to_yaml_custom_dict(self): @@ -99,6 +148,13 @@ )) elif self.scalar_arg: return dict(scalar_arg=self.scalar_arg.arg) + elif self.scalar_cap: + return dict(scalar_cap=self.scalar_cap.cap) + elif self.scalar_const: + return dict(scalar_const=dict(type_var=self.scalar_const.type.name, + attributes=[self.scalar_const.value])) + elif self.scalar_index: + return dict(scalar_index=self.scalar_index.dim) elif self.symbolic_cast: # Note that even though operands must be arity 1, we write it the # same way as for apply because it allows handling code to be more diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py @@ -22,6 +22,12 @@ "TypeVar", "TV", + # Predefined types. + "I32", + "I64", + "F32", + "F64", + # TypeVar aliases. "T", "U", @@ -63,6 +69,12 @@ # Expando access via TV.foo TV = TypeVar.create_expando() +# Predefined types. +I32 = TV.I32 +I64 = TV.I64 +F32 = TV.F32 +F64 = TV.F64 + # Some common type name aliases. T = TV.T U = TV.U diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -10,7 +10,7 @@ def matmul(A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): - """Performs a matrix multiplacation of two 2D inputs. + """Performs a matrix multiplication of two 2D inputs. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. @@ -23,7 +23,7 @@ def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), B=TensorDef(T2, Batch, S.K, S.N), C=TensorDef(U, Batch, S.M, S.N, output=True)): - """Performs a batched matrix multiplacation of two 3D inputs. + """Performs a batched matrix multiplication of two 3D inputs. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. @@ -49,7 +49,7 @@ def vecmat(y=TensorDef(T1, S.M), A=TensorDef(T2, S.M, S.N), x=TensorDef(U, S.N, output=True)): - """Performs a vector-matrix multiplacation. + """Performs a vector-matrix multiplication. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. diff --git a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py @@ -23,6 +23,17 @@ C=TensorDef(U, S.M, S.N, output=True)): C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) +@linalg_structured_op +def fill_rng_2d(A=TensorDef(T, S.M, S.N, output=True), + min=CaptureDef(F64), + max=CaptureDef(F64), + seed=CaptureDef(I32)): + multiplier = const(I32, 1103515245) + increment = const(I32, 12345) + temp1 = (cast(I32, index(D.m)) + seed) * multiplier + increment + temp2 = (cast(I32, index(D.n)) + temp1) * multiplier + increment + scaling = (max - min) * const(F64, 2.3283064e-10) + A[D.m, D.n] = cast(T, cast(F64, temp2) * scaling + min) with Context() as ctx, Location.unknown(): module = Module.create() @@ -142,5 +153,27 @@ def test_f64f64f32_matmul(lhs, rhs, init_result): return matmul_poly(lhs, rhs, outs=[init_result]) + # CHECK-LABEL: @test_fill_rng_2d + # CHECK-SAME: %{{.*}} tensor<4x16xi32>, %[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32 + # CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index + # CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index + # CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32 + # CHECK-DAG: %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32 + # CHECK-DAG: %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32 + # CHECK-DAG: %[[CST0:.+]] = constant 1103515245 : i32 + # CHECK-DAG: %[[CST1:.+]] = constant 12345 : i32 + # CHECK-DAG: %[[RND1:.+]] = muli %[[RND0]], %[[CST0]] : i32 + # CHECK-DAG: %[[RND2:.+]] = addi %[[RND1]], %[[CST1]] : i32 + # CHECK: %[[RND3:.+]] = sitofp %{{.*}} : i32 to f64 + # CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64 + # CHECK-DAG: %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64 + # CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64 + # CHECK-DAG: %[[RND4:.+]] = mulf %[[RND3]], %[[FACT]] : f64 + # CHECK-DAG: %[[RND5:.+]] = addf %[[RND4]], %[[MIN]] : f64 + # CHECK-DAG: %{{.*}} = fptosi %[[RND5]] : f64 to i32 + @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32), + f64, f64, i32) + def test_fill_rng_2d(init_result, min, max, seed): + return fill_rng_2d(outs=[init_result], caps=[min, max, seed]) print(module)