diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py @@ -12,7 +12,9 @@ dense = pt.dense passed = 0 -all_types = [pt.int8, pt.int16, pt.int32, pt.int64, pt.float32, pt.float64] +all_types = [ + pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64 +] for t in all_types: i, j = pt.get_index_vars(2) A = pt.tensor([2, 3], dtype=t) @@ -29,5 +31,5 @@ passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]]) passed += np.allclose(values, [20.0, 10.0, 70.0]) -# CHECK: Number of passed: 18 +# CHECK: Number of passed: 21 print("Number of passed:", passed) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -72,7 +72,7 @@ INT16 = np.int16 INT32 = np.int32 INT64 = np.int64 - # numpy _ctype_from_dtype_scalar can't handle np.float16 yet. + FLOAT16 = np.float16 FLOAT32 = np.float32 FLOAT64 = np.float64 COMPLEX64 = np.complex64 @@ -80,15 +80,15 @@ # All floating point type enums. -_FLOAT_TYPES = (Type.FLOAT32, Type.FLOAT64) +_FLOAT_TYPES = (Type.FLOAT16, Type.FLOAT32, Type.FLOAT64) # All integral type enums. _INT_TYPES = (Type.INT8, Type.INT16, Type.INT32, Type.INT64) # All complex type enums. _COMPLEX_TYPES = (Type.COMPLEX64, Type.COMPLEX128) # Type alias for any numpy type used to implement the runtime support for the # enum data types. -_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float32, - np.float64, np.complex64, np.complex128] +_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float16, + np.float32, np.float64, np.complex64, np.complex128] @dataclasses.dataclass(frozen=True) @@ -132,6 +132,7 @@ Type.INT16: "i16", Type.INT32: "i32", Type.INT64: "i64", + Type.FLOAT16: "f16", Type.FLOAT32: "f32", Type.FLOAT64: "f64", Type.COMPLEX64: "complex", @@ -147,6 +148,7 @@ np.int16: Type.INT16, np.int32: Type.INT32, np.int64: Type.INT64, + np.float16: Type.FLOAT16, np.float32: Type.FLOAT32, np.float64: Type.FLOAT64, np.complex64: Type.COMPLEX64, @@ -162,6 +164,7 @@ Type.INT16: ir.IntegerType.get_signless(16), Type.INT32: ir.IntegerType.get_signless(32), Type.INT64: ir.IntegerType.get_signless(64), + Type.FLOAT16: ir.F16Type.get(), Type.FLOAT32: ir.F32Type.get(), Type.FLOAT64: ir.F64Type.get(), Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()), diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py @@ -39,6 +39,7 @@ int16 = mlir_pytaco.DType(mlir_pytaco.Type.INT16) int32 = mlir_pytaco.DType(mlir_pytaco.Type.INT32) int64 = mlir_pytaco.DType(mlir_pytaco.Type.INT64) +float16 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT16) float32 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32) float64 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64) complex64 = mlir_pytaco.DType(mlir_pytaco.Type.COMPLEX64) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py @@ -89,6 +89,8 @@ c_lib.convertFromMLIRSparseTensorI32), (np.int64, c_lib.convertToMLIRSparseTensorI64, c_lib.convertFromMLIRSparseTensorI64), + (np.float16, c_lib.convertToMLIRSparseTensorF16, + c_lib.convertFromMLIRSparseTensorF16), (np.float32, c_lib.convertToMLIRSparseTensorF32, c_lib.convertFromMLIRSparseTensorF32), (np.float64, c_lib.convertToMLIRSparseTensorF64,