diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -8,112 +8,121 @@ import ctypes +class C128(ctypes.Structure): + """A ctype representation for MLIR's Double Complex.""" + _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)] + + +class C64(ctypes.Structure): + """A ctype representation for MLIR's Float Complex.""" + _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)] + + +def as_ctype(dtp): + """Converts dtype to ctype.""" + if dtp is np.dtype(np.complex128): + return C128 + if dtp is np.dtype(np.complex64): + return C64 + return np.ctypeslib.as_ctypes_type(dtp) + + def make_nd_memref_descriptor(rank, dtype): - class MemRefDescriptor(ctypes.Structure): - """ - Build an empty descriptor for the given rank/dtype, where rank>0. - """ - _fields_ = [ - ("allocated", ctypes.c_longlong), - ("aligned", ctypes.POINTER(dtype)), - ("offset", ctypes.c_longlong), - ("shape", ctypes.c_longlong * rank), - ("strides", ctypes.c_longlong * rank), - ] + class MemRefDescriptor(ctypes.Structure): + """Builds an empty descriptor for the given rank/dtype, where rank>0.""" - return MemRefDescriptor + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ("shape", ctypes.c_longlong * rank), + ("strides", ctypes.c_longlong * rank), + ] + + return MemRefDescriptor def make_zero_d_memref_descriptor(dtype): - class MemRefDescriptor(ctypes.Structure): - """ - Build an empty descriptor for the given dtype, where rank=0. - """ - _fields_ = [ - ("allocated", ctypes.c_longlong), - ("aligned", ctypes.POINTER(dtype)), - ("offset", ctypes.c_longlong), - ] + class MemRefDescriptor(ctypes.Structure): + """Builds an empty descriptor for the given dtype, where rank=0.""" - return MemRefDescriptor + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ] + return MemRefDescriptor -class UnrankedMemRefDescriptor(ctypes.Structure): - """ Creates a ctype struct for memref descriptor""" - _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] +class UnrankedMemRefDescriptor(ctypes.Structure): + """Creates a ctype struct for memref descriptor""" + _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] def get_ranked_memref_descriptor(nparray): - """ - Return a ranked memref descriptor for the given numpy array. - """ - if nparray.ndim == 0: - x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))() - x.allocated = nparray.ctypes.data - x.aligned = nparray.ctypes.data_as( - ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) - ) - x.offset = ctypes.c_longlong(0) - return x - - x = make_nd_memref_descriptor( - nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype) - )() + """Returns a ranked memref descriptor for the given numpy array.""" + ctp = as_ctype(nparray.dtype) + if nparray.ndim == 0: + x = make_zero_d_memref_descriptor(ctp)() x.allocated = nparray.ctypes.data - x.aligned = nparray.ctypes.data_as( - ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) - ) + x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) x.offset = ctypes.c_longlong(0) - x.shape = nparray.ctypes.shape - - # Numpy uses byte quantities to express strides, MLIR OTOH uses the - # torch abstraction which specifies strides in terms of elements. - strides_ctype_t = ctypes.c_longlong * nparray.ndim - x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) return x + x = make_nd_memref_descriptor(nparray.ndim, ctp)() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) + x.offset = ctypes.c_longlong(0) + x.shape = nparray.ctypes.shape + + # Numpy uses byte quantities to express strides, MLIR OTOH uses the + # torch abstraction which specifies strides in terms of elements. + strides_ctype_t = ctypes.c_longlong * nparray.ndim + x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) + return x + def get_unranked_memref_descriptor(nparray): - """ - Return a generic/unranked memref descriptor for the given numpy array. - """ - d = UnrankedMemRefDescriptor() - d.rank = nparray.ndim - x = get_ranked_memref_descriptor(nparray) - d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) - return d + """Returns a generic/unranked memref descriptor for the given numpy array.""" + d = UnrankedMemRefDescriptor() + d.rank = nparray.ndim + x = get_ranked_memref_descriptor(nparray) + d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) + return d def unranked_memref_to_numpy(unranked_memref, np_dtype): - """ - Converts unranked memrefs to numpy arrays. - """ - descriptor = make_nd_memref_descriptor( - unranked_memref[0].rank, np.ctypeslib.as_ctypes_type(np_dtype) - ) - val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) - np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) - strided_arr = np.lib.stride_tricks.as_strided( - np_arr, - np.ctypeslib.as_array(val[0].shape), - np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, - ) - return strided_arr + """Converts unranked memrefs to numpy arrays.""" + ctp = as_ctype(np_dtype) + descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp) + val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) + np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(val[0].shape), + np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, + ) + if strided_arr.dtype == C128: + return strided_arr.view("complex128") + if strided_arr.dtype == C64: + return strided_arr.view("complex64") + return strided_arr def ranked_memref_to_numpy(ranked_memref): - """ - Converts ranked memrefs to numpy arrays. - """ - np_arr = np.ctypeslib.as_array( - ranked_memref[0].aligned, shape=ranked_memref[0].shape - ) - strided_arr = np.lib.stride_tricks.as_strided( - np_arr, - np.ctypeslib.as_array(ranked_memref[0].shape), - np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, - ) - return strided_arr + """Converts ranked memrefs to numpy arrays.""" + np_arr = np.ctypeslib.as_array( + ranked_memref[0].aligned, shape=ranked_memref[0].shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(ranked_memref[0].shape), + np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, + ) + if strided_arr.dtype == C128: + return strided_arr.view("complex128") + if strided_arr.dtype == C64: + return strided_arr.view("complex64") + return strided_arr diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -64,7 +64,7 @@ def lowerToLLVM(module): import mlir.conversions pm = PassManager.parse( - "convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts") + "convert-complex-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts") pm.run(module) return module @@ -266,6 +266,102 @@ run(testMemrefAdd) +# Test addition of two complex memrefs +# CHECK-LABEL: TEST: testComplexMemrefAdd +def testComplexMemrefAdd(): + with Context(): + module = Module.parse(""" + module { + func.func @main(%arg0: memref<1xcomplex>, + %arg1: memref<1xcomplex>, + %arg2: memref<1xcomplex>) attributes { llvm.emit_c_interface } { + %0 = arith.constant 0 : index + %1 = memref.load %arg0[%0] : memref<1xcomplex> + %2 = memref.load %arg1[%0] : memref<1xcomplex> + %3 = complex.add %1, %2 : complex + memref.store %3, %arg2[%0] : memref<1xcomplex> + return + } + } """) + + arg1 = np.array([1.+2.j]).astype(np.complex128) + arg2 = np.array([3.+4.j]).astype(np.complex128) + arg3 = np.array([0.+0.j]).astype(np.complex128) + + arg1_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg1))) + arg2_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg2))) + arg3_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg3))) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke("main", + arg1_memref_ptr, + arg2_memref_ptr, + arg3_memref_ptr) + # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j] + log("{0} + {1} = {2}".format(arg1, arg2, arg3)) + + # test to-numpy utility + # CHECK: [4.+6.j] + npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) + log(npout) + + +run(testComplexMemrefAdd) + + +# Test addition of two complex unranked memrefs +# CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd +def testComplexUnrankedMemrefAdd(): + with Context(): + module = Module.parse(""" + module { + func.func @main(%arg0: memref<*xcomplex>, + %arg1: memref<*xcomplex>, + %arg2: memref<*xcomplex>) attributes { llvm.emit_c_interface } { + %A = memref.cast %arg0 : memref<*xcomplex> to memref<1xcomplex> + %B = memref.cast %arg1 : memref<*xcomplex> to memref<1xcomplex> + %C = memref.cast %arg2 : memref<*xcomplex> to memref<1xcomplex> + %0 = arith.constant 0 : index + %1 = memref.load %A[%0] : memref<1xcomplex> + %2 = memref.load %B[%0] : memref<1xcomplex> + %3 = complex.add %1, %2 : complex + memref.store %3, %C[%0] : memref<1xcomplex> + return + } + } """) + + arg1 = np.array([5.+6.j]).astype(np.complex64) + arg2 = np.array([7.+8.j]).astype(np.complex64) + arg3 = np.array([0.+0.j]).astype(np.complex64) + + arg1_memref_ptr = ctypes.pointer( + ctypes.pointer(get_unranked_memref_descriptor(arg1))) + arg2_memref_ptr = ctypes.pointer( + ctypes.pointer(get_unranked_memref_descriptor(arg2))) + arg3_memref_ptr = ctypes.pointer( + ctypes.pointer(get_unranked_memref_descriptor(arg3))) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke("main", + arg1_memref_ptr, + arg2_memref_ptr, + arg3_memref_ptr) + # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j] + log("{0} + {1} = {2}".format(arg1, arg2, arg3)) + + # test to-numpy utility + # CHECK: [12.+14.j] + npout = unranked_memref_to_numpy(arg3_memref_ptr[0], + np.dtype(np.complex64)) + log(npout) + + +run(testComplexUnrankedMemrefAdd) + + # Test addition of two 2d_memref # CHECK-LABEL: TEST: testDynamicMemrefAdd2D def testDynamicMemrefAdd2D():