diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -1683,10 +1683,16 @@ to perform a reduction. expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. """ + expr_info = expr_to_info[expr] + if isinstance(expr, Access): + # Handle simple reduction expression in the format of A[i] = B[i, j]. + if reduce_index in expr_info.src_indices: + expr_info.reduce_indices.add(reduce_index) + return + assert (isinstance(expr, _BinaryExpr)) a_info = expr_to_info[expr.a] b_info = expr_to_info[expr.b] - expr_info = expr_to_info[expr] if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices: expr_info.reduce_indices.add(reduce_index) @@ -1724,6 +1730,9 @@ | expr_info.reduce_indices) else: assert isinstance(expr, Access) + # Handle simple reduction expression in the format of A[i] = B[i, j]. + expr_info.acc_reduce_indices = expr_info.reduce_indices + def _gather_structured_op( @@ -1821,9 +1830,10 @@ structop_inputs: The resulting list of IndexExpr that provide input to the current structured op. """ - if (expr != root and expr not in structop_inputs) and ( - isinstance(expr, Access) or - (expr in expr_to_info and expr_to_info[expr].structop_info)): + if ((expr != root or isinstance(expr, Access)) and + expr not in structop_inputs) and (isinstance(expr, Access) or + (expr in expr_to_info and + expr_to_info[expr].structop_info)): structop_inputs.append(expr) @@ -1843,7 +1853,7 @@ An OperandDef in the linalg dialect for the input IndexExpr. """ op_info = expr_to_info[expr].structop_info - if op_info: + if op_info and not isinstance(expr, Access): # The input is a temporary tensor produced by another structured op. indices = op_info.dst_indices name = op_info.dst_name diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py @@ -37,3 +37,42 @@ passed += (a.shape[0] == 5) # CHECK: Number of passed: 3 print("Number of passed:", passed) + + +# CHECK-LABEL: test_tensor_copy +@testing_utils.run_test +def test_tensor_copy(): + i, j = mlir_pytaco.get_index_vars(2) + I = 2 + J = 3 + A = mlir_pytaco.Tensor([I, J]) + A.insert([0, 1], 5.0) + A.insert([1, 2], 6.0) + B = mlir_pytaco.Tensor([I, J]) + B[i, j] = A[i, j] + indices, values = B.get_coordinates_and_values() + passed = np.allclose(indices, [[0, 1], [1, 2]]) + passed += np.allclose(values, [5.0, 6.0]) + + # CHECK: Number of passed: 2 + print("Number of passed:", passed) + + +# CHECK-LABEL: test_tensor_trivial_reduction +@testing_utils.run_test +def test_tensor_trivial_reduction(): + i, j = mlir_pytaco.get_index_vars(2) + I = 2 + J = 3 + A = mlir_pytaco.Tensor([I, J]) + A.insert([0, 1], 5.0) + A.insert([0, 2], 3.0) + A.insert([1, 2], 6.0) + B = mlir_pytaco.Tensor([I]) + B[i] = A[i, j] + indices, values = B.get_coordinates_and_values() + passed = np.allclose(indices, [[0], [1]]) + passed += np.allclose(values, [8.0, 6.0]) + + # CHECK: Number of passed: 2 + print("Number of passed:", passed)