diff --git a/mlir/docs/Tools/LinalgOpDsl.md b/mlir/docs/Tools/LinalgOpDsl.md new file mode 100644 --- /dev/null +++ b/mlir/docs/Tools/LinalgOpDsl.md @@ -0,0 +1,116 @@ +# linalg_opdsl tool + +Python based DSL for authoring Linalg op definitions and generating +`linalg.generic` IR based on them for samples. + +The tool `linalg_opdsl` provides a high level DSL for constructing +structured op definitions in a way that can be exported to built-in, named +structured ops via the above YAML-based definitions or used interactively to +emit corresponding `linalg.generic` IR for the composition. + +## Basic usage + +The tool is bundled with the MLIR Python bindings. To use from the CMake build +tree, MLIR must be build with Python bindings enabled +(`-DMLIR_BINDINGS_PYTHON_ENABLED=ON`). Then add the `python` directory in the +build tree to your `PYTHONPATH` environment variable (i.e. +`export PYTHONPATH=$PWD/build/python`). Optionally, use an installed MLIR +package, if available, to avoid building. + +```shell +# Dump the `core_named_ops.py` module as YAML. +python -m python -m mlir.tools.linalg_opdsl.dump_oplib .ops.core_named_ops +``` + +The tool is meant for use during both development and runtime, but not as +a build tool of the core compiler: in order to export static named op +definitions to be built as part of the compiler, the corresponding Linalg +dialect YAML file must be updated and reviewed. TODO: Develop a script to +automate op updates to these files. + +## Language Guide + +This tool is new and rapidly evolving. For language examples, refer to the +built-in ops in the `mlir.tools.linalg_opdsl.ops` package +(`lib/Bindings/Python/mlir/tools/linalg_opdsl/ops` in the repository). + +Using a matmul as an example, we will decompose the language: + +```python +T1 = TV.T1 +T2 = TV.T2 + +@linalg_structured_op +def matmul(A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + """Performs a matrix multiplacation of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) +``` + +Here we have a simple type polymorphic contraction that takes arguments `A` +and `B` and outputs `C`. Each is bound to a `TensorDef`, which specifies: + +* The symbolic element type (`T1`, `T2`, `U` above). +* Symbolic shape expressions with symbols that are bound globally for the op ( +note that in this simple example, the shape expressions are just symbol +references, but they are permitted to be a constrained set of affine +expressions). +* Usage (`output=True`). + +The docstring will be transferred to the op definition verbatim. + +Special identifying op interfaces can be declared for the op via +`implements(interface1[, interface2...])`. + +## Assignments + +The bulk of language consists of assignment expressions of the form above. +The iteration dimension order is determined lexically based on the order +encountered in the expression (following operator precedence if math operators +are used). TODO: Introduce a directive to fix the dimension bindings. + +Reduction dimensions are inferred to be any dimensions on the RHS that are not +on the LHS. + +A number of arithmetic primitive functions are supported: + +* `PrimFn.add(a, b)` (also via overloading the binary `+` operator) +* `PrimFn.exp(a)` +* `PrimFn.log(a)` +* `PrimFn.mul(a, b)` (also via overloading the binary `*` operator) +* `PrimFn.max(a, b)` +* `PrimFn.sub(a, b)` (also via overloading the binary `-` operator) + +Reduction functions can appear as the outer-most function on the RHS: + +* `ReduceFn.add` (also overloading the inplace `+=` on a LHS) +* `ReduceFn.mul` +* `ReduceFn.max` + +There are also special forms: + +* `cast(TypeVar, operand)` + +## Types + +All types in assignment expressions are late bound based on actual input +and output types of constructed ops. Assignment expressions with no `cast` +calls will generally require uniform types throughout and will fail to +verify if violated. The presence of a `cast` allows for a limited form of +numeric type conversion between element types that can be derived from inputs +and outputs (and in the future, attributes). `cast` calls with a `TypeVar` +first argument are emitted as `symbolic_cast` primitives in the YAML definition. + +Casting will perform `int<->float` type conversions and will perform any +necessary extension or truncation within type family. Note that presently, +any integer type is assumed to be signed for the purpose of determing how to +extend or truncate. Supporting unsigned integer types is left for future work. + +Not all functions are applicable for all numeric types, and on mismatch, op +verification will fail. diff --git a/mlir/lib/Bindings/Python/mlir/tools/__init__.py b/mlir/lib/Bindings/Python/mlir/tools/__init__.py new file mode 100644 diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/__init__.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/__init__.py new file mode 100644 diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/dump_oplib.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/dump_oplib.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/dump_oplib.py @@ -0,0 +1,90 @@ +#!/usr/bin/which python +# Command line tool to load an oplib module and dump all of the operations +# it contains in some format. +"""Loads one or more modules containing op definitions and dumps them. + +The dump format can be: + +* `--dump_format=yaml` (default) +* `--dump_format=repr` + +Positional arguments are interpreted as module names (optionally, relative to +this module). Loose module files can be specified via `--file `. + +Sample usage: + # Dump the YAML op definitions for the core named ops (as in the dialect + # source tree). + python -m mlir.tools.linalg_opdsl.dump_oplib .ops.core_named_ops + +Note: YAML output is emitted in "document list" format with each operation +as its own "document". Practically, this means that each operation (or group +of composite ops) is emitted with a "---" preceding it, which can be useful +for testing. +""" + +import argparse +import importlib + +from .lang import * +from .lang.config import * +from .lang.yaml_helper import * + + +def create_arg_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Dump an oplib in various formats") + p.add_argument("modules", + metavar="M", + type=str, + nargs="*", + help="Op module to dump") + p.add_argument("--file", + metavar="F", + type=str, + nargs="*", + help="Python op file to dump") + p.add_argument("--format", + type=str, + dest="format", + default="yaml", + choices=("yaml", "repr"), + help="Format in which to dump") + return p + + +def load_module_from_file(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + m = importlib.util.module_from_spec(spec) + spec.loader.exec_module(m) + return m + + +def main(args): + # Load all configs. + configs = [] + modules = [] + for module_name in args.modules: + modules.append( + importlib.import_module(module_name, package="mlir.tools.linalg_opdsl")) + for i, file_path in enumerate(args.file or []): + modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path)) + for m in modules: + for attr_name, value in m.__dict__.items(): + # TODO: This class layering is awkward. + if isinstance(value, DefinedOpCallable): + try: + linalg_config = LinalgOpConfig.from_linalg_op_def(value.model) + except Exception as e: + raise ValueError( + f"Could not create LinalgOpConfig from {value.model}") from e + configs.extend(linalg_config) + + # Print. + if args.format == "yaml": + print(yaml_dump_all(configs)) + elif args.format == "repr": + for config in configs: + print(repr(config)) + + +if __name__ == "__main__": + main(create_arg_parser().parse_args()) diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/__init__.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/__init__.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/__init__.py @@ -0,0 +1 @@ +from .dsl import * diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/affine.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/affine.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/affine.py @@ -0,0 +1,307 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""DSL for constructing affine expressions and maps. + +These python wrappers allow construction of affine expressions in a more +pythonic fashion that is later instantiated as an IR AffineExpr. Separating the +AST from construction of the map allows for manipulations of symbols and dims +beyond the scope of one expression. + +Affine expression construction: + >>> with _ir.Context(): + ... s = AffineBuildState() + ... (S.K + S.M).build(s) + ... (S.K * S.M).build(s) + ... (S.K // S.M).build(s) + ... (S.K / S.M).build(s) + ... (S.K % 4).build(s) + ... (D.i + D.j * 4).build(s) + ... s + AffineExpr(s0 + s1) + AffineExpr(s0 * s1) + AffineExpr(s0 floordiv s1) + AffineExpr(s0 ceildiv s1) + AffineExpr(s0 mod 4) + AffineExpr(d0 + d1 * 4) + AffineBuildState< + symbols={'K': 0, 'M': 1} + dims={'i': 0, 'j': 1}> + +In the DSL, dimensions and symbols are name-uniqued instances of DimDef and +SymbolDef. There are shortcut "expando" instances that will create a +corresponding DimDef/SymbolDef upon accessing an attribute: + +Referencing a named dimension: + + >>> D.i + Dim(i) + >>> D.a is D.b + False + >>> D.a is D.a + True + +Referencing a named symbol: + + >>> S.foobar + Symbol(foobar) + >>> S.a is S.b + False + >>> S.a is S.a + True +""" + +from typing import Callable, Dict, Optional, Tuple, Union + +from mlir import ir as _ir + +__all__ = [ + "AffineBuildState", + "AffineExprDef", + "D", + "DimDef", + "S", + "SymbolDef", +] + +# Type aliases. +SymbolPosMap = Dict[str, int] + + +class AffineBuildState: + """Internal state for the AffineExprDef._create impls.""" + + def __init__(self, + *, + global_state: "AffineBuildState" = None, + allow_new_symbols: bool = True, + allow_new_dims: bool = True): + if not global_state: + self.all_symbols = dict() # type: Dict[str, int] + self.all_dims = dict() # type: Dict[str, int] + else: + # Alias the global dict. + self.all_symbols = global_state.all_symbols + self.all_dims = global_state.all_dims + + # Map of symbols and dims in the current build. + self.local_symbols = dict() # type: Dict[str, int] + self.local_dims = dict() # type: Dict[str, int] + self.allow_new_symbols = allow_new_symbols + self.allow_new_dims = allow_new_dims + + def get_dim(self, dimname: str) -> int: + """Gets the dim position given a name.""" + pos = self.all_dims.get(dimname) + if pos is None: + if not self.allow_new_dims: + raise ValueError( + f"New dimensions not allowed in the current affine expression: " + f"Requested '{dimname}', Availble: {self.all_dims}") + pos = len(self.all_dims) + self.all_dims[dimname] = pos + self.local_dims[dimname] = pos + return pos + + def get_symbol(self, symname: str) -> int: + """Geta a symbol position given a name.""" + pos = self.all_symbols.get(symname) + if pos is None: + if not self.allow_new_symbols: + raise ValueError( + f"New symbols not allowed in the current affine expression: " + f"Requested '{symname}', Availble: {self.all_symbols}") + pos = len(self.all_symbols) + self.all_symbols[symname] = pos + self.local_symbols[symname] = pos + return pos + + @property + def local_dim_count(self) -> int: + return len(self.local_dims) + + @property + def local_symbol_count(self) -> int: + return len(self.local_symbols) + + @property + def dim_count(self) -> int: + return len(self.all_dims) + + @property + def symbol_count(self) -> int: + return len(self.all_symbols) + + def __repr__(self): + lines = [f"AffineBuildState<"] + lines.append(f" symbols={self.local_symbols}") + lines.append(f" dims={self.local_dims}>") + return "\n".join(lines) + + +class AffineExprDef: + """Base class for an affine expression being defined.""" + + def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr: + """Builds the corresponding _ir.AffineExpr from the definitions. + """ + state = AffineBuildState() if state is None else state + expr = self._create(state) + return expr + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + raise NotImplementedError() + + @staticmethod + def coerce_from(py_value): + if isinstance(py_value, int): + return AffineConstantExpr(py_value) + assert isinstance(py_value, AffineExprDef) + return py_value + + def visit_affine_exprs(self, callback): + """Visits all AffineExprDefs including self.""" + callback(self) + + def __add__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs) + + def __mul__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs) + + def __mod__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs) + + def __floordiv__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs) + + def __truediv__(lhs, rhs): + # TODO: Not really a ceil div - taking liberties for the DSL. + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs) + + +class AffineConstantExpr(AffineExprDef): + """An affine constant being defined.""" + + def __init__(self, value: int): + assert isinstance(value, int) + self.value = value + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + return _ir.AffineConstantExpr.get(self.value) + + def __repr__(self): + return f"Const({self.value})" + + +class AffineBinaryExprDef(AffineExprDef): + """An affine binary expression being defined.""" + + def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef): + self.ir_ctor = ir_ctor + self.lhs = lhs + self.rhs = rhs + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state)) + + def visit_affine_exprs(self, callback): + """Visits all AffineExprDefs including self.""" + super().visit_affine_exprs(callback) + self.lhs.visit_affine_exprs(callback) + self.rhs.visit_affine_exprs(callback) + + def __repr__(self): + return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})" + + +class DimDef(AffineExprDef): + """Represents a named dimension. + + """ + ALL_DIMS = dict() # type: Dict[str, "DimDef"] + dimname: str + + def __new__(cls, dimname: str): + existing = cls.ALL_DIMS.get(dimname) + if existing is not None: + return existing + new = super().__new__(cls) + new.dimname = dimname + cls.ALL_DIMS[dimname] = new + return new + + def __repr__(self): + return f"Dim({self.dimname})" + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + pos = state.get_dim(self.dimname) + return _ir.AffineDimExpr.get(position=pos) + + @classmethod + def create_expando(cls): + """Create an expando class that creates unique symbols based on attr access. + """ + + class ExpandoDims: + + def __getattr__(self, n): + return cls(n) + + return ExpandoDims() + + +D = DimDef.create_expando() + + +class SymbolDef(AffineExprDef): + """Represents a named symbol. + + >>> s1 = SymbolDef("s1") + >>> s1 + Symbol(s1) + >>> s2 = SymbolDef("s2") + >>> s1 is s2 + False + >>> s1 is SymbolDef("s1") + True + """ + ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"] + symname: str + + def __new__(cls, symname: str): + existing = cls.ALL_SYMBOLS.get(symname) + if existing is not None: + return existing + new = super().__new__(cls) + new.symname = symname + cls.ALL_SYMBOLS[symname] = new + return new + + def __repr__(self): + return f"Symbol({self.symname})" + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + pos = state.get_symbol(self.symname) + return _ir.AffineSymbolExpr.get(position=pos) + + @classmethod + def create_expando(cls): + """Create an expando class that creates unique symbols based on attr access. + """ + + class ExpandoSymbols: + + def __getattr__(self, n): + return cls(n) + + return ExpandoSymbols() + + +# Global accessor for on-demand symbols. +S = SymbolDef.create_expando() diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/comprehension.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/comprehension.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/comprehension.py @@ -0,0 +1,429 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Model classes representing a tensor comprehension. + +Note that in this context, "tensor comprehension" refers to the generalized +Einstein notation for computing on multi-dimensional arrays: + https://arxiv.org/pdf/1802.04730.pdf + +These classes model the language more at an AST level as evaluated. Reasoning +about it typically involves processing this form into config objects that +represent actual op definitions (i.e. YAML). +""" + +from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union + +from mlir import ir as _ir + +from .affine import * +from .scalar_expr import * +from .types import * +from .yaml_helper import * + +# Type aliases. +AffineDimList = Dict[str, _ir.AffineExpr] + + +class TensorExpression: + """An expression that can appear on the RHS of a comprehension.""" + + def to_scalar_expression(self) -> ScalarExpression: + raise NotImplementedError() + + def visit_affine_exprs(self, callback): + """Visits all affine expressions reachable by the expression.""" + pass + + def _get_all_dim_defs(self) -> Set[DimDef]: + """Recursively gets all DimDef affine expressions that are referenced.""" + results = set() + + def visitor(affine_expr): + if isinstance(affine_expr, DimDef): + results.add(affine_expr) + + self.visit_affine_exprs(visitor) + return results + + def collect_uses(self, uses: Set["TensorUse"]): + """Collects all TensorUses reachable through this expression.""" + pass + + def __add__(self, rhs: "TensorExpression") -> "TensorExpression": + return PrimFn.add(self, rhs) + + def __mul__(self, rhs) -> "TensorExpression": + return PrimFn.mul(self, rhs) + + def __sub__(self, rhs) -> "TensorExpression": + return PrimFn.sub(self, rhs) + + def __hash__(self): + return hash(id(self)) + + +class TensorUse(TensorExpression): + """A used tensor represented by its (tensor_name, indices). + + Note that forming a comprehension via direct assignment is performed through + __setitem__ on the TensorDef level. However, performing a reduction with + compound ops (+=, *=, etc) is done by doing a: + TensorDef.__getitem__ + TensorUse.__iadd__ + TensorDef.__setitem__ + """ + + def __init__(self, tensor_def: "TensorDef", indices: Sequence[AffineExprDef]): + self.tensor_def = tensor_def + self.indices = tuple(indices) + + def to_scalar_expression(self) -> ScalarExpression: + assert self.tensor_def.tensor_name is not None + return ScalarArg(self.tensor_def.tensor_name).expr() + + @property + def tensor_name(self) -> str: + n = self.tensor_def.tensor_name + assert n is not None, "TensorDef not attached" + return n + + def visit_affine_exprs(self, callback): + for ind in self.indices: + ind.visit_affine_exprs(callback) + + def collect_uses(self, uses: Set["TensorUse"]): + uses.add(self) + + def __iadd__(self, rhs: TensorExpression) -> TensorExpression: + return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs) + + def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: + """For implicit reductions, computes default reduction dims. + + Assumes that the rhs is the expression being reduced and self is being + reduced into. Any indices referenced on the rhs and not in self are + considered reduction dims and will be ordered as encountered on the rhs. + """ + rhs_dims = rhs._get_all_dim_defs() + lhs_dims = self._get_all_dim_defs() + return rhs_dims - lhs_dims + + def __repr__(self): + return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]" + + +class TensorDef: + """Bookkeeping of a single registered tensor, held in dict by name.""" + + def __init__(self, + type_var: TypeVar, + *shape: AffineExprDef, + indexing_map: Optional[_ir.AffineMap] = None, + output: bool = False): + if not isinstance(type_var, TypeVar): + raise ValueError(f"TensorDef requires a TypeVar. Got: {repr(type_var)}") + self.owner = None # type: Optional["LinalgOpDef"] + self.type_var = type_var + self.shape = shape + self.indexing_map = indexing_map + self.output = output + self.tensor_name = None # type: Optional[str] + self.registered_index = -1 # type: int + + @property + def rank(self) -> int: + """The rank of the tensor.""" + return len(self.shape) + + def attach(self, index: int, tensor_name: str, owner: "LinalgOpDef"): + if self.owner: + raise ValueError(f"TensorDef already registered with op: {self}") + self.registered_index = index + self.tensor_name = tensor_name + self.owner = owner + + def __getitem__(self, dims) -> TensorUse: + assert self.owner, "TensorDef is not attached to an op" + state = AffineBuildState(global_state=self.owner._affine_state, + allow_new_symbols=False) + if not isinstance(dims, tuple): + dims = (dims,) # Handle single subscript case. + # Special case: (None) is a 0d-scalar use. + if dims == (None,): + dims = () + + exprs = [] + for expr_def in dims: + if not isinstance(expr_def, AffineExprDef): + raise KeyError( + "A TensorDef can only be subscripted by a tuple of affine dims") + exprs.append(expr_def) + return TensorUse(self, exprs) + + def __setitem__(self, dims, value): + """Creates a new 1:1 comprehension by binding this tensor to an expression. + + Note that due to the way assignment works in Python, we have to capture + direct assignment as a setitem on the TensorDef. + """ + if not isinstance(value, TensorExpression): + raise ValueError(f"Only TensorExpressions can be assigned to TensorDefs. " + f"Got: {repr(value)}") + use = self[dims] + comp = Comprehension((use, value)) + self.owner.comprehensions.append(comp) + + def __hash__(self): + return hash(id(self)) + + def __repr__(self): + output = "OUTPUT " if self.output else "" + return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, " + f"shape={self.shape})") + + +class Comprehension: + """Represents a single comprehension.""" + + def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]): + self.definitions = list() # List[TensorUse] + self.values = list() # List[TensorExpression] + + # Find the lhs to reduction rhs. + for assign, value in bindings: + if isinstance(value, ReduceApply): + if value.lhs: + raise ValueError(f"Reduction expression already assigns: {value}") + value.lhs = assign + self.definitions.append(assign) + self.values.append(value) + + @property + def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]: + """Gets the reduction dims for the comprehension or None.""" + result = set() + for use in self.values: + if isinstance(use, ReduceApply): + result.add(use.reduce.reduce_dims) + else: + result.add(tuple()) + return result + + def __repr__(self): + if len(self.definitions) > 1: + defs_repr = f"({', '.join(repr(d) for d in self.definitions)})" + values_repr = f"({', '.join(repr(v) for v in self.values)})" + else: + defs_repr = f"{repr(self.definitions[0])}" + values_repr = f"{repr(self.values[0])}" + + return f"{defs_repr} = {values_repr}" + + +class PrimFnType: + """Primitive operations.""" + + def __init__(self, prim_name: str): + self.prim_name = prim_name + + def __call__(self, *args): + return PrimApply(self, args) + + def reduce(self, *reduce_dims: DimDef): + """Shortcut to create a Reduce operation from this primitive.""" + return ReduceFnType(self, *reduce_dims) + + def __repr__(self): + return f"{self.prim_name}" + + +class PrimFn: + add = PrimFnType("add") + exp = PrimFnType("exp") + log = PrimFnType("log") + mul = PrimFnType("mul") + max = PrimFnType("max") + sub = PrimFnType("sub") + + +class ReduceFnType: + """A reduction operator that reduces into its LHS from its RHS.""" + + def __init__(self, operator: PrimFnType, *reduce_dims: DimDef): + """Initializes the ReduceFn with a primitive function and dims.""" + if not isinstance(operator, PrimFnType): + raise ValueError(f"Reduce expected a Prim operator. Got: {operator}") + self.operator = operator + self.reduce_dims = tuple(reduce_dims) + + def __call__(self, *args: TensorExpression): + return ReduceApply(self, args) + + def __repr__(self): + return (f"reduce_{self.operator.prim_name}" + f"({', '.join(repr(d) for d in self.reduce_dims)})") + + +class ReduceFn: + add = PrimFn.add.reduce + mul = PrimFn.mul.reduce + max = PrimFn.max.reduce + + +class PrimApply(TensorExpression): + """Application of a primitive.""" + + def __init__(self, prim: PrimFnType, args: Sequence[TensorExpression]): + self.prim = prim + self.args = tuple(args) + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarApplyFn(self.prim.prim_name, + *[arg.to_scalar_expression() for arg in self.args + ]).expr() + + def visit_affine_exprs(self, callback): + for arg in self.args: + arg.visit_affine_exprs(callback) + + def collect_uses(self, uses: Set["TensorUse"]): + for arg in self.args: + arg.collect_uses(uses) + + def __repr__(self): + return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})" + + +class cast(TensorExpression): + """Casts the element type to a type (typically symbolic TypeVar).""" + + def __init__(self, to_type: TypeVar, operand: TensorExpression): + self.to_type = to_type + self.operand = operand + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarSymbolicCast(self.to_type, + self.operand.to_scalar_expression()).expr() + + def visit_affine_exprs(self, callback): + self.operand.visit_affine_exprs(callback) + + def collect_uses(self, uses: Set["TensorUse"]): + self.operand.collect_uses(uses) + + def __repr__(self): + return f"cast({self.to_type}, {repr(self.operand)})" + + +class ReduceApply(TensorExpression): + """Application of a reduction. + + This captures the lhs separately (initial value) separately from the rhs. + """ + + def __init__(self, reduce: ReduceFnType, args: Sequence[TensorExpression]): + self.reduce = reduce + self.lhs = None # type: Optional[TensorUse] + self.args = tuple(args) + + def to_scalar_expression(self) -> ScalarExpression: + if self.lhs is None: + raise ValueError(f"Cannot scalarize a ReduceApply that has not been " + f"bound to its lhs: {self}") + full_args = [self.lhs.to_scalar_expression() + ] + [arg.to_scalar_expression() for arg in self.args] + return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr() + + def visit_affine_exprs(self, callback): + for ind in self.reduce.reduce_dims: + ind.visit_affine_exprs(callback) + for arg in self.args: + arg.visit_affine_exprs(callback) + + def collect_uses(self, uses: Set["TensorUse"]): + for arg in self.args: + arg.collect_uses(uses) + + def __repr__(self): + return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})" + + +class OpInterfaceDef: + """An interface that an op implements.""" + + def __init__(self, cpp_name: str): + self.cpp_name = cpp_name + + +ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") + + +class OpMetadataDef(YAMLObject): + """Metadata about the op (generally not behavior impacting).""" + yaml_tag = "!LinalgOpMetadata" + + def __init__(self, name: str, cpp_op_name: Optional[str], doc: Optional[str]): + self.name = name + self.cpp_op_name = cpp_op_name if cpp_op_name is not None else name + self.doc = doc + self.implements = [] # type: List[OpInterfaceDef] + + def to_yaml_custom_dict(self): + d = dict( + name=self.name, + cpp_op_name=self.cpp_op_name, + doc=self.doc, + ) + if self.implements: + d["implements"] = [intr.cpp_name for intr in self.implements] + return d + + +class LinalgOpDef: + """Definition of a linalg op.""" + + def __init__(self, + name: str, + cpp_op_name: Optional[str] = None, + doc: Optional[str] = None): + self.metadata = OpMetadataDef(name=name, cpp_op_name=cpp_op_name, doc=doc) + self.registered_tensors = dict() # type: Dict[str, TensorDef] + self.comprehensions = list() # type: List[Comprehension] + self._affine_state = AffineBuildState() + + @property + def inputs(self) -> Sequence[TensorDef]: + return [t for t in self.registered_tensors.values() if not t.output] + + @property + def outputs(self) -> Sequence[TensorDef]: + return [t for t in self.registered_tensors.values() if t.output] + + def add_tensor(self, tensor_name: str, tensor: TensorDef): + """Registers a tensor.""" + if tensor_name in self.registered_tensors: + raise ValueError(f"Tensor {tensor_name} is already registered " + f"to {self.registered_tensors['tensor_name']}") + tensor.attach(len(self.registered_tensors), tensor_name, self) + self.registered_tensors[tensor_name] = tensor + + def tensor(self, name): + """Gets a registered tensor by name.""" + try: + return self.registered_tensors[name] + except KeyError: + raise KeyError(f"Tensor {name} is not registered") + + def __repr__(self): + lines = [ + f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_op_name}," + ] + for name, tensor in self.registered_tensors.items(): + lines.append(f" {tensor}") + if self.comprehensions: + lines[-1] += " {" + for comprehension in self.comprehensions: + lines.append(f" {comprehension}") + lines.append("}") + return "\n".join(lines) diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/config.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/config.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/config.py @@ -0,0 +1,321 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Represents configured ops as emitted for code generation. + +Classes in this module generally are directly serializable to YAML for use +by the code generator. + +TODO: These should just be dumb containers or serialization code but they +currently encode too many details of how the language is interpreted. Move this +to helpers on the comprehension objects themselves. +""" + +from typing import Any, Dict, Optional + +from mlir import ir as _ir + +from .comprehension import * +from .yaml_helper import * + +__all__ = [ + "LinalgStructuredOpConfig", + "LinalgOpConfig", +] + + +def _serialize_affine_map(affine_map: _ir.AffineMap) -> str: + with affine_map.context: + # Affine map printing/parsing is via an AffineMap attr. + attr = _ir.AffineMapAttr.get(affine_map) + return str(attr) + + +class TensorUseConfig: + """Wrapper around a TensorUse with additional context-bound state.""" + + def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap): + self.tensor_use = tensor_use + self.indexing_map = indexing_map + + def __repr__(self): + return f"Use({self.tensor_use}, indexing_map={self.indexing_map})" + + +class TensorDefConfig(YAMLObject): + """Wrapper around a TensorDef with additional context-bound state.""" + yaml_tag = "LinalgTensorDef" + + def __init__(self, tensor_def: TensorDef, shape_map: _ir.AffineMap): + self.tensor_def = tensor_def + self.shape_map = shape_map + self.indexing_map = None # type: Optional[_ir.AffineMap] + + def to_yaml_custom_dict(self): + + def get_usage(): + if self.tensor_def.output: + return "output" + else: + return "input" + + return dict( + name=self.tensor_def.tensor_name, + usage=get_usage(), + shape=_serialize_affine_map(self.shape_map), + element_type_var=self.tensor_def.type_var.name, + ) + + def __repr__(self): + return f"Def({self.tensor_def}, shape_map={self.shape_map}, indexing_map={self.indexing_map})" + + +class LinalgIndexingMapsConfig(YAMLObject): + """Abstracts the style of indexing maps that the op exports. + + Presently only static (tied to the op name) indexing maps are supported. In + the future, it is expected that we will have additional variants: + - Dynamic based on attributes + - Dynamic based on operands + Each is expected to require a different variant of specification. + """ + yaml_tag = "!LinalgIndexingMapsConfig" + + def __init__(self, + static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None): + self.static_indexing_maps = static_indexing_maps + + def to_yaml_custom_dict(self): + if self.static_indexing_maps is not None: + return dict(static_indexing_maps=[ + _serialize_affine_map(m) for m in self.static_indexing_maps + ]) + raise ValueError( + f"LinalgIndexingMapsConfig must have one type of indexing map" + f"(got none)") + + +class LinalgStructuredOpConfig(YAMLObject): + """Configuration for metadata sufficient to construct a linalg single + contraction named op.""" + + yaml_tag = "!LinalgStructuredOpConfig" + + def __init__(self, + comprehension: Comprehension, + context: Optional[_ir.Context] = None): + self.context = context if context is not None else _ir.Context() + self.affine_state = AffineBuildState() + self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] + self.tensor_args = dict() # type: Dict[TensorDef, TensorDefConfig] + self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] + + # Compute the ordered set of writes. + collected_uses = set() + for write_use, read_use in zip(comprehension.definitions, + comprehension.values): + self.writes.append((write_use, read_use)) + + for write_use, read_use in self.writes: + collected_uses.add(write_use) + read_use.collect_uses(collected_uses) + + # Need to add all definitions before uses, so process twice. + for use in collected_uses: + self.add_tensor_arg(use.tensor_def) + for use in collected_uses: + self.add_use(use) + + # Now normalize all defs and uses indexing maps now that full count of + # dims and symbols are known. + for cuse in self.uses.values(): + cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) + for cdef in self.tensor_args.values(): + cdef.shape_map = self._normalize_affine_map(cdef.shape_map, + with_dims=False) + + # Now for each write use, propagate the indexing maps from the use to the + # tensor, ensuring that there are not conflicts. + for write_use, _ in self.writes: + write_tensor_def = self.tensor_args[write_use.tensor_def] + if write_tensor_def.indexing_map: + raise ValueError( + f"Unexpected multi-write to a single tensor: {write_tensor_def}") + write_tensor_def.indexing_map = self.uses[write_use].indexing_map + + # For each read use, propagate the indexing maps from the use to the + # tensor, ensuring that there are not conflicts. + for _, read_expr in self.writes: + read_uses = set() # type: Set[TensorUse] + read_expr.collect_uses(read_uses) + for read_use in read_uses: + read_tensor_def = self.tensor_args[read_use.tensor_def] + if (read_tensor_def.indexing_map and + read_tensor_def.indexing_map != self.uses[read_use].indexing_map): + raise ValueError( + f"Unexpected multi-read of a tensor with different accesses:" + f"{read_tensor_def} vs {read_use}") + read_tensor_def.indexing_map = self.uses[read_use].indexing_map + + # Sanity check that all defs have an indexing map. + assert all(d.indexing_map for d in self.tensor_args.values()), ( + f"Missing indexing map on TensorDef: {self.tensor_args}") + + # Collect reduction dims and ensure all the same. + all_reduction_dims = set(comprehension.all_reduction_dims) + if len(all_reduction_dims) != 1: + raise ValueError( + f"All writes within a generic must have the same reduction " + f"dims. Got: {all_reduction_dims}") + self.reduction_dims = next(iter(all_reduction_dims)) + + # Generate the scalar assignments (used to build a body). + self.assignments = [ + ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression()) + for write_use, read_expr in self.writes + ] + + @property + def ordered_tensor_args(self) -> Sequence[TensorDefConfig]: + return sorted(self.tensor_args.values(), + key=lambda tdc: tdc.tensor_def.registered_index) + + @property + def ordered_tensor_uses(self) -> Sequence[TensorUseConfig]: + return sorted(self.uses.values(), + key=lambda tuc: tuc.tensor_use.tensor_def.registered_index) + + @property + def ordered_dims(self) -> Sequence[Tuple[str, int]]: + """Gets the ordered list of dim bindings (symbolic name, position). + + TODO: The original parser relies on parse ordering to arrive at the + iterator types, but that ordering is not defined on the Python side, so + this may be ambiguous. + """ + return list(self.affine_state.all_dims.items()) + + @property + def indexing_maps(self) -> Sequence[_ir.AffineMap]: + return [use.indexing_map for use in self.ordered_tensor_uses] + + @property + def iterator_types(self) -> Sequence[str]: + + def get_type(symbolic_name, position): + for reduction_dim_expr in self.reduction_dims: + if reduction_dim_expr.dimname == symbolic_name: + return "reduction" + return "parallel" + + return [get_type(*dim) for dim in self.ordered_dims] + + def add_tensor_arg(self, tensor_def: TensorDef): + if tensor_def in self.tensor_args: + return + with self.context: + local_state = AffineBuildState(global_state=self.affine_state, + allow_new_dims=False) + exprs = [] + for expr in tensor_def.shape: + exprs.append(expr.build(state=local_state)) + assert local_state.local_dim_count == 0 + indexing_map = _ir.AffineMap.get(dim_count=0, + symbol_count=local_state.symbol_count, + exprs=exprs) + + def_config = TensorDefConfig(tensor_def, indexing_map) + self.tensor_args[tensor_def] = def_config + + def add_use(self, tensor_use: TensorUse): + if tensor_use in self.uses: + return + with self.context: + local_state = AffineBuildState(global_state=self.affine_state, + allow_new_symbols=False) + exprs = [] + for expr in tensor_use.indices: + exprs.append(expr.build(state=local_state)) + assert local_state.local_symbol_count == 0 + indexing_map = _ir.AffineMap.get(dim_count=local_state.dim_count, + symbol_count=local_state.symbol_count, + exprs=exprs) + + use_config = TensorUseConfig(tensor_use, indexing_map) + self.uses[tensor_use] = use_config + + def _normalize_affine_map(self, + affine_map: _ir.AffineMap, + with_dims: bool = True) -> _ir.AffineMap: + """Normalizes an indexing map to have the max known symbols and dims.""" + with self.context: + return _ir.AffineMap.get( + dim_count=self.affine_state.dim_count if with_dims else 0, + symbol_count=self.affine_state.symbol_count, + exprs=list(affine_map.results)) + + def to_yaml_custom_dict(self): + self_dict = dict( + args=self.ordered_tensor_args, + # TODO: Refactor the hierarchy internally when supporting more + # than static (preserving this serialized form). + indexing_maps=LinalgIndexingMapsConfig( + static_indexing_maps=self.indexing_maps), + iterator_types=self.iterator_types, + assignments=self.assignments, + ) + return self_dict + + def __repr__(self): + lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"] + lines.append("tensor_args=[") + for def_config in self.ordered_tensor_args: + lines.append(f" {repr(def_config)}") + lines.append("], indexing_maps=[") + for m in self.indexing_maps: + lines.append(f" {repr(m)}") + lines.append(f"], iterator_types=[") + for t in self.iterator_types: + lines.append(f" {t}") + lines.append("])") + return "\n".join(lines) + + +class LinalgOpConfig(YAMLObject): + """Container for any supported linalg op type. + + This includes the concrete type by name for ease of parsing by systems + that ignore tags. + """ + yaml_tag = "!LinalgOpConfig" + + def __init__(self, + metadata: OpMetadataDef, + *, + structured_op: Optional[LinalgStructuredOpConfig] = None): + self.metadata = metadata + self.structured_op = structured_op + + def to_yaml_custom_dict(self): + self_dict = dict(metadata=self.metadata,) + if self.structured_op: + self_dict["structured_op"] = self.structured_op + return self_dict + + @staticmethod + def from_linalg_op_def( + tc_op_def: LinalgOpDef, + context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]: + """Expands a LinalgOpDef into corresponding Linalg configured ops.""" + # TODO: Many LinalgOpDef patterns need to expand to multiple generics. + assert len( + tc_op_def.comprehensions) == 1, "Only one comprehension supported" + return [ + LinalgOpConfig(tc_op_def.metadata, + structured_op=LinalgStructuredOpConfig( + tc_op_def.comprehensions[0], context)), + ] + + def __repr__(self): + return (f"LinalgOpConfig(metadata={self.metadata},\n" + f"structured_op={self.structured_op})") diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/dsl.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/dsl.py @@ -0,0 +1,91 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Dict, List + +from contextlib import contextmanager +import functools +import inspect +import threading + +from mlir import ir +from .comprehension import * + +_CONTEXT = threading.local() + + +@contextmanager +def bind_op_def(model: LinalgOpDef): + if hasattr(_CONTEXT, "current_op_def"): + raise ValueError("Cannot recursively define an operation") + _CONTEXT.current_op_def = model + try: + yield model + finally: + del _CONTEXT.current_op_def + + +def current_op_def() -> LinalgOpDef: + try: + return _CONTEXT.current_op_def + except AttributeError: + raise ValueError( + "Attempt to access the current op definition being defined " + "but none is set. Did you mean to call this in an op definition?") + + +class DefinedOpCallable: + """Callable that wraps any defined op function.""" + + def __init__(self, op_name: str, model: LinalgOpDef): + self.op_name = op_name + self.model = model + + def __call__(self, *args, **kwargs): + # TODO: Upstream the emitter and invoke here + raise NotImplementedError("Linalg generic emission not yet implemented") + + +def linalg_structured_op(dsl_func=None, + *, + op_name=None, + op_class_name=None) -> DefinedOpCallable: + if dsl_func is None: + # Curry the keyword args in for delayed application. + return functools.partial(tc_def_op, + op_name=op_name, + op_class_name=op_class_name) + # Determine default names by introspecting the function. + if op_name is None: + op_name = dsl_func.__name__ + if op_class_name is None: + # Camel case it. + op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" + + tc_model = LinalgOpDef(name=op_name, + cpp_op_name=op_class_name, + doc=inspect.getdoc(dsl_func)) + + # Extract arguments and TensorDefs from the signature. + dsl_func_args = list() + sig = inspect.signature(dsl_func) + for param_name, param in sig.parameters.items(): + param_default = param.default + if not isinstance(param_default, TensorDef): + raise ValueError(f"@tc_def_op function parameters must be defaulted as " + f"TensorDef(...): Found {param_name}: {param_default}") + dsl_func_args.append(param_default) + tc_model.add_tensor(param_name, param_default) + + # Invoke the DSL func to finish populating the model. + with bind_op_def(tc_model): + dsl_func(*dsl_func_args) + + # TODO: The returned callable should be an IR emitter but that is not + # upstreamed yet. + return DefinedOpCallable(op_name, tc_model) + + +def implements(*interfaces: OpInterfaceDef): + current_op_def().metadata.implements.extend(interfaces) diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/scalar_expr.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/scalar_expr.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/scalar_expr.py @@ -0,0 +1,124 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Models DAGs of scalar math expressions. + +Used for generating region bodies at the "math" level where they are still type +polymorphic. This is modeled to be polymorphic by attribute name for interop +with serialization schemes that are just plain-old-dicts. + +These classes are typically not user accessed and are created as a by-product +of interpreting a comprehension DSL and model the operations to perform in the +op body. The class hierarchy is laid out to map well to a form of YAML that +can be easily consumed from the C++ side, not necessarily for ergonomics. +""" + +from typing import Optional, Sequence + +from .yaml_helper import * +from .types import * + +__all__ = [ + "ScalarAssign", + "ScalarApplyFn", + "ScalarArg", + "ScalarExpression", + "ScalarSymbolicCast", +] + + +class ScalarApplyFn: + """A type of ScalarExpression that applies a named function to operands.""" + + def __init__(self, fn_name: str, *operands: "ScalarExpression"): + self.fn_name = fn_name + self.operands = operands + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_apply=self) + + def __repr__(self): + return f"ScalarApplyFn<{self.fn_name}>({', '.join(self.operands)})" + + +class ScalarArg: + """A type of ScalarExpression that references a named argument.""" + + def __init__(self, arg: str): + self.arg = arg + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_arg=self) + + def __repr__(self): + return f"(ScalarArg({self.arg})" + + +class ScalarSymbolicCast: + """A type of ScalarExpression that symbolically casts an operand to a TypeVar. + """ + + def __init__(self, to_type: TypeVar, operand: "ScalarExpression"): + self.to_type = to_type + self.operand = operand + + def expr(self) -> "ScalarExpression": + return ScalarExpression(symbolic_cast=self) + + def __repr__(self): + return f"ScalarSymbolicCast({self.to_type}, {self.operand})" + + +class ScalarExpression(YAMLObject): + """An expression on scalar values. + + Can be one of: + - ScalarApplyFn + - ScalarArg + - ScalarSymbolicCast + """ + yaml_tag = "!ScalarExpression" + + def __init__(self, + scalar_apply: Optional[ScalarApplyFn] = None, + scalar_arg: Optional[ScalarArg] = None, + symbolic_cast: Optional[ScalarSymbolicCast] = None): + if (bool(scalar_apply) + bool(scalar_arg) + bool(symbolic_cast)) != 1: + raise ValueError( + "One of 'scalar_apply', 'scalar_block_arg', 'symbolic_cast' must be " + "specified") + self.scalar_apply = scalar_apply + self.scalar_arg = scalar_arg + self.symbolic_cast = symbolic_cast + + def to_yaml_custom_dict(self): + if self.scalar_apply: + return dict(scalar_apply=dict( + fn_name=self.scalar_apply.fn_name, + operands=list(self.scalar_apply.operands), + )) + elif self.scalar_arg: + return dict(scalar_arg=self.scalar_arg.arg) + elif self.symbolic_cast: + # Note that even though operands must be arity 1, we write it the + # same way as for apply because it allows handling code to be more + # generic vs having a special form. + return dict(symbolic_cast=dict(type_var=self.symbolic_cast.to_type.name, + operands=[self.symbolic_cast.operand])) + else: + raise ValueError(f"Unexpected ScalarExpression type: {self}") + + +class ScalarAssign(YAMLObject): + """An assignment to a named argument (LHS of a comprehension).""" + yaml_tag = "!ScalarAssign" + + def __init__(self, arg: str, value: ScalarExpression): + self.arg = arg + self.value = value + + def to_yaml_custom_dict(self): + return dict(arg=self.arg, value=self.value) + + def __repr__(self): + return f"ScalarAssign({self.arg}, {self.value})" diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/types.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/types.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/types.py @@ -0,0 +1,69 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Facility for symbolically referencing type variables. + +Type variables are instances of the TypeVar class, which is uniqued by name. +An "expando" accessor `TV` is provided that generates a named TypeVar for +any attribute access: + + >>> TV.T + TypeVar(T) + >>> TV.T is TV.U + False + >>> TV.T is TV.T + True +""" + +from enum import Enum +from typing import Dict + +__all__ = [ + "TypeVar", + "TV", + + # TypeVar aliases. + "T", + "U", + "V", +] + + +class TypeVar: + """A replaceable type variable. + + Type variables are uniqued by name. + """ + ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"] + + def __new__(cls, name: str): + existing = cls.ALL_TYPEVARS.get(name) + if existing is not None: + return existing + new = super().__new__(cls) + new.name = name + cls.ALL_TYPEVARS[name] = new + return new + + def __repr__(self): + return f"TypeVar({self.name})" + + @classmethod + def create_expando(cls): + """Create an expando class that creates unique type vars on attr access.""" + + class ExpandoTypeVars: + + def __getattr__(self, n): + return cls(n) + + return ExpandoTypeVars() + + +# Expando access via TV.foo +TV = TypeVar.create_expando() + +# Some common type name aliases. +T = TV.T +U = TV.U +V = TV.V diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/yaml_helper.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/yaml_helper.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/yaml_helper.py @@ -0,0 +1,54 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""YAML serialization is routed through here to centralize common logic.""" + +import sys + +try: + import yaml +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"This tool requires PyYAML but it was not installed. " + f"Recommend: {sys.executable} -m pip install PyYAML") from e + +__all__ = [ + "yaml_dump", + "yaml_dump_all", + "YAMLObject", +] + + +class YAMLObject(yaml.YAMLObject): + + @classmethod + def to_yaml(cls, dumper, self): + """Default to a custom dictionary mapping.""" + return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict()) + + def to_yaml_custom_dict(self): + raise NotImplementedError() + + def as_linalg_yaml(self): + return yaml_dump(self) + + +def multiline_str_representer(dumper, data): + if len(data.splitlines()) > 1: + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + else: + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + + +yaml.add_representer(str, multiline_str_representer) + + +def yaml_dump(data, sort_keys=False, **kwargs): + return yaml.dump(data, sort_keys=sort_keys, **kwargs) + + +def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs): + return yaml.dump_all(data, + sort_keys=sort_keys, + explicit_start=explicit_start, + **kwargs) diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/__init__.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/__init__.py new file mode 100644 diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/core_named_ops.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/core_named_ops.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/core_named_ops.py @@ -0,0 +1,70 @@ +from ..lang import * + +T1 = TV.T1 +T2 = TV.T2 + +Batch = S.Batch + + +@linalg_structured_op +def matmul(A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + """Performs a matrix multiplacation of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + + +@linalg_structured_op +def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, Batch, S.M, S.N, output=True)): + """Performs a batched matrix multiplacation of two 3D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ContractionOpInterface) + C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n]) + + +@linalg_structured_op +def matvec(A=TensorDef(T1, S.M, S.N), + y=TensorDef(T2, S.N), + x=TensorDef(U, S.M, output=True)): + """Performs a matrix-vector multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ContractionOpInterface) + x[D.m] += cast(U, A[D.m, D.n]) * cast(U, y[D.n]) + + +@linalg_structured_op +def vecmat(y=TensorDef(T1, S.M), + A=TensorDef(T2, S.M, S.N), + x=TensorDef(U, S.N, output=True)): + """Performs a vector-matrix multiplacation. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ContractionOpInterface) + x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n]) + + +@linalg_structured_op +def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, + output=True)): + """Performs a dot product of two vectors to a scalar result. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ContractionOpInterface) + C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) diff --git a/mlir/test/Bindings/Python/tools/linalg_opdsl/assignments.py b/mlir/test/Bindings/Python/tools/linalg_opdsl/assignments.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/tools/linalg_opdsl/assignments.py @@ -0,0 +1,29 @@ +# RUN: %PYTHON -m mlir.tools.linalg_opdsl.dump_oplib --file %s | FileCheck %s + +from mlir.tools.linalg_opdsl.lang import * + +# CHECK: --- +# CHECK-LABEL: matmul +# CHECK: assignments: +# CHECK: - +# CHECK: arg: C +# CHECK: value: +# CHECK: scalar_apply: +# CHECK: fn_name: add +# CHECK: operands: +# CHECK: scalar_apply: +# CHECK: fn_name: mul +# CHECK: operands: +# CHECK: symbolic_cast: +# CHECK: type_var: U +# CHECK: operands: +# CHECK: scalar_arg: A +# CHECK: symbolic_cast: +# CHECK: type_var: U +# CHECK: operands: +# CHECK: scalar_arg: B +@linalg_structured_op +def matmul(A=TensorDef(T, S.M, S.K), + B=TensorDef(T, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) diff --git a/mlir/test/Bindings/Python/tools/linalg_opdsl/doctests.py b/mlir/test/Bindings/Python/tools/linalg_opdsl/doctests.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/tools/linalg_opdsl/doctests.py @@ -0,0 +1,13 @@ +# RUN: %PYTHON %s + +import doctest +import importlib + +def test_module(module_name): + print(f"--- Testing module: {module_name}") + m = importlib.import_module(module_name) + doctest.testmod(m, verbose=True, raise_on_error=True, report=True) + + +test_module("mlir.tools.linalg_opdsl.lang.affine") +test_module("mlir.tools.linalg_opdsl.lang.types") diff --git a/mlir/test/Bindings/Python/tools/linalg_opdsl/interfaces.py b/mlir/test/Bindings/Python/tools/linalg_opdsl/interfaces.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/tools/linalg_opdsl/interfaces.py @@ -0,0 +1,14 @@ +# RUN: %PYTHON -m mlir.tools.linalg_opdsl.dump_oplib --file %s | FileCheck %s + +from mlir.tools.linalg_opdsl.lang import * + +# CHECK: --- +# CHECK-LABEL: matmul +# CHECK: implements: +# CHECK-NEXT: - LinalgContractionOpInterface +@linalg_structured_op +def matmul(A=TensorDef(T, S.M, S.K), + B=TensorDef(T, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) diff --git a/mlir/test/Bindings/Python/tools/linalg_opdsl/lit.local.cfg b/mlir/test/Bindings/Python/tools/linalg_opdsl/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/tools/linalg_opdsl/lit.local.cfg @@ -0,0 +1,9 @@ +# TODO: This tool requires PyYAML, which is not yet a required build/test +# dependency. Remove this exclusion once it is a required dep. + +# Since both lit and the python bindings use the same python interpreter, +# we can just check whether yaml can be imported here and exclude if not. +try: + import yaml +except ModuleNotFoundError: + config.unsupported = True diff --git a/mlir/test/Bindings/Python/tools/linalg_opdsl/shape_maps_iteration.py b/mlir/test/Bindings/Python/tools/linalg_opdsl/shape_maps_iteration.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/tools/linalg_opdsl/shape_maps_iteration.py @@ -0,0 +1,43 @@ +# RUN: %PYTHON -m mlir.tools.linalg_opdsl.dump_oplib --file %s | FileCheck %s + +from mlir.tools.linalg_opdsl.lang import * + + +# Verify that simple case with iteration order defined lexically and reduction +# dims auto discovered emits the right shape, indexing maps and iterator types. +# CHECK: --- +# CHECK-LABEL: matmul +# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s2)> +# CHECK: shape: affine_map<()[s0, s1, s2] -> (s2, s1)> +# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s1)> +# CHECK: static_indexing_maps: +# CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> +# CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)> +# CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)> +# CHECK: iterator_types: +# CHECK-NEXT: - parallel +# CHECK-NEXT: - parallel +# CHECK-NEXT: - reduction +@linalg_structured_op +def matmul(A=TensorDef(T, S.M, S.K), + B=TensorDef(T, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + + +# Verifies that assignment to a scalar (represented as [None]) is represented +# correctly. +# CHECK: --- +# CHECK-LABEL: dot +# CHECK: shape: affine_map<()[s0] -> (s0)> +# CHECK: shape: affine_map<()[s0] -> (s0)> +# CHECK: shape: affine_map<()[s0] -> ()> +# CHECK: static_indexing_maps: +# CHECK-NEXT: - affine_map<(d0)[s0] -> (d0)> +# CHECK-NEXT: - affine_map<(d0)[s0] -> (d0)> +# CHECK-NEXT: - affine_map<(d0)[s0] -> ()> +# CHECK: iterator_types: +# CHECK-NEXT: - reduction +@linalg_structured_op +def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)): + C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) diff --git a/mlir/test/Bindings/Python/tools/linalg_opdsl/test_core_named_ops.py b/mlir/test/Bindings/Python/tools/linalg_opdsl/test_core_named_ops.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/tools/linalg_opdsl/test_core_named_ops.py @@ -0,0 +1,4 @@ +# RUN: %PYTHON -m mlir.tools.linalg_opdsl.dump_oplib .ops.core_named_ops | FileCheck %s + +# Just verify that at least one known op is generated. +# CHECK: name: matmul