diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_unary_ops.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_unary_ops.py new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_unary_ops.py @@ -0,0 +1,40 @@ +# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s + +import numpy as np +import os +import sys + +_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(_SCRIPT_PATH) +from tools import mlir_pytaco_api as pt + +i, j = pt.get_index_vars(2) +A = pt.tensor([2, 3]) +B = pt.tensor([2, 3]) +A.insert([0, 1], 10.3) +A.insert([1, 1], 40.7) +A.insert([0, 2], -11.3) +A.insert([1, 2], -41.7) + +B[i, j] = abs(A[i, j]) +indices, values = B.get_coordinates_and_values() +passed = np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]]) +passed += np.allclose(values, [10.3, 11.3, 40.7, 41.7]) + +B[i, j] = pt.ceil(A[i, j]) +indices, values = B.get_coordinates_and_values() +passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]]) +passed += np.allclose(values, [11, -11, 41, -41]) + +B[i, j] = pt.floor(A[i, j]) +indices, values = B.get_coordinates_and_values() +passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]]) +passed += np.allclose(values, [10, -12, 40, -42]) + +B[i, j] = -A[i, j] +indices, values = B.get_coordinates_and_values() +passed += np.array_equal(indices, [[0, 1], [0, 2], [1, 1], [1, 2]]) +passed += np.allclose(values, [-10.3, 11.3, -40.7, 41.7]) + +# CHECK: Number of passed: 8 +print("Number of passed:", passed) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -53,6 +53,7 @@ _ENTRY_NAME = "main" # Type aliases for type annotation. +_UnaryOp = Callable[[Any], Any] _BinaryOp = Callable[[Any, Any], Any] _ExprVisitor = Callable[..., None] _ExprInfoDict = Dict["IndexExpr", "_ExprInfo"] @@ -1223,6 +1224,14 @@ raise ValueError(f"Expected IndexExpr: {rhs}") return _BinaryExpr(op, self, rhs) + def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr": + """Build a unary expression. + + Args: + op: A _UnaryOp object representing the unary operation. + """ + return _UnaryExpr(op, self) + def __add__(self, rhs) -> "_BinaryExpr": """Defines the operator +. @@ -1253,6 +1262,22 @@ """ return self._verify_operand_and_build_expr(rhs, operator.mul) + def __abs__(self) -> "_UnaryExpr": + """Defines the operator abs. + + Returns: + A _UnaryExpr object representing the operation. + """ + return self._build_unary_expr(operator.abs) + + def __neg__(self) -> "_UnaryExpr": + """Defines the operator neg. + + Returns: + A _UnaryExpr object representing the operation. + """ + return self._build_unary_expr(operator.neg) + def __sub__(self, rhs) -> "_BinaryExpr": """Defines the operator -. @@ -1603,6 +1628,75 @@ input_accesses.append(expr) +def _op_ceil(__a: Any) -> Any: + """A _UnaryOp object for operation ceil.""" + pass + + +def _op_floor(__a: Any) -> Any: + """A _UnaryOp object for operation floor.""" + pass + + +def _op_unary_to_callable(op: _UnaryOp) -> lang.UnaryFnType: + """Returns the linalg dialect function object for the given operation.""" + op_to_callable = { + operator.abs: lang.UnaryFn.abs, + operator.neg: lang.UnaryFn.negf, + _op_ceil: lang.UnaryFn.ceil, + _op_floor: lang.UnaryFn.floor, + } + return op_to_callable[op] + + +@dataclasses.dataclass(frozen=True) +class _UnaryExpr(IndexExpr): + """The representation for a Unary operation. + + Attributes: + op: A _UnaryOp representing the operation. + a: An IndexExpr representing the operand for the operation. + """ + op: _BinaryOp + a: IndexExpr + + def __post_init__(self) -> None: + """Verifies that the operand being added is an IndexExpr.""" + assert isinstance(self.a, IndexExpr) + + def _emit_expression( + self, + expr_to_opnd: Dict[IndexExpr, lang.OperandDef], + expr_to_info: _ExprInfoDict, + ) -> lang.ScalarExpression: + """Emits the expression tree and returns the expression.""" + # The current expression node is an internal node of the structured op. + if self not in expr_to_opnd: + a = self.a._emit_expression(expr_to_opnd, expr_to_info) + return _op_unary_to_callable(self.op)(a) + + # The current expression is a leaf node of the structured op. That is, it is + # a temporary tensor generated by its child structured op. + op_info = expr_to_info[self].structop_info + assert op_info is not None + dims = _mlir_dimensions_from_index_vars(op_info.dst_indices) + return lang.TensorUse(expr_to_opnd[self], dims) + + def _visit(self, + func: _ExprVisitor, + args, + *, + leaf_checker: _SubtreeLeafChecker = None) -> None: + """A post-order visitor.""" + if leaf_checker is None or not leaf_checker(self, *args): + self.a._visit(func, args, leaf_checker=leaf_checker) + func(self, *args) + + def dtype(self) -> DType: + """Returns the data type of the operation.""" + return self.a.dtype() + + def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType: """Returns the linalg dialect function object for the given operation.""" op_to_callable = { @@ -1612,7 +1706,6 @@ } return op_to_callable[op] - @dataclasses.dataclass(frozen=True) class _BinaryExpr(IndexExpr): """The representation for a binary operation. @@ -1740,6 +1833,15 @@ mode_formats = tuple(expr.tensor.format.format_pack.formats) assert len(src_dims) == len(mode_formats) dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)]) + elif isinstance(expr, _UnaryExpr): + a_info = expr_to_info[expr.a] + index_to_dim_info = { + i: d for i, d in zip(a_info.src_indices, a_info.dim_infos) + } + # Here we rely on the fact that dictionaries keep the insertion order for + # keys and values. + src_indices = tuple(index_to_dim_info.keys()) + dim_infos = tuple(index_to_dim_info.values()) else: assert isinstance(expr, _BinaryExpr) a_info = expr_to_info[expr.a] @@ -1826,6 +1928,10 @@ expr_info.acc_reduce_indices = ( a_info.acc_reduce_indices | b_info.acc_reduce_indices | expr_info.reduce_indices) + elif isinstance(expr, _UnaryExpr): + a_info = expr_to_info[expr.a] + expr_info.acc_reduce_indices = ( + a_info.acc_reduce_indices | expr_info.reduce_indices) else: assert isinstance(expr, Access) # Handle simple reduction expression in the format of A[i] = B[i, j]. @@ -1965,3 +2071,51 @@ opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym) op_def.add_operand(name, opnd) return opnd + + +def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr": + """Build a unary operation ceil. + + Args: + a: The operand, which could be any Python object from user inputs. + op: An _UnaryOp object representing the operation. + + Returns: + A _UnaryExpr object representing the operation. + + Raises: + ValueError: If a is not an IndexExpr. + """ + if not isinstance(a, Access): + raise ValueError(f"Expected an Access Operand: {a}") + return a._build_unary_expr(op) + + +def ceil(a: Access) -> "_UnaryExpr": + """Defines the operation ceil. + + Args: + a: The operand, which could be any Python object from user inputs. + + Returns: + A _UnaryExpr object representing the operation. + + Raises: + ValueError: If a is not an IndexExpr. + """ + return _check_and_build_unary(a, _op_ceil) + + +def floor(a: Access) -> "_UnaryExpr": + """Defines the operation floor. + + Args: + a: The operand, which could be any Python object from user inputs. + + Returns: + A _UnaryExpr object representing the operation. + + Raises: + ValueError: If a is not an IndexExpr. + """ + return _check_and_build_unary(a, _op_floor) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py @@ -16,6 +16,8 @@ from . import mlir_pytaco_io # Functions defined by PyTACO API. +ceil = mlir_pytaco.ceil +floor = mlir_pytaco.floor get_index_vars = mlir_pytaco.get_index_vars from_array = mlir_pytaco.Tensor.from_array read = mlir_pytaco_io.read