diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -675,6 +675,118 @@
Most of the above documentation roughly applies to this path and will be ported
as migration continues.
+### Op Specification via Python based Op DSL
+
+The tool `linalg_opdsl` provides a high level DSL for constructing
+structured op definitions in a way that can be exported to built-in, named
+structured ops via the above YAML-based definitions or used interactively to
+emit corresponding `linalg.generic` IR for the composition.
+
+#### Language Guide
+
+This tool is new and rapidly evolving. For language examples, refer to the
+built-in ops in the `mlir.tools.linalg_opdsl.ops` package
+(`lib/Bindings/Python/mlir/tools/linalg_opdsl/ops` in the repository).
+
+Using a matmul as an example, we will decompose the language:
+
+```python
+T1 = TV.T1
+T2 = TV.T2
+
+@linalg_structured_op
+def matmul(A=TensorDef(T1, S.M, S.K),
+ B=TensorDef(T2, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
+ """Performs a matrix multiplacation of two 2D inputs.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ContractionOpInterface)
+ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+```
+
+Here we have a simple type polymorphic contraction that takes arguments `A`
+and `B` and outputs `C`. Each is bound to a `TensorDef`, which specifies:
+
+* The symbolic element type (`T1`, `T2`, `U` above).
+* Symbolic shape expressions with symbols that are bound globally for the op (
+note that in this simple example, the shape expressions are just symbol
+references, but they are permitted to be a constrained set of affine
+expressions).
+* Usage (`output=True`).
+
+The docstring will be transferred to the op definition verbatim.
+
+Special identifying op interfaces can be declared for the op via
+`implements(interface1[, interface2...])`.
+
+##### Assignments
+
+The bulk of language consists of assignment expressions of the form above.
+The iteration dimension order is determined lexically based on the order
+encountered in the expression (following operator precedence if math operators
+are used). TODO: Introduce a directive to fix the dimension bindings.
+
+Reduction dimensions are inferred to be any dimensions on the RHS that are not
+on the LHS.
+
+A number of arithmetic primitive functions are supported:
+
+* `PrimFn.add(a, b)` (also via overloading the binary `+` operator)
+* `PrimFn.exp(a)`
+* `PrimFn.log(a)`
+* `PrimFn.mul(a, b)` (also via overloading the binary `*` operator)
+* `PrimFn.max(a, b)`
+* `PrimFn.sub(a, b)` (also via overloading the binary `-` operator)
+
+Reduction functions can appear as the outer-most function on the RHS:
+
+* `ReduceFn.add` (also overloading the inplace `+=` on a LHS)
+* `ReduceFn.mul`
+* `ReduceFn.max`
+
+There are also special forms:
+
+* `cast(TypeVar, operand)`
+
+##### Types
+
+All types in assignment expressions are late bound based on actual input
+and output types of constructed ops. Assignment expressions with no `cast`
+calls will generally require uniform types throughout and will fail to
+verify if violated. The presence of a `cast` allows for a limited form of
+numeric type conversion between element types that can be derived from inputs
+and outputs (and in the future, attributes). `cast` calls with a `TypeVar`
+first argument are emitted as `symbolic_cast` primitives in the YAML definition.
+
+Casting will perform `int<->float` type conversions and will perform any
+necessary extension or truncation within type family. Note that presently,
+any integer type is assumed to be signed for the purpose of determing how to
+extend or truncate. Supporting unsigned integer types is left for future work.
+
+Not all functions are applicable for all numeric types, and on mismatch, op
+verification will fail.
+
+#### Basic usage
+
+To use from the build tree, MLIR must be build with Python bindings enabled
+(`-DMLIR_BINDINGS_PYTHON_ENABLED=ON`). Then add the `python` directory in the
+build tree to your `PYTHONPATH` environment variable (i.e.
+`export PYTHONPATH=$PWD/build/python`). Optionally, use an installed MLIR
+package, if available, to avoid building.
+
+```shell
+# Dump the `core_named_ops.py` module as YAML.
+python -m python -m mlir.tools.linalg_opdsl.dump_oplib .ops.core_named_ops
+```
+
+The `linalg_opdsl` tool is meant as both a development and runtime tool, not
+a build tool: in order to export static named op definitions to be built as
+part of the compiler, the corresponding Linalg dialect YAML file must be
+updated and reviewed. TODO: Develop a script to automate updates to these files.
+
## Open Issues and Design Alternatives
Multiple open issues and design alternatives are in flight and it is time to lay
diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -5,18 +5,9 @@
# Copy python source tree.
################################################################################
-set(PY_SRC_FILES
- mlir/__init__.py
- mlir/_dlloader.py
- mlir/conversions/__init__.py
- mlir/dialects/__init__.py
- mlir/dialects/_linalg.py
- mlir/dialects/_builtin.py
- mlir/ir.py
- mlir/execution_engine.py
- mlir/passmanager.py
- mlir/transforms/__init__.py
-)
+file(GLOB_RECURSE PY_SRC_FILES
+ RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
+ "${CMAKE_CURRENT_SOURCE_DIR}/mlir/*.py")
add_custom_target(MLIRBindingsPythonSources ALL
DEPENDS ${PY_SRC_FILES}
@@ -25,11 +16,13 @@
foreach(PY_SRC_FILE ${PY_SRC_FILES})
set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}")
+ get_filename_component(PY_DEST_DIR "${PY_DEST_FILE}" DIRECTORY)
+ file(MAKE_DIRECTORY "${PY_DEST_DIR}")
add_custom_command(
TARGET MLIRBindingsPythonSources PRE_BUILD
COMMENT "Copying python source ${PY_SRC_FILE} -> ${PY_DEST_FILE}"
DEPENDS "${PY_SRC_FILE}"
- COMMAND "${CMAKE_COMMAND}" -E copy_if_different
+ COMMAND "${CMAKE_COMMAND}" -E create_symlink
"${CMAKE_CURRENT_SOURCE_DIR}/${PY_SRC_FILE}" "${PY_DEST_FILE}"
)
endforeach()
diff --git a/mlir/lib/Bindings/Python/mlir/tools/__init__.py b/mlir/lib/Bindings/Python/mlir/tools/__init__.py
new file mode 100644
diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/__init__.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/__init__.py
new file mode 100644
diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/dump_oplib.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/dump_oplib.py
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/dump_oplib.py
@@ -0,0 +1,90 @@
+#!/usr/bin/which python
+# Command line tool to load an oplib module and dump all of the operations
+# it contains in some format.
+"""Loads one or more modules containing op definitions and dumps them.
+
+The dump format can be:
+
+* `--dump_format=yaml` (default)
+* `--dump_format=repr`
+
+Positional arguments are interpreted as module names (optionally, relative to
+this module). Loose module files can be specified via `--file `.
+
+Sample usage:
+ # Dump the YAML op definitions for the core named ops (as in the dialect
+ # source tree).
+ python -m mlir.tools.linalg_opdsl.dump_oplib .ops.core_named_ops
+
+Note: YAML output is emitted in "document list" format with each operation
+as its own "document". Practically, this means that each operation (or group
+of composite ops) is emitted with a "---" preceding it, which can be useful
+for testing.
+"""
+
+import argparse
+import importlib
+
+from .lang import *
+from .lang.config import *
+from .lang.yaml_helper import *
+
+
+def create_arg_parser() -> argparse.ArgumentParser:
+ p = argparse.ArgumentParser(description="Dump an oplib in various formats")
+ p.add_argument("modules",
+ metavar="M",
+ type=str,
+ nargs="*",
+ help="Op module to dump")
+ p.add_argument("--file",
+ metavar="F",
+ type=str,
+ nargs="*",
+ help="Python op file to dump")
+ p.add_argument("--format",
+ type=str,
+ dest="format",
+ default="yaml",
+ choices=("yaml", "repr"),
+ help="Format in which to dump")
+ return p
+
+
+def load_module_from_file(module_name, file_path):
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ m = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(m)
+ return m
+
+
+def main(args):
+ # Load all configs.
+ configs = []
+ modules = []
+ for module_name in args.modules:
+ modules.append(
+ importlib.import_module(module_name, package="mlir.tools.linalg_opdsl"))
+ for i, file_path in enumerate(args.file or []):
+ modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path))
+ for m in modules:
+ for attr_name, value in m.__dict__.items():
+ # TODO: This class layering is awkward.
+ if isinstance(value, DefinedOpCallable):
+ try:
+ linalg_config = LinalgOpConfig.from_linalg_op_def(value.model)
+ except Exception as e:
+ raise ValueError(
+ f"Could not create LinalgOpConfig from {value.model}") from e
+ configs.extend(linalg_config)
+
+ # Print.
+ if args.format == "yaml":
+ print(yaml_dump_all(configs))
+ elif args.format == "repr":
+ for config in configs:
+ print(repr(config))
+
+
+if __name__ == "__main__":
+ main(create_arg_parser().parse_args())
diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/__init__.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/__init__.py
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/__init__.py
@@ -0,0 +1 @@
+from .dsl import *
diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/affine.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/affine.py
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/affine.py
@@ -0,0 +1,307 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""DSL for constructing affine expressions and maps.
+
+These python wrappers allow construction of affine expressions in a more
+pythonic fashion that is later instantiated as an IR AffineExpr. Separating the
+AST from construction of the map allows for manipulations of symbols and dims
+beyond the scope of one expression.
+
+Affine expression construction:
+ >>> with _ir.Context():
+ ... s = AffineBuildState()
+ ... (S.K + S.M).build(s)
+ ... (S.K * S.M).build(s)
+ ... (S.K // S.M).build(s)
+ ... (S.K / S.M).build(s)
+ ... (S.K % 4).build(s)
+ ... (D.i + D.j * 4).build(s)
+ ... s
+ AffineExpr(s0 + s1)
+ AffineExpr(s0 * s1)
+ AffineExpr(s0 floordiv s1)
+ AffineExpr(s0 ceildiv s1)
+ AffineExpr(s0 mod 4)
+ AffineExpr(d0 + d1 * 4)
+ AffineBuildState<
+ symbols={'K': 0, 'M': 1}
+ dims={'i': 0, 'j': 1}>
+
+In the DSL, dimensions and symbols are name-uniqued instances of DimDef and
+SymbolDef. There are shortcut "expando" instances that will create a
+corresponding DimDef/SymbolDef upon accessing an attribute:
+
+Referencing a named dimension:
+
+ >>> D.i
+ Dim(i)
+ >>> D.a is D.b
+ False
+ >>> D.a is D.a
+ True
+
+Referencing a named symbol:
+
+ >>> S.foobar
+ Symbol(foobar)
+ >>> S.a is S.b
+ False
+ >>> S.a is S.a
+ True
+"""
+
+from typing import Callable, Dict, Optional, Tuple, Union
+
+from mlir import ir as _ir
+
+__all__ = [
+ "AffineBuildState",
+ "AffineExprDef",
+ "D",
+ "DimDef",
+ "S",
+ "SymbolDef",
+]
+
+# Type aliases.
+SymbolPosMap = Dict[str, int]
+
+
+class AffineBuildState:
+ """Internal state for the AffineExprDef._create impls."""
+
+ def __init__(self,
+ *,
+ global_state: "AffineBuildState" = None,
+ allow_new_symbols: bool = True,
+ allow_new_dims: bool = True):
+ if not global_state:
+ self.all_symbols = dict() # type: Dict[str, int]
+ self.all_dims = dict() # type: Dict[str, int]
+ else:
+ # Alias the global dict.
+ self.all_symbols = global_state.all_symbols
+ self.all_dims = global_state.all_dims
+
+ # Map of symbols and dims in the current build.
+ self.local_symbols = dict() # type: Dict[str, int]
+ self.local_dims = dict() # type: Dict[str, int]
+ self.allow_new_symbols = allow_new_symbols
+ self.allow_new_dims = allow_new_dims
+
+ def get_dim(self, dimname: str) -> int:
+ """Gets the dim position given a name."""
+ pos = self.all_dims.get(dimname)
+ if pos is None:
+ if not self.allow_new_dims:
+ raise ValueError(
+ f"New dimensions not allowed in the current affine expression: "
+ f"Requested '{dimname}', Availble: {self.all_dims}")
+ pos = len(self.all_dims)
+ self.all_dims[dimname] = pos
+ self.local_dims[dimname] = pos
+ return pos
+
+ def get_symbol(self, symname: str) -> int:
+ """Geta a symbol position given a name."""
+ pos = self.all_symbols.get(symname)
+ if pos is None:
+ if not self.allow_new_symbols:
+ raise ValueError(
+ f"New symbols not allowed in the current affine expression: "
+ f"Requested '{symname}', Availble: {self.all_symbols}")
+ pos = len(self.all_symbols)
+ self.all_symbols[symname] = pos
+ self.local_symbols[symname] = pos
+ return pos
+
+ @property
+ def local_dim_count(self) -> int:
+ return len(self.local_dims)
+
+ @property
+ def local_symbol_count(self) -> int:
+ return len(self.local_symbols)
+
+ @property
+ def dim_count(self) -> int:
+ return len(self.all_dims)
+
+ @property
+ def symbol_count(self) -> int:
+ return len(self.all_symbols)
+
+ def __repr__(self):
+ lines = [f"AffineBuildState<"]
+ lines.append(f" symbols={self.local_symbols}")
+ lines.append(f" dims={self.local_dims}>")
+ return "\n".join(lines)
+
+
+class AffineExprDef:
+ """Base class for an affine expression being defined."""
+
+ def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr:
+ """Builds the corresponding _ir.AffineExpr from the definitions.
+ """
+ state = AffineBuildState() if state is None else state
+ expr = self._create(state)
+ return expr
+
+ def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+ raise NotImplementedError()
+
+ @staticmethod
+ def coerce_from(py_value):
+ if isinstance(py_value, int):
+ return AffineConstantExpr(py_value)
+ assert isinstance(py_value, AffineExprDef)
+ return py_value
+
+ def visit_affine_exprs(self, callback):
+ """Visits all AffineExprDefs including self."""
+ callback(self)
+
+ def __add__(lhs, rhs):
+ rhs = AffineExprDef.coerce_from(rhs)
+ return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs)
+
+ def __mul__(lhs, rhs):
+ rhs = AffineExprDef.coerce_from(rhs)
+ return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs)
+
+ def __mod__(lhs, rhs):
+ rhs = AffineExprDef.coerce_from(rhs)
+ return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs)
+
+ def __floordiv__(lhs, rhs):
+ rhs = AffineExprDef.coerce_from(rhs)
+ return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs)
+
+ def __truediv__(lhs, rhs):
+ # TODO: Not really a ceil div - taking liberties for the DSL.
+ rhs = AffineExprDef.coerce_from(rhs)
+ return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs)
+
+
+class AffineConstantExpr(AffineExprDef):
+ """An affine constant being defined."""
+
+ def __init__(self, value: int):
+ assert isinstance(value, int)
+ self.value = value
+
+ def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+ return _ir.AffineConstantExpr.get(self.value)
+
+ def __repr__(self):
+ return f"Const({self.value})"
+
+
+class AffineBinaryExprDef(AffineExprDef):
+ """An affine binary expression being defined."""
+
+ def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef):
+ self.ir_ctor = ir_ctor
+ self.lhs = lhs
+ self.rhs = rhs
+
+ def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+ return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state))
+
+ def visit_affine_exprs(self, callback):
+ """Visits all AffineExprDefs including self."""
+ super().visit_affine_exprs(callback)
+ self.lhs.visit_affine_exprs(callback)
+ self.rhs.visit_affine_exprs(callback)
+
+ def __repr__(self):
+ return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})"
+
+
+class DimDef(AffineExprDef):
+ """Represents a named dimension.
+
+ """
+ ALL_DIMS = dict() # type: Dict[str, "DimDef"]
+ dimname: str
+
+ def __new__(cls, dimname: str):
+ existing = cls.ALL_DIMS.get(dimname)
+ if existing is not None:
+ return existing
+ new = super().__new__(cls)
+ new.dimname = dimname
+ cls.ALL_DIMS[dimname] = new
+ return new
+
+ def __repr__(self):
+ return f"Dim({self.dimname})"
+
+ def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+ pos = state.get_dim(self.dimname)
+ return _ir.AffineDimExpr.get(position=pos)
+
+ @classmethod
+ def create_expando(cls):
+ """Create an expando class that creates unique symbols based on attr access.
+ """
+
+ class ExpandoDims:
+
+ def __getattr__(self, n):
+ return cls(n)
+
+ return ExpandoDims()
+
+
+D = DimDef.create_expando()
+
+
+class SymbolDef(AffineExprDef):
+ """Represents a named symbol.
+
+ >>> s1 = SymbolDef("s1")
+ >>> s1
+ Symbol(s1)
+ >>> s2 = SymbolDef("s2")
+ >>> s1 is s2
+ False
+ >>> s1 is SymbolDef("s1")
+ True
+ """
+ ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"]
+ symname: str
+
+ def __new__(cls, symname: str):
+ existing = cls.ALL_SYMBOLS.get(symname)
+ if existing is not None:
+ return existing
+ new = super().__new__(cls)
+ new.symname = symname
+ cls.ALL_SYMBOLS[symname] = new
+ return new
+
+ def __repr__(self):
+ return f"Symbol({self.symname})"
+
+ def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+ pos = state.get_symbol(self.symname)
+ return _ir.AffineSymbolExpr.get(position=pos)
+
+ @classmethod
+ def create_expando(cls):
+ """Create an expando class that creates unique symbols based on attr access.
+ """
+
+ class ExpandoSymbols:
+
+ def __getattr__(self, n):
+ return cls(n)
+
+ return ExpandoSymbols()
+
+
+# Global accessor for on-demand symbols.
+S = SymbolDef.create_expando()
diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/comprehension.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/comprehension.py
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/comprehension.py
@@ -0,0 +1,425 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""Model classes representing a tensor-comprehension.
+
+These classes model the language more at an AST level as evaluated. Reasoning
+about it typically involves processing this form into config objects that
+represent actual op definitions (i.e. YAML).
+"""
+
+from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
+
+from mlir import ir as _ir
+
+from .affine import *
+from .scalar_expr import *
+from .types import *
+from .yaml_helper import *
+
+# Type aliases.
+AffineDimList = Dict[str, _ir.AffineExpr]
+
+
+class TensorExpression:
+ """An expression that can appear on the RHS of a comprehension."""
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ raise NotImplementedError()
+
+ def visit_affine_exprs(self, callback):
+ """Visits all affine expressions reachable by the expression."""
+ pass
+
+ def _get_all_dim_defs(self) -> Set[DimDef]:
+ """Recursively gets all DimDef affine expressions that are referenced."""
+ results = set()
+
+ def visitor(affine_expr):
+ if isinstance(affine_expr, DimDef):
+ results.add(affine_expr)
+
+ self.visit_affine_exprs(visitor)
+ return results
+
+ def collect_uses(self, uses: Set["TensorUse"]):
+ """Collects all TensorUses reachable through this expression."""
+ pass
+
+ def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
+ return PrimFn.add(self, rhs)
+
+ def __mul__(self, rhs) -> "TensorExpression":
+ return PrimFn.mul(self, rhs)
+
+ def __sub__(self, rhs) -> "TensorExpression":
+ return PrimFn.sub(self, rhs)
+
+ def __hash__(self):
+ return hash(id(self))
+
+
+class TensorUse(TensorExpression):
+ """A used tensor represented by its (tensor_name, indices).
+
+ Note that forming a comprehension via direct assignment is performed through
+ __setitem__ on the TensorDef level. However, performing a reduction with
+ compound ops (+=, *=, etc) is done by doing a:
+ TensorDef.__getitem__
+ TensorUse.__iadd__
+ TensorDef.__setitem__
+ """
+
+ def __init__(self, tensor_def: "TensorDef", indices: Sequence[AffineExprDef]):
+ self.tensor_def = tensor_def
+ self.indices = tuple(indices)
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ assert self.tensor_def.tensor_name is not None
+ return ScalarArg(self.tensor_def.tensor_name).expr()
+
+ @property
+ def tensor_name(self) -> str:
+ n = self.tensor_def.tensor_name
+ assert n is not None, "TensorDef not attached"
+ return n
+
+ def visit_affine_exprs(self, callback):
+ for ind in self.indices:
+ ind.visit_affine_exprs(callback)
+
+ def collect_uses(self, uses: Set["TensorUse"]):
+ uses.add(self)
+
+ def __iadd__(self, rhs: TensorExpression) -> TensorExpression:
+ return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs)
+
+ def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
+ """For implicit reductions, computes default reduction dims.
+
+ Assumes that the rhs is the expression being reduced and self is being
+ reduced into. Any indices referenced on the rhs and not in self are
+ considered reduction dims and will be ordered as encountered on the rhs.
+ """
+ rhs_dims = rhs._get_all_dim_defs()
+ lhs_dims = self._get_all_dim_defs()
+ return rhs_dims - lhs_dims
+
+ def __repr__(self):
+ return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]"
+
+
+class TensorDef:
+ """Bookkeeping of a single registered tensor, held in dict by name."""
+
+ def __init__(self,
+ type_var: TypeVar,
+ *shape: AffineExprDef,
+ indexing_map: Optional[_ir.AffineMap] = None,
+ output: bool = False):
+ if not isinstance(type_var, TypeVar):
+ raise ValueError(f"TensorDef requires a TypeVar. Got: {repr(type_var)}")
+ self.owner = None # type: Optional["LinalgOpDef"]
+ self.type_var = type_var
+ self.shape = shape
+ self.indexing_map = indexing_map
+ self.output = output
+ self.tensor_name = None # type: Optional[str]
+ self.registered_index = -1 # type: int
+
+ @property
+ def rank(self) -> int:
+ """The rank of the tensor."""
+ return len(self.shape)
+
+ def attach(self, index: int, tensor_name: str, owner: "LinalgOpDef"):
+ if self.owner:
+ raise ValueError(f"TensorDef already registered with op: {self}")
+ self.registered_index = index
+ self.tensor_name = tensor_name
+ self.owner = owner
+
+ def __getitem__(self, dims) -> TensorUse:
+ assert self.owner, "TensorDef is not attached to an op"
+ state = AffineBuildState(global_state=self.owner._affine_state,
+ allow_new_symbols=False)
+ if not isinstance(dims, tuple):
+ dims = (dims,) # Handle single subscript case.
+ # Special case: (None) is a 0d-scalar use.
+ if dims == (None,):
+ dims = ()
+
+ exprs = []
+ for expr_def in dims:
+ if not isinstance(expr_def, AffineExprDef):
+ raise KeyError(
+ "A TensorDef can only be subscripted by a tuple of affine dims")
+ exprs.append(expr_def)
+ return TensorUse(self, exprs)
+
+ def __setitem__(self, dims, value):
+ """Creates a new 1:1 comprehension by binding this tensor to an expression.
+
+ Note that due to the way assignment works in Python, we have to capture
+ direct assignment as a setitem on the TensorDef.
+ """
+ if not isinstance(value, TensorExpression):
+ raise ValueError(f"Only TensorExpressions can be assigned to TensorDefs. "
+ f"Got: {repr(value)}")
+ use = self[dims]
+ comp = Comprehension((use, value))
+ self.owner.comprehensions.append(comp)
+
+ def __hash__(self):
+ return hash(id(self))
+
+ def __repr__(self):
+ output = "OUTPUT " if self.output else ""
+ return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, "
+ f"shape={self.shape})")
+
+
+class Comprehension:
+ """Represents a single comprehension."""
+
+ def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
+ self.definitions = list() # List[TensorUse]
+ self.values = list() # List[TensorExpression]
+
+ # Find the lhs to reduction rhs.
+ for assign, value in bindings:
+ if isinstance(value, ReduceApply):
+ if value.lhs:
+ raise ValueError(f"Reduction expression already assigns: {value}")
+ value.lhs = assign
+ self.definitions.append(assign)
+ self.values.append(value)
+
+ @property
+ def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
+ """Gets the reduction dims for the comprehension or None."""
+ result = set()
+ for use in self.values:
+ if isinstance(use, ReduceApply):
+ result.add(use.reduce.reduce_dims)
+ else:
+ result.add(tuple())
+ return result
+
+ def __repr__(self):
+ if len(self.definitions) > 1:
+ defs_repr = f"({', '.join(repr(d) for d in self.definitions)})"
+ values_repr = f"({', '.join(repr(v) for v in self.values)})"
+ else:
+ defs_repr = f"{repr(self.definitions[0])}"
+ values_repr = f"{repr(self.values[0])}"
+
+ return f"{defs_repr} = {values_repr}"
+
+
+class PrimFnType:
+ """Primitive operations."""
+
+ def __init__(self, prim_name: str):
+ self.prim_name = prim_name
+
+ def __call__(self, *args):
+ return PrimApply(self, args)
+
+ def reduce(self, *reduce_dims: DimDef):
+ """Shortcut to create a Reduce operation from this primitive."""
+ return ReduceFnType(self, *reduce_dims)
+
+ def __repr__(self):
+ return f"{self.prim_name}"
+
+
+class PrimFn:
+ add = PrimFnType("add")
+ exp = PrimFnType("exp")
+ log = PrimFnType("log")
+ mul = PrimFnType("mul")
+ max = PrimFnType("max")
+ sub = PrimFnType("sub")
+
+
+class ReduceFnType:
+ """A reduction operator that reduces into its LHS from its RHS."""
+
+ def __init__(self, operator: PrimFnType, *reduce_dims: DimDef):
+ """Initializes the ReduceFn with a primitive function and dims."""
+ if not isinstance(operator, PrimFnType):
+ raise ValueError(f"Reduce expected a Prim operator. Got: {operator}")
+ self.operator = operator
+ self.reduce_dims = tuple(reduce_dims)
+
+ def __call__(self, *args: TensorExpression):
+ return ReduceApply(self, args)
+
+ def __repr__(self):
+ return (f"reduce_{self.operator.prim_name}"
+ f"({', '.join(repr(d) for d in self.reduce_dims)})")
+
+
+class ReduceFn:
+ add = PrimFn.add.reduce
+ mul = PrimFn.mul.reduce
+ max = PrimFn.max.reduce
+
+
+class PrimApply(TensorExpression):
+ """Application of a primitive."""
+
+ def __init__(self, prim: PrimFnType, args: Sequence[TensorExpression]):
+ self.prim = prim
+ self.args = tuple(args)
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ return ScalarApplyFn(self.prim.prim_name,
+ *[arg.to_scalar_expression() for arg in self.args
+ ]).expr()
+
+ def visit_affine_exprs(self, callback):
+ for arg in self.args:
+ arg.visit_affine_exprs(callback)
+
+ def collect_uses(self, uses: Set["TensorUse"]):
+ for arg in self.args:
+ arg.collect_uses(uses)
+
+ def __repr__(self):
+ return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})"
+
+
+class cast(TensorExpression):
+ """Casts the element type to a type (typically symbolic TypeVar)."""
+
+ def __init__(self, to_type: TypeVar, operand: TensorExpression):
+ self.to_type = to_type
+ self.operand = operand
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ return ScalarSymbolicCast(self.to_type,
+ self.operand.to_scalar_expression()).expr()
+
+ def visit_affine_exprs(self, callback):
+ self.operand.visit_affine_exprs(callback)
+
+ def collect_uses(self, uses: Set["TensorUse"]):
+ self.operand.collect_uses(uses)
+
+ def __repr__(self):
+ return f"cast({self.to_type}, {repr(self.operand)})"
+
+
+class ReduceApply(TensorExpression):
+ """Application of a reduction.
+
+ This captures the lhs separately (initial value) separately from the rhs.
+ """
+
+ def __init__(self, reduce: ReduceFnType, args: Sequence[TensorExpression]):
+ self.reduce = reduce
+ self.lhs = None # type: Optional[TensorUse]
+ self.args = tuple(args)
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ if self.lhs is None:
+ raise ValueError(f"Cannot scalarize a ReduceApply that has not been "
+ f"bound to its lhs: {self}")
+ full_args = [self.lhs.to_scalar_expression()
+ ] + [arg.to_scalar_expression() for arg in self.args]
+ return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr()
+
+ def visit_affine_exprs(self, callback):
+ for ind in self.reduce.reduce_dims:
+ ind.visit_affine_exprs(callback)
+ for arg in self.args:
+ arg.visit_affine_exprs(callback)
+
+ def collect_uses(self, uses: Set["TensorUse"]):
+ for arg in self.args:
+ arg.collect_uses(uses)
+
+ def __repr__(self):
+ return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})"
+
+
+class OpInterfaceDef:
+ """An interface that an op implements."""
+
+ def __init__(self, cpp_name: str):
+ self.cpp_name = cpp_name
+
+
+ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface")
+
+
+class OpMetadataDef(YAMLObject):
+ """Metadata about the op (generally not behavior impacting)."""
+ yaml_tag = "!LinalgOpMetadata"
+
+ def __init__(self, name: str, cpp_op_name: Optional[str], doc: Optional[str]):
+ self.name = name
+ self.cpp_op_name = cpp_op_name if cpp_op_name is not None else name
+ self.doc = doc
+ self.implements = [] # type: List[OpInterfaceDef]
+
+ def to_yaml_custom_dict(self):
+ d = dict(
+ name=self.name,
+ cpp_op_name=self.cpp_op_name,
+ doc=self.doc,
+ )
+ if self.implements:
+ d["implements"] = [intr.cpp_name for intr in self.implements]
+ return d
+
+
+class LinalgOpDef:
+ """Definition of a linalg op."""
+
+ def __init__(self,
+ name: str,
+ cpp_op_name: Optional[str] = None,
+ doc: Optional[str] = None):
+ self.metadata = OpMetadataDef(name=name, cpp_op_name=cpp_op_name, doc=doc)
+ self.registered_tensors = dict() # type: Dict[str, TensorDef]
+ self.comprehensions = list() # type: List[Comprehension]
+ self._affine_state = AffineBuildState()
+
+ @property
+ def inputs(self) -> Sequence[TensorDef]:
+ return [t for t in self.registered_tensors.values() if not t.output]
+
+ @property
+ def outputs(self) -> Sequence[TensorDef]:
+ return [t for t in self.registered_tensors.values() if t.output]
+
+ def add_tensor(self, tensor_name: str, tensor: TensorDef):
+ """Registers a tensor."""
+ if tensor_name in self.registered_tensors:
+ raise ValueError(f"Tensor {tensor_name} is already registered "
+ f"to {self.registered_tensors['tensor_name']}")
+ tensor.attach(len(self.registered_tensors), tensor_name, self)
+ self.registered_tensors[tensor_name] = tensor
+
+ def tensor(self, name):
+ """Gets a registered tensor by name."""
+ try:
+ return self.registered_tensors[name]
+ except KeyError:
+ raise KeyError(f"Tensor {name} is not registered")
+
+ def __repr__(self):
+ lines = [
+ f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_op_name},"
+ ]
+ for name, tensor in self.registered_tensors.items():
+ lines.append(f" {tensor}")
+ if self.comprehensions:
+ lines[-1] += " {"
+ for comprehension in self.comprehensions:
+ lines.append(f" {comprehension}")
+ lines.append("}")
+ return "\n".join(lines)
diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/config.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/config.py
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/config.py
@@ -0,0 +1,321 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""Represents configured ops as emitted for code generation.
+
+Classes in this module generally are directly serializable to YAML for use
+by the code generator.
+
+TODO: These should just be dumb containers or serialization code but they
+currently encode too many details of how the language is interpreted. Move this
+to helpers on the comprehension objects themselves.
+"""
+
+from typing import Any, Dict, Optional
+
+from mlir import ir as _ir
+
+from .comprehension import *
+from .yaml_helper import *
+
+__all__ = [
+ "LinalgStructuredOpConfig",
+ "LinalgOpConfig",
+]
+
+
+def _serialize_affine_map(affine_map: _ir.AffineMap) -> str:
+ with affine_map.context:
+ # Affine map printing/parsing is via an AffineMap attr.
+ attr = _ir.AffineMapAttr.get(affine_map)
+ return str(attr)
+
+
+class TensorUseConfig:
+ """Wrapper around a TensorUse with additional context-bound state."""
+
+ def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap):
+ self.tensor_use = tensor_use
+ self.indexing_map = indexing_map
+
+ def __repr__(self):
+ return f"Use({self.tensor_use}, indexing_map={self.indexing_map})"
+
+
+class TensorDefConfig(YAMLObject):
+ """Wrapper around a TensorDef with additional context-bound state."""
+ yaml_tag = "LinalgTensorDef"
+
+ def __init__(self, tensor_def: TensorDef, shape_map: _ir.AffineMap):
+ self.tensor_def = tensor_def
+ self.shape_map = shape_map
+ self.indexing_map = None # type: Optional[_ir.AffineMap]
+
+ def to_yaml_custom_dict(self):
+
+ def get_usage():
+ if self.tensor_def.output:
+ return "output"
+ else:
+ return "input"
+
+ return dict(
+ name=self.tensor_def.tensor_name,
+ usage=get_usage(),
+ shape=_serialize_affine_map(self.shape_map),
+ element_type_var=self.tensor_def.type_var.name,
+ )
+
+ def __repr__(self):
+ return f"Def({self.tensor_def}, shape_map={self.shape_map}, indexing_map={self.indexing_map})"
+
+
+class LinalgIndexingMapsConfig(YAMLObject):
+ """Abstracts the style of indexing maps that the op exports.
+
+ Presently only static (tied to the op name) indexing maps are supported. In
+ the future, it is expected that we will have additional variants:
+ - Dynamic based on attributes
+ - Dynamic based on operands
+ Each is expected to require a different variant of specification.
+ """
+ yaml_tag = "!LinalgIndexingMapsConfig"
+
+ def __init__(self,
+ static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None):
+ self.static_indexing_maps = static_indexing_maps
+
+ def to_yaml_custom_dict(self):
+ if self.static_indexing_maps is not None:
+ return dict(static_indexing_maps=[
+ _serialize_affine_map(m) for m in self.static_indexing_maps
+ ])
+ raise ValueError(
+ f"LinalgIndexingMapsConfig must have one type of indexing map"
+ f"(got none)")
+
+
+class LinalgStructuredOpConfig(YAMLObject):
+ """Configuration for metadata sufficient to construct a linalg single
+ contraction named op."""
+
+ yaml_tag = "!LinalgStructuredOpConfig"
+
+ def __init__(self,
+ comprehension: Comprehension,
+ context: Optional[_ir.Context] = None):
+ self.context = context if context is not None else _ir.Context()
+ self.affine_state = AffineBuildState()
+ self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]]
+ self.tensor_args = dict() # type: Dict[TensorDef, TensorDefConfig]
+ self.uses = dict() # type: Dict[TensorUse, TensorUseConfig]
+
+ # Compute the ordered set of writes.
+ collected_uses = set()
+ for write_use, read_use in zip(comprehension.definitions,
+ comprehension.values):
+ self.writes.append((write_use, read_use))
+
+ for write_use, read_use in self.writes:
+ collected_uses.add(write_use)
+ read_use.collect_uses(collected_uses)
+
+ # Need to add all definitions before uses, so process twice.
+ for use in collected_uses:
+ self.add_tensor_arg(use.tensor_def)
+ for use in collected_uses:
+ self.add_use(use)
+
+ # Now normalize all defs and uses indexing maps now that full count of
+ # dims and symbols are known.
+ for cuse in self.uses.values():
+ cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
+ for cdef in self.tensor_args.values():
+ cdef.shape_map = self._normalize_affine_map(cdef.shape_map,
+ with_dims=False)
+
+ # Now for each write use, propagate the indexing maps from the use to the
+ # tensor, ensuring that there are not conflicts.
+ for write_use, _ in self.writes:
+ write_tensor_def = self.tensor_args[write_use.tensor_def]
+ if write_tensor_def.indexing_map:
+ raise ValueError(
+ f"Unexpected multi-write to a single tensor: {write_tensor_def}")
+ write_tensor_def.indexing_map = self.uses[write_use].indexing_map
+
+ # For each read use, propagate the indexing maps from the use to the
+ # tensor, ensuring that there are not conflicts.
+ for _, read_expr in self.writes:
+ read_uses = set() # type: Set[TensorUse]
+ read_expr.collect_uses(read_uses)
+ for read_use in read_uses:
+ read_tensor_def = self.tensor_args[read_use.tensor_def]
+ if (read_tensor_def.indexing_map and
+ read_tensor_def.indexing_map != self.uses[read_use].indexing_map):
+ raise ValueError(
+ f"Unexpected multi-read of a tensor with different accesses:"
+ f"{read_tensor_def} vs {read_use}")
+ read_tensor_def.indexing_map = self.uses[read_use].indexing_map
+
+ # Sanity check that all defs have an indexing map.
+ assert all(d.indexing_map for d in self.tensor_args.values()), (
+ f"Missing indexing map on TensorDef: {self.tensor_args}")
+
+ # Collect reduction dims and ensure all the same.
+ all_reduction_dims = set(comprehension.all_reduction_dims)
+ if len(all_reduction_dims) != 1:
+ raise ValueError(
+ f"All writes within a generic must have the same reduction "
+ f"dims. Got: {all_reduction_dims}")
+ self.reduction_dims = next(iter(all_reduction_dims))
+
+ # Generate the scalar assignments (used to build a body).
+ self.assignments = [
+ ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
+ for write_use, read_expr in self.writes
+ ]
+
+ @property
+ def ordered_tensor_args(self) -> Sequence[TensorDefConfig]:
+ return sorted(self.tensor_args.values(),
+ key=lambda tdc: tdc.tensor_def.registered_index)
+
+ @property
+ def ordered_tensor_uses(self) -> Sequence[TensorUseConfig]:
+ return sorted(self.uses.values(),
+ key=lambda tuc: tuc.tensor_use.tensor_def.registered_index)
+
+ @property
+ def ordered_dims(self) -> Sequence[Tuple[str, int]]:
+ """Gets the ordered list of dim bindings (symbolic name, position).
+
+ TODO: The original parser relies on parse ordering to arrive at the
+ iterator types, but that ordering is not defined on the Python side, so
+ this may be ambiguous.
+ """
+ return list(self.affine_state.all_dims.items())
+
+ @property
+ def indexing_maps(self) -> Sequence[_ir.AffineMap]:
+ return [use.indexing_map for use in self.ordered_tensor_uses]
+
+ @property
+ def iterator_types(self) -> Sequence[str]:
+
+ def get_type(symbolic_name, position):
+ for reduction_dim_expr in self.reduction_dims:
+ if reduction_dim_expr.dimname == symbolic_name:
+ return "reduction"
+ return "parallel"
+
+ return [get_type(*dim) for dim in self.ordered_dims]
+
+ def add_tensor_arg(self, tensor_def: TensorDef):
+ if tensor_def in self.tensor_args:
+ return
+ with self.context:
+ local_state = AffineBuildState(global_state=self.affine_state,
+ allow_new_dims=False)
+ exprs = []
+ for expr in tensor_def.shape:
+ exprs.append(expr.build(state=local_state))
+ assert local_state.local_dim_count == 0
+ indexing_map = _ir.AffineMap.get(dim_count=0,
+ symbol_count=local_state.symbol_count,
+ exprs=exprs)
+
+ def_config = TensorDefConfig(tensor_def, indexing_map)
+ self.tensor_args[tensor_def] = def_config
+
+ def add_use(self, tensor_use: TensorUse):
+ if tensor_use in self.uses:
+ return
+ with self.context:
+ local_state = AffineBuildState(global_state=self.affine_state,
+ allow_new_symbols=False)
+ exprs = []
+ for expr in tensor_use.indices:
+ exprs.append(expr.build(state=local_state))
+ assert local_state.local_symbol_count == 0
+ indexing_map = _ir.AffineMap.get(dim_count=local_state.dim_count,
+ symbol_count=local_state.symbol_count,
+ exprs=exprs)
+
+ use_config = TensorUseConfig(tensor_use, indexing_map)
+ self.uses[tensor_use] = use_config
+
+ def _normalize_affine_map(self,
+ affine_map: _ir.AffineMap,
+ with_dims: bool = True) -> _ir.AffineMap:
+ """Normalizes an indexing map to have the max known symbols and dims."""
+ with self.context:
+ return _ir.AffineMap.get(
+ dim_count=self.affine_state.dim_count if with_dims else 0,
+ symbol_count=self.affine_state.symbol_count,
+ exprs=list(affine_map.results))
+
+ def to_yaml_custom_dict(self):
+ self_dict = dict(
+ args=self.ordered_tensor_args,
+ # TODO: Refactor the hierarchy internally when supporting more
+ # than static (preserving this serialized form).
+ indexing_maps=LinalgIndexingMapsConfig(
+ static_indexing_maps=self.indexing_maps),
+ iterator_types=self.iterator_types,
+ assignments=self.assignments,
+ )
+ return self_dict
+
+ def __repr__(self):
+ lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"]
+ lines.append("tensor_args=[")
+ for def_config in self.ordered_tensor_args:
+ lines.append(f" {repr(def_config)}")
+ lines.append("], indexing_maps=[")
+ for m in self.indexing_maps:
+ lines.append(f" {repr(m)}")
+ lines.append(f"], iterator_types=[")
+ for t in self.iterator_types:
+ lines.append(f" {t}")
+ lines.append("])")
+ return "\n".join(lines)
+
+
+class LinalgOpConfig(YAMLObject):
+ """Container for any supported linalg op type.
+
+ This includes the concrete type by name for ease of parsing by systems
+ that ignore tags.
+ """
+ yaml_tag = "!LinalgOpConfig"
+
+ def __init__(self,
+ metadata: OpMetadataDef,
+ *,
+ structured_op: Optional[LinalgStructuredOpConfig] = None):
+ self.metadata = metadata
+ self.structured_op = structured_op
+
+ def to_yaml_custom_dict(self):
+ self_dict = dict(metadata=self.metadata,)
+ if self.structured_op:
+ self_dict["structured_op"] = self.structured_op
+ return self_dict
+
+ @staticmethod
+ def from_linalg_op_def(
+ tc_op_def: LinalgOpDef,
+ context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]:
+ """Expands a LinalgOpDef into corresponding Linalg configured ops."""
+ # TODO: Many LinalgOpDef patterns need to expand to multiple generics.
+ assert len(
+ tc_op_def.comprehensions) == 1, "Only one comprehension supported"
+ return [
+ LinalgOpConfig(tc_op_def.metadata,
+ structured_op=LinalgStructuredOpConfig(
+ tc_op_def.comprehensions[0], context)),
+ ]
+
+ def __repr__(self):
+ return (f"LinalgOpConfig(metadata={self.metadata},\n"
+ f"structured_op={self.structured_op})")
diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/dsl.py
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/dsl.py
@@ -0,0 +1,91 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Dict, List
+
+from contextlib import contextmanager
+import functools
+import inspect
+import threading
+
+from mlir import ir
+from .comprehension import *
+
+_CONTEXT = threading.local()
+
+
+@contextmanager
+def bind_op_def(model: LinalgOpDef):
+ if hasattr(_CONTEXT, "current_op_def"):
+ raise ValueError("Cannot recursively define an operation")
+ _CONTEXT.current_op_def = model
+ try:
+ yield model
+ finally:
+ del _CONTEXT.current_op_def
+
+
+def current_op_def() -> LinalgOpDef:
+ try:
+ return _CONTEXT.current_op_def
+ except AttributeError:
+ raise ValueError(
+ "Attempt to access the current op definition being defined "
+ "but none is set. Did you mean to call this in an op definition?")
+
+
+class DefinedOpCallable:
+ """Callable that wraps any defined op function."""
+
+ def __init__(self, op_name: str, model: LinalgOpDef):
+ self.op_name = op_name
+ self.model = model
+
+ def __call__(self, *args, **kwargs):
+ # TODO: Upstream the emitter and invoke here
+ raise NotImplementedError("Linalg generic emission not yet implemented")
+
+
+def linalg_structured_op(dsl_func=None,
+ *,
+ op_name=None,
+ op_class_name=None) -> DefinedOpCallable:
+ if dsl_func is None:
+ # Curry the keyword args in for delayed application.
+ return functools.partial(tc_def_op,
+ op_name=op_name,
+ op_class_name=op_class_name)
+ # Determine default names by introspecting the function.
+ if op_name is None:
+ op_name = dsl_func.__name__
+ if op_class_name is None:
+ # Camel case it.
+ op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
+
+ tc_model = LinalgOpDef(name=op_name,
+ cpp_op_name=op_class_name,
+ doc=inspect.getdoc(dsl_func))
+
+ # Extract arguments and TensorDefs from the signature.
+ dsl_func_args = list()
+ sig = inspect.signature(dsl_func)
+ for param_name, param in sig.parameters.items():
+ param_default = param.default
+ if not isinstance(param_default, TensorDef):
+ raise ValueError(f"@tc_def_op function parameters must be defaulted as "
+ f"TensorDef(...): Found {param_name}: {param_default}")
+ dsl_func_args.append(param_default)
+ tc_model.add_tensor(param_name, param_default)
+
+ # Invoke the DSL func to finish populating the model.
+ with bind_op_def(tc_model):
+ dsl_func(*dsl_func_args)
+
+ # TODO: The returned callable should be an IR emitter but that is not
+ # upstreamed yet.
+ return DefinedOpCallable(op_name, tc_model)
+
+
+def implements(*interfaces: OpInterfaceDef):
+ current_op_def().metadata.implements.extend(interfaces)
diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/scalar_expr.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/scalar_expr.py
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/scalar_expr.py
@@ -0,0 +1,124 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""Models DAGs of scalar math expressions.
+
+Used for generating region bodies at the "math" level where they are still type
+polymorphic. This is modeled to be polymorphic by attribute name for interop
+with serialization schemes that are just plain-old-dicts.
+
+These classes are typically not user accessed and are created as a by-product
+of interpreting a comprehension DSL and model the operations to perform in the
+op body. The class hierarchy is laid out to map well to a form of YAML that
+can be easily consumed from the C++ side, not necessarily for ergonomics.
+"""
+
+from typing import Optional, Sequence
+
+from .yaml_helper import *
+from .types import *
+
+__all__ = [
+ "ScalarAssign",
+ "ScalarApplyFn",
+ "ScalarArg",
+ "ScalarExpression",
+ "ScalarSymbolicCast",
+]
+
+
+class ScalarApplyFn:
+ """A type of ScalarExpression that applies a named function to operands."""
+
+ def __init__(self, fn_name: str, *operands: "ScalarExpression"):
+ self.fn_name = fn_name
+ self.operands = operands
+
+ def expr(self) -> "ScalarExpression":
+ return ScalarExpression(scalar_apply=self)
+
+ def __repr__(self):
+ return f"ScalarApplyFn<{self.fn_name}>({', '.join(self.operands)})"
+
+
+class ScalarArg:
+ """A type of ScalarExpression that references a named argument."""
+
+ def __init__(self, arg: str):
+ self.arg = arg
+
+ def expr(self) -> "ScalarExpression":
+ return ScalarExpression(scalar_arg=self)
+
+ def __repr__(self):
+ return f"(ScalarArg({self.arg})"
+
+
+class ScalarSymbolicCast:
+ """A type of ScalarExpression that symbolically casts an operand to a TypeVar.
+ """
+
+ def __init__(self, to_type: TypeVar, operand: "ScalarExpression"):
+ self.to_type = to_type
+ self.operand = operand
+
+ def expr(self) -> "ScalarExpression":
+ return ScalarExpression(symbolic_cast=self)
+
+ def __repr__(self):
+ return f"ScalarSymbolicCast({self.to_type}, {self.operand})"
+
+
+class ScalarExpression(YAMLObject):
+ """An expression on scalar values.
+
+ Can be one of:
+ - ScalarApplyFn
+ - ScalarArg
+ - ScalarSymbolicCast
+ """
+ yaml_tag = "!ScalarExpression"
+
+ def __init__(self,
+ scalar_apply: Optional[ScalarApplyFn] = None,
+ scalar_arg: Optional[ScalarArg] = None,
+ symbolic_cast: Optional[ScalarSymbolicCast] = None):
+ if (bool(scalar_apply) + bool(scalar_arg) + bool(symbolic_cast)) != 1:
+ raise ValueError(
+ "One of 'scalar_apply', 'scalar_block_arg', 'symbolic_cast' must be "
+ "specified")
+ self.scalar_apply = scalar_apply
+ self.scalar_arg = scalar_arg
+ self.symbolic_cast = symbolic_cast
+
+ def to_yaml_custom_dict(self):
+ if self.scalar_apply:
+ return dict(scalar_apply=dict(
+ fn_name=self.scalar_apply.fn_name,
+ operands=list(self.scalar_apply.operands),
+ ))
+ elif self.scalar_arg:
+ return dict(scalar_arg=self.scalar_arg.arg)
+ elif self.symbolic_cast:
+ # Note that even though operands must be arity 1, we write it the
+ # same way as for apply because it allows handling code to be more
+ # generic vs having a special form.
+ return dict(symbolic_cast=dict(type_var=self.symbolic_cast.to_type.name,
+ operands=[self.symbolic_cast.operand]))
+ else:
+ raise ValueError(f"Unexpected ScalarExpression type: {self}")
+
+
+class ScalarAssign(YAMLObject):
+ """An assignment to a named argument (LHS of a comprehension)."""
+ yaml_tag = "!ScalarAssign"
+
+ def __init__(self, arg: str, value: ScalarExpression):
+ self.arg = arg
+ self.value = value
+
+ def to_yaml_custom_dict(self):
+ return dict(arg=self.arg, value=self.value)
+
+ def __repr__(self):
+ return f"ScalarAssign({self.arg}, {self.value})"
diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/types.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/types.py
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/types.py
@@ -0,0 +1,69 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""Facility for symbolically referencing type variables.
+
+Type variables are instances of the TypeVar class, which is uniqued by name.
+An "expando" accessor `TV` is provided that generates a named TypeVar for
+any attribute access:
+
+ >>> TV.T
+ TypeVar(T)
+ >>> TV.T is TV.U
+ False
+ >>> TV.T is TV.T
+ True
+"""
+
+from enum import Enum
+from typing import Dict
+
+__all__ = [
+ "TypeVar",
+ "TV",
+
+ # TypeVar aliases.
+ "T",
+ "U",
+ "V",
+]
+
+
+class TypeVar:
+ """A replaceable type variable.
+
+ Type variables are uniqued by name.
+ """
+ ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"]
+
+ def __new__(cls, name: str):
+ existing = cls.ALL_TYPEVARS.get(name)
+ if existing is not None:
+ return existing
+ new = super().__new__(cls)
+ new.name = name
+ cls.ALL_TYPEVARS[name] = new
+ return new
+
+ def __repr__(self):
+ return f"TypeVar({self.name})"
+
+ @classmethod
+ def create_expando(cls):
+ """Create an expando class that creates unique type vars on attr access."""
+
+ class ExpandoTypeVars:
+
+ def __getattr__(self, n):
+ return cls(n)
+
+ return ExpandoTypeVars()
+
+
+# Expando access via TV.foo
+TV = TypeVar.create_expando()
+
+# Some common type name aliases.
+T = TV.T
+U = TV.U
+V = TV.V
diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/yaml_helper.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/yaml_helper.py
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/lang/yaml_helper.py
@@ -0,0 +1,54 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""YAML serialization is routed through here to centralize common logic."""
+
+import sys
+
+try:
+ import yaml
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ f"This tool requires PyYAML but it was not installed. "
+ f"Recommend: {sys.executable} -m pip install PyYAML") from e
+
+__all__ = [
+ "yaml_dump",
+ "yaml_dump_all",
+ "YAMLObject",
+]
+
+
+class YAMLObject(yaml.YAMLObject):
+
+ @classmethod
+ def to_yaml(cls, dumper, self):
+ """Default to a custom dictionary mapping."""
+ return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict())
+
+ def to_yaml_custom_dict(self):
+ raise NotImplementedError()
+
+ def as_linalg_yaml(self):
+ return yaml_dump(self)
+
+
+def multiline_str_representer(dumper, data):
+ if len(data.splitlines()) > 1:
+ return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
+ else:
+ return dumper.represent_scalar('tag:yaml.org,2002:str', data)
+
+
+yaml.add_representer(str, multiline_str_representer)
+
+
+def yaml_dump(data, sort_keys=False, **kwargs):
+ return yaml.dump(data, sort_keys=sort_keys, **kwargs)
+
+
+def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs):
+ return yaml.dump_all(data,
+ sort_keys=sort_keys,
+ explicit_start=explicit_start,
+ **kwargs)
diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/__init__.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/__init__.py
new file mode 100644
diff --git a/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/core_named_ops.py b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/core_named_ops.py
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/core_named_ops.py
@@ -0,0 +1,70 @@
+from ..lang import *
+
+T1 = TV.T1
+T2 = TV.T2
+
+Batch = S.Batch
+
+
+@linalg_structured_op
+def matmul(A=TensorDef(T1, S.M, S.K),
+ B=TensorDef(T2, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
+ """Performs a matrix multiplacation of two 2D inputs.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ContractionOpInterface)
+ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+
+
+@linalg_structured_op
+def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
+ B=TensorDef(T2, Batch, S.K, S.N),
+ C=TensorDef(U, Batch, S.M, S.N, output=True)):
+ """Performs a batched matrix multiplacation of two 3D inputs.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ContractionOpInterface)
+ C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n])
+
+
+@linalg_structured_op
+def matvec(A=TensorDef(T1, S.M, S.N),
+ y=TensorDef(T2, S.N),
+ x=TensorDef(U, S.M, output=True)):
+ """Performs a matrix-vector multiplication.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ContractionOpInterface)
+ x[D.m] += cast(U, A[D.m, D.n]) * cast(U, y[D.n])
+
+
+@linalg_structured_op
+def vecmat(y=TensorDef(T1, S.M),
+ A=TensorDef(T2, S.M, S.N),
+ x=TensorDef(U, S.N, output=True)):
+ """Performs a vector-matrix multiplacation.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ContractionOpInterface)
+ x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n])
+
+
+@linalg_structured_op
+def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U,
+ output=True)):
+ """Performs a dot product of two vectors to a scalar result.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ContractionOpInterface)
+ C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
diff --git a/mlir/test/Bindings/Python/tools/linalg_opdsl/assignments.py b/mlir/test/Bindings/Python/tools/linalg_opdsl/assignments.py
new file mode 100644
--- /dev/null
+++ b/mlir/test/Bindings/Python/tools/linalg_opdsl/assignments.py
@@ -0,0 +1,29 @@
+# RUN: %PYTHON -m mlir.tools.linalg_opdsl.dump_oplib --file %s | FileCheck %s
+
+from mlir.tools.linalg_opdsl.lang import *
+
+# CHECK: ---
+# CHECK-LABEL: matmul
+# CHECK: assignments:
+# CHECK: -
+# CHECK: arg: C
+# CHECK: value:
+# CHECK: scalar_apply:
+# CHECK: fn_name: add
+# CHECK: operands:
+# CHECK: scalar_apply:
+# CHECK: fn_name: mul
+# CHECK: operands:
+# CHECK: symbolic_cast:
+# CHECK: type_var: U
+# CHECK: operands:
+# CHECK: scalar_arg: A
+# CHECK: symbolic_cast:
+# CHECK: type_var: U
+# CHECK: operands:
+# CHECK: scalar_arg: B
+@linalg_structured_op
+def matmul(A=TensorDef(T, S.M, S.K),
+ B=TensorDef(T, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
+ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
diff --git a/mlir/test/Bindings/Python/tools/linalg_opdsl/doctests.py b/mlir/test/Bindings/Python/tools/linalg_opdsl/doctests.py
new file mode 100644
--- /dev/null
+++ b/mlir/test/Bindings/Python/tools/linalg_opdsl/doctests.py
@@ -0,0 +1,13 @@
+# RUN: %PYTHON %s
+
+import doctest
+import importlib
+
+def test_module(module_name):
+ print(f"--- Testing module: {module_name}")
+ m = importlib.import_module(module_name)
+ doctest.testmod(m, verbose=True, raise_on_error=True, report=True)
+
+
+test_module("mlir.tools.linalg_opdsl.lang.affine")
+test_module("mlir.tools.linalg_opdsl.lang.types")
diff --git a/mlir/test/Bindings/Python/tools/linalg_opdsl/interfaces.py b/mlir/test/Bindings/Python/tools/linalg_opdsl/interfaces.py
new file mode 100644
--- /dev/null
+++ b/mlir/test/Bindings/Python/tools/linalg_opdsl/interfaces.py
@@ -0,0 +1,14 @@
+# RUN: %PYTHON -m mlir.tools.linalg_opdsl.dump_oplib --file %s | FileCheck %s
+
+from mlir.tools.linalg_opdsl.lang import *
+
+# CHECK: ---
+# CHECK-LABEL: matmul
+# CHECK: implements:
+# CHECK-NEXT: - LinalgContractionOpInterface
+@linalg_structured_op
+def matmul(A=TensorDef(T, S.M, S.K),
+ B=TensorDef(T, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
+ implements(ContractionOpInterface)
+ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
diff --git a/mlir/test/Bindings/Python/tools/linalg_opdsl/lit.local.cfg b/mlir/test/Bindings/Python/tools/linalg_opdsl/lit.local.cfg
new file mode 100644
--- /dev/null
+++ b/mlir/test/Bindings/Python/tools/linalg_opdsl/lit.local.cfg
@@ -0,0 +1,9 @@
+# TODO: This tool requires PyYAML, which is not yet a required build/test
+# dependency. Remove this exclusion once it is a required dep.
+
+# Since both lit and the python bindings use the same python interpreter,
+# we can just check whether yaml can be imported here and exclude if not.
+try:
+ import yaml
+except ModuleNotFoundError:
+ config.unsupported = True
diff --git a/mlir/test/Bindings/Python/tools/linalg_opdsl/shape_maps_iteration.py b/mlir/test/Bindings/Python/tools/linalg_opdsl/shape_maps_iteration.py
new file mode 100644
--- /dev/null
+++ b/mlir/test/Bindings/Python/tools/linalg_opdsl/shape_maps_iteration.py
@@ -0,0 +1,43 @@
+# RUN: %PYTHON -m mlir.tools.linalg_opdsl.dump_oplib --file %s | FileCheck %s
+
+from mlir.tools.linalg_opdsl.lang import *
+
+
+# Verify that simple case with iteration order defined lexically and reduction
+# dims auto discovered emits the right shape, indexing maps and iterator types.
+# CHECK: ---
+# CHECK-LABEL: matmul
+# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
+# CHECK: shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
+# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
+# CHECK: static_indexing_maps:
+# CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+# CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
+# CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+# CHECK: iterator_types:
+# CHECK-NEXT: - parallel
+# CHECK-NEXT: - parallel
+# CHECK-NEXT: - reduction
+@linalg_structured_op
+def matmul(A=TensorDef(T, S.M, S.K),
+ B=TensorDef(T, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
+ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+
+
+# Verifies that assignment to a scalar (represented as [None]) is represented
+# correctly.
+# CHECK: ---
+# CHECK-LABEL: dot
+# CHECK: shape: affine_map<()[s0] -> (s0)>
+# CHECK: shape: affine_map<()[s0] -> (s0)>
+# CHECK: shape: affine_map<()[s0] -> ()>
+# CHECK: static_indexing_maps:
+# CHECK-NEXT: - affine_map<(d0)[s0] -> (d0)>
+# CHECK-NEXT: - affine_map<(d0)[s0] -> (d0)>
+# CHECK-NEXT: - affine_map<(d0)[s0] -> ()>
+# CHECK: iterator_types:
+# CHECK-NEXT: - reduction
+@linalg_structured_op
+def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
+ C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
diff --git a/mlir/test/Bindings/Python/tools/linalg_opdsl/test_core_named_ops.py b/mlir/test/Bindings/Python/tools/linalg_opdsl/test_core_named_ops.py
new file mode 100644
--- /dev/null
+++ b/mlir/test/Bindings/Python/tools/linalg_opdsl/test_core_named_ops.py
@@ -0,0 +1,4 @@
+# RUN: %PYTHON -m mlir.tools.linalg_opdsl.dump_oplib .ops.core_named_ops | FileCheck %s
+
+# Just verify that at least one known op is generated.
+# CHECK: name: matmul