diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py @@ -31,5 +31,31 @@ passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]]) passed += np.allclose(values, [20.0, 5.0, 63.0]) -# CHECK: Number of passed: 2 +# PyTACO doesn't allow the use of index values, but MLIR-PyTACO removes this +# restriction. +E = pt.tensor([3]) +E[i] = i +indices, values = E.get_coordinates_and_values() +passed += np.array_equal(indices, [[0], [1], [2]]) +passed += np.allclose(values, [0.0, 1.0, 2.0]) + +F = pt.tensor([3]) +G = pt.tensor([3]) +F.insert([0], 10) +F.insert([2], 40) +G[i] = F[i] + i +indices, values = G.get_coordinates_and_values() +passed += np.array_equal(indices, [[0], [1], [2]]) +passed += np.allclose(values, [10.0, 1.0, 42.0]) + +H = pt.tensor([3]) +I = pt.tensor([3]) +H.insert([0], 10) +H.insert([2], 40) +I[i] = H[i] * i +indices, values = I.get_coordinates_and_values() +passed += np.array_equal(indices, [[0], [2]]) +passed += np.allclose(values, [0.0, 80.0]) + +# CHECK: Number of passed: 8 print("Number of passed:", passed) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -611,7 +611,7 @@ # _StructOpInfo for the top level expression. expr_to_info[self].structop_info = _StructOpInfo(dst_indices, tuple(dst.shape), - self.dtype(), dst.name, + dst.dtype, dst.name, dst.format) return structop_roots @@ -650,7 +650,7 @@ raise ValueError("Destination IndexVar not used in the " f"source expression: {i}") else: - if d != index_to_dim_info[i].dim: + if d != index_to_dim_info[i].dim and index_to_dim_info[i].dim != -1: raise ValueError(f"Inconsistent destination dimension for {i}: " f"{d} vs {index_to_dim_info[i].dim}") @@ -739,7 +739,7 @@ return old_value -class IndexVar: +class IndexVar(IndexExpr): """The tensor index class. We support the TACO API index_var class with an alias of this class. @@ -763,6 +763,34 @@ """Returns the name of the IndexVar.""" return self._name + def _visit(self, + func: _ExprVisitor, + args, + *, + leaf_checker: _SubtreeLeafChecker = None) -> None: + """A post-order visitor.""" + if leaf_checker: + assert leaf_checker(self, *args) + func(self, *args) + + def _emit_expression( + self, + expr_to_opnd: Dict[IndexExpr, lang.OperandDef], + expr_to_info: _ExprInfoDict, + ) -> lang.ScalarExpression: + """Emits a index value casted to the data type of the tensor expression.""" + dim = getattr(lang.D, self.name) + index = lang.index(dim) + int_value = lang.TypeFn.cast_unsigned(lang.TV.I64, index) + return lang.TypeFn.cast_unsigned(lang.T, int_value) + + def dtype(self) -> DType: + """Returns the data type for the index value. + + This is unreachable for IndexVar. + """ + assert 0 + def get_index_vars(n: int) -> List[IndexVar]: """Returns a list of n IndexVar. @@ -1527,6 +1555,11 @@ mode_format: ModeFormat +def _get_dummy_dim_info() -> _DimInfo: + """Constructs the _DimInfo for an index used in tensor expressions.""" + return _DimInfo(-1, ModeFormat.DENSE) + + @dataclasses.dataclass() class _ExprInfo: """Expression information for validation and code generation. @@ -1788,9 +1821,12 @@ if i not in index_to_dim_info: index_to_dim_info[i] = d else: - if d.dim != index_to_dim_info[i].dim: + dim = index_to_dim_info[i].dim + if dim == -1 or d.dim == -1: + dim = dim if dim != -1 else d.dim + elif dim != d.dim: raise ValueError(f"Inconsistent source dimension for {i}: " - f"{d.dim} vs {index_to_dim_info[i].dim}") + f"{d.dim} vs {dim}") mode_format = _mode_format_estimator(expr.op)( index_to_dim_info[i].mode_format, d.mode_format) index_to_dim_info[i] = _DimInfo(d.dim, mode_format) @@ -1823,7 +1859,10 @@ if expr in expr_to_info: return - if isinstance(expr, Access): + if isinstance(expr, IndexVar): + src_indices = expr, # A tuple with one element. + dim_infos = _get_dummy_dim_info(), # A tuple with one element. + elif isinstance(expr, Access): src_indices = expr.indices src_dims = tuple(expr.tensor.shape) if expr.tensor.format is None: @@ -1883,6 +1922,9 @@ reduce_index: The IndexVar which we want to find out the proper expression to perform a reduction. expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. + + Raises: + ValueError: If the expression is not proper or not supported. """ expr_info = expr_to_info[expr] if isinstance(expr, Access): @@ -1890,6 +1932,9 @@ if reduce_index in expr_info.src_indices: expr_info.reduce_indices.add(reduce_index) return + elif isinstance(expr, IndexVar): + # A[i] = B[i] + j is not allowed. + raise ValueError(f"IndexVar is not part of the iteration domain: {expr}.") assert (isinstance(expr, _BinaryExpr)) a_info = expr_to_info[expr.a] @@ -1933,6 +1978,11 @@ a_info = expr_to_info[expr.a] expr_info.acc_reduce_indices = ( a_info.acc_reduce_indices | expr_info.reduce_indices) + elif isinstance(expr, IndexVar): + # If an IndexVar is reducing itself, it means the IndexVar is outside the + # iteration domain. This usage is now allowed and we should emit an error + # before reaching here. + assert not expr_info.reduce_indices else: assert isinstance(expr, Access) # Handle simple reduction expression in the format of A[i] = B[i, j]. @@ -2011,7 +2061,7 @@ """ return (expr != root and expr_to_info[expr].structop_info is not None) or isinstance( - expr, Access) + expr, Access) or isinstance(expr, IndexVar) def _gather_structured_op_input( diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py @@ -201,14 +201,6 @@ C[i, j] = A[1, j] + B[i, j] """), "Expected IndexVars") -# CHECK: test_invalid_operation: passed -test_expect_error("invalid_operation", (""" -i, j = mlir_pytaco.get_index_vars(2) -A = mlir_pytaco.Tensor([2, 3]) -C = mlir_pytaco.Tensor([2, 3], _DENSE) -C[i, j] = A[i, j] + i - """), "Expected IndexExpr") - # CHECK: test_inconsistent_rank_indices: passed test_expect_error("inconsistent_rank_indices", (""" i, j = mlir_pytaco.get_index_vars(2) @@ -245,6 +237,15 @@ C.evaluate() """), "Inconsistent source dimension for IndexVar") +# CHECK: test_index_var_outside_domain: passed +test_expect_error("index_var_outside_domain", (""" +i, j = mlir_pytaco.get_index_vars(2) +A = mlir_pytaco.Tensor([3]) +B = mlir_pytaco.Tensor([3]) +B[i] = A[i] + j +B.evaluate() + """), "IndexVar is not part of the iteration domain") + # CHECK-LABEL: test_tensor_all_dense_sparse @testing_utils.run_test