diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -23,13 +23,14 @@ _fields_ = [("f16", ctypes.c_int16)] +# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): """Converts dtype to ctype.""" - if dtp is np.dtype(np.complex128): + if dtp == np.dtype(np.complex128): return C128 - if dtp is np.dtype(np.complex64): + if dtp == np.dtype(np.complex64): return C64 - if dtp is np.dtype(np.float16): + if dtp == np.dtype(np.float16): return F16 return np.ctypeslib.as_ctypes_type(dtp)