Differential D135004 Diff 464452 mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
Changeset View
Changeset View
Standalone View
Standalone View
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
Show First 20 Lines • Show All 347 Lines • ▼ Show 20 Lines | class Format: | ||||
def rank(self) -> int: | def rank(self) -> int: | ||||
"""Returns the number of dimensions represented by the format.""" | """Returns the number of dimensions represented by the format.""" | ||||
return self.format_pack.rank() | return self.format_pack.rank() | ||||
def get_permutation_and_sparsity(self) -> Tuple[np.ndarray, np.ndarray]: | def get_permutation_and_sparsity(self) -> Tuple[np.ndarray, np.ndarray]: | ||||
"""Constructs the numpy arrays for the permutation and sparsity.""" | """Constructs the numpy arrays for the permutation and sparsity.""" | ||||
perm = np.array(self.ordering.ordering, dtype=np.ulonglong) | perm = np.array(self.ordering.ordering, dtype=np.ulonglong) | ||||
a = [0 if s == ModeFormat.DENSE else 1 for s in self.format_pack.formats] | # FIXME(bixia): these magic numbers must be kept in sync with the | ||||
# definition of the enum on the C/C++ side. How can we convert a | |||||
# `ModeFormat` into its underlying `uint8_t` value so we can just say | |||||
# `convert(ModeFormat.DENSE)` and `convert(ModeFormat.COMPRESSED)`? | |||||
a = [4 if s == ModeFormat.DENSE else 8 for s in self.format_pack.formats] | |||||
sparse = np.array(a, dtype=np.uint8) | sparse = np.array(a, dtype=np.uint8) | ||||
return (perm, sparse) | return (perm, sparse) | ||||
def mlir_tensor_attr(self) -> Optional[sparse_tensor.EncodingAttr]: | def mlir_tensor_attr(self) -> Optional[sparse_tensor.EncodingAttr]: | ||||
"""Constructs the MLIR attributes for the tensor format.""" | """Constructs the MLIR attributes for the tensor format.""" | ||||
order = ( | order = ( | ||||
range(self.rank()) if | range(self.rank()) if | ||||
(self.ordering is None) else self.ordering.ordering) | (self.ordering is None) else self.ordering.ordering) | ||||
▲ Show 20 Lines • Show All 1,831 Lines • Show Last 20 Lines |