diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py @@ -0,0 +1,28 @@ +# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s + +import numpy as np +import os +import sys + +_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(_SCRIPT_PATH) +from tools import mlir_pytaco_api as pt + +compressed = pt.compressed + +i, j = pt.get_index_vars(2) +A = pt.tensor([2, 3]) +S = pt.tensor(3) # S is a scalar tensor. +B = pt.tensor([2, 3], compressed) +A.insert([0, 1], 10) +A.insert([1, 2], 40) + +# Use [0] to index the scalar tensor. +B[i, j] = A[i, j] * S[0] + +indices, values = B.get_coordinates_and_values() +passed = np.array_equal(indices, [[0, 1], [1, 2]]) +passed += np.allclose(values, [30.0, 120.0]) + +# CHECK: Number of passed: 2 +print("Number of passed:", passed) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -447,27 +447,6 @@ return ir.RankedTensorType.get(shape, ir_type, attr) -def _verify_and_normalize_indices(indices) -> Tuple[IndexVar, ...]: - """Verifies and normalizes the indices for a tensor access. - - Args: - indices: The index expression used to access a tensor, which could be any - Python object from user inputs. - - Returns: - A tuple of IndexVar. - - Raises: - ValueError: If indices is not an IndexVar or a tuple of IndexVar. - """ - if isinstance(indices, IndexVar): - return (indices,) - elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar): - return indices - - raise ValueError(f"Expected IndexVars: {indices}") - - @dataclasses.dataclass(frozen=True) class _StructOpInfo: """Information for generating a structured op in the linalg dialect. @@ -761,7 +740,7 @@ def is_dense(self) -> bool: """Returns true if the tensor doesn't have sparsity annotation.""" - return self._format is None + return self.order == 0 or self._format is None def to_array(self) -> np.ndarray: """Returns the numpy array for the Tensor. @@ -918,6 +897,32 @@ """Returns the shape of the Tensor.""" return self._shape + def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]: + """Verifies and normalizes the indices to access the tensor. + + Args: + indices: The index expression used to access a tensor, which could be any + Python object from user inputs. + + Returns: + A tuple of IndexVar. + + Raises: + ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or + a tuple of IndexVar for other tensors. + """ + if self.order == 0: + if not isinstance(indices, int) or indices != 0: + raise ValueError(f"Expected 0 to index scalar tensors: {indices}") + return () + + if isinstance(indices, IndexVar): + return (indices,) + elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar): + return indices + + raise ValueError(f"Expected IndexVars: {indices}") + def __getitem__(self, key) -> "Access": """Verifies and processes a tensor access. @@ -936,7 +941,7 @@ Raises: ValueError: If key is not an IndexVar or a tuple of IndexVar. """ - indices = _verify_and_normalize_indices(key) + indices = self._verify_and_normalize_indices(key) return Access(self, indices) def __setitem__(self, key, value) -> None: @@ -960,7 +965,7 @@ or a tuple of IndexVar, or the length of the indices is not the same as the rank of the tensor. """ - indices = _verify_and_normalize_indices(key) + indices = self._verify_and_normalize_indices(key) if len(indices) != self.order: raise ValueError("Mismatch between indices and tensor rank: " f"len({indices}) != {self.order}.") @@ -985,8 +990,8 @@ def mlir_tensor_type(self) -> ir.RankedTensorType: """Returns the MLIR type for the tensor.""" - mlir_attr = None if ( - self._format is None) else self._format.mlir_tensor_attr() + mlir_attr = (None if (self._format is None or self.order == 0) else + self._format.mlir_tensor_attr()) return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr) def dense_dst_ctype_pointer(self) -> ctypes.pointer: