diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -25,15 +25,16 @@ def as_ctype(dtp): """Converts dtype to ctype.""" - if dtp is np.dtype(np.complex128): + if dtp == np.dtype(np.complex128): return C128 - if dtp is np.dtype(np.complex64): + if dtp == np.dtype(np.complex64): return C64 - if dtp is np.dtype(np.float16): + if dtp == np.dtype(np.float16): return F16 return np.ctypeslib.as_ctypes_type(dtp) +# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def to_numpy(array): """Converts ctypes array back to numpy dtype array.""" if array.dtype == C128: