diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -8,7 +8,7 @@ represent actual op definitions (i.e. YAML). """ -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from enum import Enum from ..... import ir as _ir @@ -17,6 +17,10 @@ from .types import * from .yaml_helper import * +############################################################################### +# Tensor expression nodes. +############################################################################### + class TensorExpression: """An expression that can appear on the RHS of a comprehension.""" @@ -24,19 +28,18 @@ def to_scalar_expression(self) -> ScalarExpression: raise NotImplementedError() - def visit_tensor_exprs(self, callback): + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): """Visits all tensor expression reachable by the expression.""" callback(self) def collect_dim_uses(self, uses: Set["DimDef"]): """Collects all DimDefs reachable through this expression.""" - results = set() - def visit_dim_def(dim_def): + def visit_dim_def(dim_def: AffineExprDef): if isinstance(dim_def, DimDef): uses.add(dim_def) - def visit_affine_exprs(expr): + def visit_affine_exprs(expr: "TensorExpression"): if isinstance(expr, TensorUse): for ind in expr.indices: ind.visit_affine_exprs(visit_dim_def) @@ -49,7 +52,7 @@ def collect_tensor_uses(self, uses: Set["TensorUse"]): """Collects all TensorUses reachable through this expression.""" - def visit_tensor_use(expr): + def visit_tensor_use(expr: "TensorExpression"): if isinstance(expr, TensorUse): uses.add(expr) @@ -58,7 +61,7 @@ def collect_indices(self, indices: Set["index"]): """Collects all index accesses reachable through this expression.""" - def visit_index(expr): + def visit_index(expr: "TensorExpression"): if isinstance(expr, index): indices.add(expr) @@ -67,7 +70,7 @@ def collect_scalar_uses(self, uses: Set["ScalarDef"]): """Collects all ScalarDefs reachable through this expression.""" - def visit_scalar_def(expr): + def visit_scalar_def(expr: "TensorExpression"): if isinstance(expr, ScalarDef): uses.add(expr) @@ -111,26 +114,261 @@ assert name is not None, "TensorDef not attached" return name - def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": - return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs) - def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: - """For implicit reductions, computes default reduction dims. - - Assumes that the rhs is the expression being reduced and self is being - reduced into. Any indices referenced on the rhs and not in self are - considered reduction dims and will be ordered as encountered on the rhs. - """ + # Computes the reduction dims for implicit reductions. Assumes that the rhs + # is the expression being reduced and self is being reduced into. Any + # indices referenced on the rhs and not in self are considered reduction + # dims and will be ordered as encountered on the rhs. rhs_dims = set() lhs_dims = set() rhs.collect_dim_uses(rhs_dims) self.collect_dim_uses(lhs_dims) return rhs_dims - lhs_dims + def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": + return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs) + def __repr__(self): return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]" +class TensorArithFn(TensorExpression): + """Application of an arithmetic function.""" + + def __init__(self, arith_fn: "ArithFnType", args: Sequence[TensorExpression]): + self.arith_fn = arith_fn + self.args = tuple(args) + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarArithFn(self.arith_fn.fn_name, + *[arg.to_scalar_expression() for arg in self.args + ]).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + super().visit_tensor_exprs(callback) + for arg in self.args: + arg.visit_tensor_exprs(callback) + + def __repr__(self): + return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})" + + +class TensorTypeFn(TensorExpression): + """Application of a type conversion function.""" + + def __init__(self, type_fn: "TypeFn", type_var: TypeVar, + arg: TensorExpression): + self.type_fn = type_fn + self.type_var = type_var + self.arg = arg + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarTypeFn(self.type_fn.fn_name, self.type_var, + self.arg.to_scalar_expression()).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + super().visit_tensor_exprs(callback) + self.arg.visit_tensor_exprs(callback) + + def __repr__(self): + return f"{repr(self.type_fn)}({self.type_var}, {self.arg})" + + +class TensorReduceFn(TensorExpression): + """Application of a reduction function. + + This captures the lhs (initial value) separately from the rhs. + """ + + def __init__(self, reduce_use: "ReduceFnUse", + args: Sequence[TensorExpression]): + self.reduce_use = reduce_use + self.lhs = None # type: Optional[TensorUse] + self.args = tuple(args) + + def to_scalar_expression(self) -> ScalarExpression: + if self.lhs is None: + raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been " + f"bound to its lhs: {self}") + full_args = [self.lhs.to_scalar_expression() + ] + [arg.to_scalar_expression() for arg in self.args] + return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + for arg in self.args: + arg.visit_tensor_exprs(callback) + + def __repr__(self): + return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" + + +class const(TensorExpression): + """Returns the given constant floating point or integer value.""" + + def __init__(self, value: Any): + with _ir.Context(): + if isinstance(value, float): + self.value = str(_ir.FloatAttr.get_f64(float(value))) + elif isinstance(value, int): + self.value = str( + _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))) + else: + raise ValueError(f"const requires int or float but got {type(value)}") + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarConst(self.value).expr() + + def __repr__(self): + return f"const({self.value})" + + +class index(TensorExpression): + """Returns the iteration index for a given dimension name. + + Resolves the given dimension name to obtain its position in the iteration + domain of the operation. + """ + + def __init__(self, dim: DimDef): + self.dim_def = dim + self.dim = -1 + + def resolve_dimension_name(self, affine_state: AffineBuildState): + self.dim = affine_state.get_dim(self.dim_def.dimname) + + def to_scalar_expression(self) -> ScalarExpression: + assert self.dim != -1, "Dimension name not resolved" + return ScalarIndex(self.dim).expr() + + def __repr__(self): + return f"index({repr(self.dim)})" + + +############################################################################### +# Function types and function definitions. +############################################################################### + + +class TypeFnType: + """Type conversion function. + + A type conversion function takes a target type and a tensor expression and + returns the casted tensor expression. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType": + return TensorTypeFn(self, type_var, arg) + + def __repr__(self): + return f"{self.fn_name}" + + +class TypeFn: + """Type conversion function namespace. + + As the integer types are signless, signedness is implement by different cast + functions that treat integers as signed (`cast`) or unsigned + (`cast_unsigned`) values. + + Examples: + - cast(I32 -> I64) -> `arith.ExtSIOp` + - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` + """ + cast = TypeFnType("cast") + cast_unsigned = TypeFnType("cast_unsigned") + + +class ArithFnType: + """Arithmetic function. + + An arithmetic function takes one ore more tensor expressions and returns the + function evaluation result. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, *args) -> "TensorArithFn": + return TensorArithFn(self, args) + + def __repr__(self): + return f"{self.fn_name}" + + +class ArithFn: + """Arithmetic function namespace. + + As the integer types are signless, signedness is implement by different + functions that treat integers as signed or unsigned values. + + Examples: + - max -> `arith.MaxSIOp` + - max_unsinged -> `arith.MaxUIOp` + """ + add = ArithFnType("add") + exp = ArithFnType("exp") + log = ArithFnType("log") + mul = ArithFnType("mul") + max = ArithFnType("max") + min = ArithFnType("min") + sub = ArithFnType("sub") + max_unsigned = ArithFnType("max_unsigned") + min_unsigned = ArithFnType("min_unsigned") + + +class ReduceFnUse: + """Reduction function use. + + A reduction use specifies the reduction function and dimensions. + """ + + def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef): + self.arith_fn = arith_fn + self.reduce_dims = reduce_dims + + def __call__(self, *args: TensorExpression): + return TensorReduceFn(self, args) + + def __repr__(self): + return (f"reduce_{self.arith_fn.fn_name}" + f"({', '.join(repr(d) for d in self.reduce_dims)})") + + +class ReduceFnType: + """Reduction function. + + An arithmetic function that reduces its RHS into its LHS. + """ + + def __init__(self, arith_fn: ArithFnType): + if not isinstance(arith_fn, ArithFnType): + raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}") + self.arith_fn = arith_fn + + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(self.arith_fn, *reduce_dims) + + def __repr__(self): + return (f"reduce_{self.arith_fn.fn_name}") + + +class ReduceFn: + add = ReduceFnType(ArithFn.add) + mul = ReduceFnType(ArithFn.mul) + max = ReduceFnType(ArithFn.max) + min = ReduceFnType(ArithFn.min) + max_unsigned = ReduceFnType(ArithFn.max_unsigned) + min_unsigned = ReduceFnType(ArithFn.min_unsigned) + + +############################################################################### +# Operand definitions. +############################################################################### + + class OperandKind(Enum): InputTensor = 0 Scalar = 1 @@ -150,7 +388,7 @@ type_var: Optional[TypeVar] = None, size_exprs: Optional[Sequence[AffineExprDef]] = None, index_dims: Optional[Sequence[DimDef]] = None, - default_vals : Optional[Sequence[int]] = None): + default_vals: Optional[Sequence[int]] = None): if type_var and not isinstance(type_var, TypeVar): raise ValueError( f"OperandDef requires a TypeVar but got {repr(type_var)}") @@ -206,7 +444,7 @@ self.operand_def = OperandDef( kind, type_var=type_var, size_exprs=shape, index_dims=index_dims) - def __getitem__(self, dims) -> TensorUse: + def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse: assert self.operand_def.owner, "TensorDef is not attached to an op" state = AffineBuildState( global_state=self.operand_def.owner._affine_state, @@ -225,7 +463,7 @@ exprs.append(expr_def) return TensorUse(self.operand_def, exprs) - def __setitem__(self, dims, value): + def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression): """Creates a new 1:1 comprehension by binding this tensor to an expression. Note that due to the way assignment works in Python, we have to capture @@ -282,6 +520,11 @@ OperandKind.IndexAttr, size_exprs=sizes, default_vals=default) +############################################################################### +# Operation definition. +############################################################################### + + class Comprehension: """Represents a single comprehension.""" @@ -320,232 +563,6 @@ return f"{defs_repr} = {values_repr}" -class TypeFnType: - """Type conversion function. - - A type conversion function takes a target type and a tensor expression and - returns the casted tensor expression. - """ - - def __init__(self, fn_name: str): - self.fn_name = fn_name - - def __call__(self, type_var: TypeVar, - arg: TensorExpression) -> "TensorTypeFn": - return TensorTypeFn(self, type_var, arg) - - def __repr__(self): - return f"{self.fn_name}" - - -class TypeFn: - """Type conversion function namespace. - - As the integer types are signless, signedness is implement by different cast - functions that treat integers as signed (`cast`) or unsigned - (`cast_unsigned`) values. - - Examples: - - cast(I32 -> I64) -> `arith.ExtSIOp` - - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` - """ - cast = TypeFnType("cast") - cast_unsigned = TypeFnType("cast_unsigned") - - -class ArithFnType: - """Arithmetic function. - - An arithmetic function takes one ore more tensor expressions and returns the - function evaluation result. - """ - - def __init__(self, fn_name: str): - self.fn_name = fn_name - - def __call__(self, *args) -> "TensorArithFn": - return TensorArithFn(self, args) - - def __repr__(self): - return f"{self.fn_name}" - - -class ArithFn: - """Arithmetic function namespace. - - As the integer types are signless, signedness is implement by different - functions that treat integers as signed or unsigned values. - - Examples: - - max -> `arith.MaxSIOp` - - max_unsinged -> `arith.MaxUIOp` - """ - add = ArithFnType("add") - exp = ArithFnType("exp") - log = ArithFnType("log") - mul = ArithFnType("mul") - max = ArithFnType("max") - min = ArithFnType("min") - sub = ArithFnType("sub") - max_unsigned = ArithFnType("max_unsigned") - min_unsigned = ArithFnType("min_unsigned") - - -class ReduceFnUse: - """Reduction function use. - - A reduction use specifies the reduction function and dimensions. - """ - - def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef): - self.arith_fn = arith_fn - self.reduce_dims = reduce_dims - - def __call__(self, *args: TensorExpression): - return TensorReduceFn(self, args) - - def __repr__(self): - return (f"reduce_{self.arith_fn.fn_name}" - f"({', '.join(repr(d) for d in self.reduce_dims)})") - - -class ReduceFnType: - """Reduction function. - - An arithmetic function that reduces its RHS into its LHS. - """ - - def __init__(self, arith_fn: ArithFnType): - if not isinstance(arith_fn, ArithFnType): - raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}") - self.arith_fn = arith_fn - - def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: - return ReduceFnUse(self.arith_fn, *reduce_dims) - - def __repr__(self): - return (f"reduce_{self.arith_fn.fn_name}") - - -class ReduceFn: - add = ReduceFnType(ArithFn.add) - mul = ReduceFnType(ArithFn.mul) - max = ReduceFnType(ArithFn.max) - min = ReduceFnType(ArithFn.min) - max_unsigned = ReduceFnType(ArithFn.max_unsigned) - min_unsigned = ReduceFnType(ArithFn.min_unsigned) - - -class TensorArithFn(TensorExpression): - """Application of an arithmetic function.""" - - def __init__(self, arith_fn: ArithFnType, args: Sequence[TensorExpression]): - self.arith_fn = arith_fn - self.args = tuple(args) - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarArithFn(self.arith_fn.fn_name, - *[arg.to_scalar_expression() for arg in self.args - ]).expr() - - def visit_tensor_exprs(self, callback): - super().visit_tensor_exprs(callback) - for arg in self.args: - arg.visit_tensor_exprs(callback) - - def __repr__(self): - return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})" - - -class TensorTypeFn(TensorExpression): - """Application of a type conversion function.""" - - def __init__(self, type_fn: TypeFn, type_var: TypeVar, arg: TensorExpression): - self.type_fn = type_fn - self.type_var = type_var - self.arg = arg - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarTypeFn(self.type_fn.fn_name, self.type_var, - self.arg.to_scalar_expression()).expr() - - def visit_tensor_exprs(self, callback): - super().visit_tensor_exprs(callback) - self.arg.visit_tensor_exprs(callback) - - def __repr__(self): - return f"{repr(self.type_fn)}({self.type_var}, {self.arg})" - - -class const(TensorExpression): - """Returns the given constant floating point or integer value.""" - - def __init__(self, value: Any): - with _ir.Context(): - if isinstance(value, float): - self.value = str(_ir.FloatAttr.get_f64(float(value))) - elif isinstance(value, int): - self.value = str( - _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))) - else: - raise ValueError(f"const requires int or float but got {type(value)}") - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarConst(self.value).expr() - - def __repr__(self): - return f"const({self.value})" - - -class index(TensorExpression): - """Returns the iteration index for a given dimension name. - - Resolves the given dimension name to obtain its position in the iteration - domain of the operation. - """ - - def __init__(self, dim: DimDef): - self.dim_def = dim - self.dim = -1 - - def resolve_dimension_name(self, affine_state: AffineBuildState): - self.dim = affine_state.get_dim(self.dim_def.dimname) - - def to_scalar_expression(self) -> ScalarExpression: - assert self.dim != -1, "Dimension name not resolved" - return ScalarIndex(self.dim).expr() - - def __repr__(self): - return f"index({repr(self.dim)})" - - -class TensorReduceFn(TensorExpression): - """Application of a reduction function. - - This captures the lhs (initial value) separately from the rhs. - """ - - def __init__(self, reduce_use: ReduceFnUse, args: Sequence[TensorExpression]): - self.reduce_use = reduce_use - self.lhs = None # type: Optional[TensorUse] - self.args = tuple(args) - - def to_scalar_expression(self) -> ScalarExpression: - if self.lhs is None: - raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been " - f"bound to its lhs: {self}") - full_args = [self.lhs.to_scalar_expression() - ] + [arg.to_scalar_expression() for arg in self.args] - return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr() - - def visit_tensor_exprs(self, callback): - for arg in self.args: - arg.visit_tensor_exprs(callback) - - def __repr__(self): - return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" - - class OpInterfaceDef: """An interface that an op implements."""