diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -25,11 +25,11 @@ def as_ctype(dtp): """Converts dtype to ctype.""" - if dtp is np.dtype(np.complex128): + if dtp == np.dtype(np.complex128): return C128 - if dtp is np.dtype(np.complex64): + if dtp == np.dtype(np.complex64): return C64 - if dtp is np.dtype(np.float16): + if dtp == np.dtype(np.float16): return F16 return np.ctypeslib.as_ctypes_type(dtp)