diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -30,6 +30,7 @@ import threading # Import MLIR related modules. +from mlir import execution_engine from mlir import ir from mlir import runtime from mlir.dialects import arith @@ -644,6 +645,7 @@ dtype = dtype or DType(Type.FLOAT32) self._name = name or self._get_unique_name() self._assignment = None + self._engine = None self._sparse_value_location = _SparseValueInfo._UNPACKED self._dense_storage = None self._dtype = dtype @@ -978,17 +980,72 @@ f"len({indices}) != {self.order}.") self._assignment = _Assignment(indices, value) + self._engine = None - def evaluate(self) -> None: - """Evaluates the assignment to the tensor.""" - result = self._assignment.expression.evaluate(self, - self._assignment.indices) - self._assignment = None + def compile(self, force_recompile: bool = False) -> None: + """Compiles the tensor assignment to an execution engine. + + Calling compile the second time does not do anything unless + force_recompile is True. + + Args: + force_recompile: A boolean value to enable recompilation, such as for the + purpose of timing. + + Raises: + ValueError: If the assignment is not proper or not supported. + """ + if self._assignment is None or (self._engine is not None and + not force_recompile): + return + + self._engine = self._assignment.expression.compile(self, + self._assignment.indices) + + def compute(self) -> None: + """Executes the engine for the tensor assignment. + + Raises: + ValueError: If the assignment hasn't been compiled yet. + """ + if self._assignment is None: + return + + if self._engine is None: + raise ValueError("Need to invoke compile() before invoking compute().") + + input_accesses = self._assignment.expression.get_input_accesses() + # Gather the pointers for the input buffers. + input_pointers = [a.tensor.ctype_pointer() for a in input_accesses] + if self.is_dense(): + # The pointer to receive dense output is the first argument to the + # execution engine. + arg_pointers = [self.dense_dst_ctype_pointer()] + input_pointers + else: + # The pointer to receive the sparse tensor output is the last argument + # to the execution engine and is a pointer to pointer of char. + arg_pointers = input_pointers + [ + ctypes.pointer(ctypes.pointer(ctypes.c_char(0))) + ] + + # Invoke the execution engine to run the module. + self._engine.invoke(_ENTRY_NAME, *arg_pointers) + + # Retrieve the result. if self.is_dense(): + result = runtime.ranked_memref_to_numpy(arg_pointers[0][0]) assert isinstance(result, np.ndarray) self._dense_storage = result else: - self._set_packed_sparse_tensor(result) + self._set_packed_sparse_tensor(arg_pointers[-1][0]) + + self._assignment = None + self._engine = None + + def evaluate(self) -> None: + """Evaluates the tensor assignment.""" + self.compile() + self.compute() def _sync_value(self) -> None: """Updates the tensor value by evaluating the pending assignment.""" @@ -1444,29 +1501,31 @@ linalg_funcop.func_op.attributes[ "llvm.emit_c_interface"] = ir.UnitAttr.get() - def evaluate( + def get_input_accesses(self) -> List["Access"]: + """Compute the list of input accesses for the expression.""" + input_accesses = [] + self._visit(_gather_input_accesses_index_vars, (input_accesses,)) + return input_accesses + + def compile( self, dst: Tensor, dst_indices: Tuple[IndexVar, ...], - ) -> Union[np.ndarray, ctypes.c_void_p]: - """Evaluates tensor assignment dst[dst_indices] = expression. + ) -> execution_engine.ExecutionEngine: + """Compiles the tensor assignment dst[dst_indices] = expression. Args: dst: The destination tensor. dst_indices: The tuple of IndexVar used to access the destination tensor. Returns: - The result of the dense tensor represented in numpy ndarray or the pointer - to the MLIR sparse tensor. + The execution engine for the tensor assignment. Raises: ValueError: If the expression is not proper or not supported. """ expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices) - - # Compute a list of input accesses. - input_accesses = [] - self._visit(_gather_input_accesses_index_vars, (input_accesses,)) + input_accesses = self.get_input_accesses() # Build and compile the module to produce the execution engine. with ir.Context(), ir.Location.unknown(): @@ -1475,29 +1534,7 @@ input_accesses) engine = utils.compile_and_build_engine(module) - # Gather the pointers for the input buffers. - input_pointers = [a.tensor.ctype_pointer() for a in input_accesses] - if dst.is_dense(): - # The pointer to receive dense output is the first argument to the - # execution engine. - arg_pointers = [dst.dense_dst_ctype_pointer()] + input_pointers - else: - # The pointer to receive sparse output is the last argument to the - # execution engine. The pointer to receive a sparse tensor output is a - # pointer to pointer of char. - arg_pointers = input_pointers + [ - ctypes.pointer(ctypes.pointer(ctypes.c_char(0))) - ] - - # Invoke the execution engine to run the module and return the result. - engine.invoke(_ENTRY_NAME, *arg_pointers) - - if dst.is_dense(): - return runtime.ranked_memref_to_numpy(arg_pointers[0][0]) - - # Return the sparse tensor pointer. - return arg_pointers[-1][0] - + return engine @dataclasses.dataclass(frozen=True) class Access(IndexExpr): diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py @@ -50,11 +50,22 @@ A.insert([1, 2], 6.0) B = mlir_pytaco.Tensor([I, J]) B[i, j] = A[i, j] + passed = (B._assignment is not None) + passed += (B._engine is None) + try: + B.compute() + except ValueError as e: + passed += (str(e).startswith("Need to invoke compile")) + B.compile() + passed += (B._engine is not None) + B.compute() + passed += (B._assignment is None) + passed += (B._engine is None) indices, values = B.get_coordinates_and_values() - passed = np.array_equal(indices, [[0, 1], [1, 2]]) + passed += np.array_equal(indices, [[0, 1], [1, 2]]) passed += np.allclose(values, [5.0, 6.0]) - # CHECK: Number of passed: 2 + # CHECK: Number of passed: 8 print("Number of passed:", passed)