diff --git a/mlir/lib/Bindings/Python/mlir/runtime/__init__.py b/mlir/lib/Bindings/Python/mlir/runtime/__init__.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/runtime/__init__.py @@ -0,0 +1 @@ +from .np_to_memref import * diff --git a/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py b/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py @@ -0,0 +1,118 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This file contains functions to convert between Memrefs and NumPy arrays and vice-versa. + +import numpy as np +import ctypes + +def make_nd_memref_descriptor(rank, dtype): + class MemRefDescriptor(ctypes.Structure): + """ + Build an empty descriptor for the given rank/dtype, where rank>0. + """ + + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ("shape", ctypes.c_longlong * rank), + ("strides", ctypes.c_longlong * rank), + ] + + return MemRefDescriptor + + +def make_zero_d_memref_descriptor(dtype): + class MemRefDescriptor(ctypes.Structure): + """ + Build an empty descriptor for the given dtype, where rank=0. + """ + + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ] + + return MemRefDescriptor + + +class UnrankedMemRefDescriptor(ctypes.Structure): + """ Creates a ctype struct for memref descriptor""" + + _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] + + +def get_ranked_memref_descriptor(nparray): + """ + Return a ranked memref descriptor for the given numpy array. + """ + if nparray.ndim == 0: + x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as( + ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) + ) + x.offset = ctypes.c_longlong(nparray.dtype.itemsize) + return x + + x = make_nd_memref_descriptor( + nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype) + )() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as( + ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) + ) + x.offset = ctypes.c_longlong(nparray.dtype.itemsize) + x.shape = nparray.ctypes.shape + + # Numpy uses byte quantities to express strides, MLIR OTOH uses the + # torch abstraction which specifies strides in terms of elements. + strides_ctype_t = ctypes.c_longlong * nparray.ndim + x.strides = strides_ctype_t(*[x//nparray.itemsize for x in nparray.strides]) + return x + + +def get_unranked_memref_descriptor(nparray): + """ + Return a generic/unranked memref descriptor for the given numpy array. + """ + d = UnrankedMemRefDescriptor() + d.rank = nparray.ndim + x = get_ranked_memref_descriptor(nparray) + d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) + return d + + +def unranked_memref_to_numpy(unranked_memref, np_dtype): + """ + Converts unranked memrefs to numpy arrays. + """ + descriptor = make_nd_memref_descriptor( + unranked_memref[0].rank, np.ctypeslib.as_ctypes_type(np_dtype) + ) + val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) + np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(val[0].shape), + np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, + ) + return strided_arr + + +def ranked_memref_to_numpy(ranked_memref): + """ + Converts ranked memrefs to numpy arrays. + """ + np_arr = np.ctypeslib.as_array( + ranked_memref[0].aligned, shape=ranked_memref[0].shape + ) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(ranked_memref[0].shape), + np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, + ) + return strided_arr diff --git a/mlir/test/Bindings/Python/execution_engine.py b/mlir/test/Bindings/Python/execution_engine.py --- a/mlir/test/Bindings/Python/execution_engine.py +++ b/mlir/test/Bindings/Python/execution_engine.py @@ -4,6 +4,7 @@ from mlir.ir import * from mlir.passmanager import * from mlir.execution_engine import * +from mlir.runtime import * # Log everything to stderr and flush so that we have a unified stream to match # errors/info emitted by MLIR to stderr. @@ -131,3 +132,123 @@ log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2)) run(testBasicCallback) + +# Test callback with an unranked memref +# CHECK-LABEL: TEST: testUnrankedMemRefCallback +def testUnrankedMemRefCallback(): + # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it. + @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) + def callback(a): + arr = unranked_memref_to_numpy(a, np.float32) + print("Inside callback: ") + print(arr) + + with Context(): + # The module just forwards to a runtime function known as "some_callback_into_python". + module = Module.parse( + r""" +func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } { + call @some_callback_into_python(%arg0) : (memref<*xf32>) -> () + return +} +func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface } +""" + ) + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.register_runtime("some_callback_into_python", callback) + inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32) + # CHECK: Inside callback: + # CHECK{LITERAL}: [[1. 2.] + # CHECK{LITERAL}: [3. 4.]] + execution_engine.invoke( + "callback_memref", + ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))), + ) + inp_arr_1 = np.array([5, 6, 7], dtype=np.float32) + strided_arr = np.lib.stride_tricks.as_strided( + inp_arr_1, strides=(4, 0), shape=(3, 4) + ) + # CHECK: Inside callback: + # CHECK{LITERAL}: [[5. 5. 5. 5.] + # CHECK{LITERAL}: [6. 6. 6. 6.] + # CHECK{LITERAL}: [7. 7. 7. 7.]] + execution_engine.invoke( + "callback_memref", + ctypes.pointer( + ctypes.pointer(get_unranked_memref_descriptor(strided_arr)) + ), + ) + +run(testUnrankedMemRefCallback) + +# Test callback with a ranked memref. +# CHECK-LABEL: TEST: testRankedMemRefCallback +def testRankedMemRefCallback(): + # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it. + @ctypes.CFUNCTYPE( + None, + ctypes.POINTER( + make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32)) + ), + ) + def callback(a): + arr = ranked_memref_to_numpy(a) + print("Inside Callback: ") + print(arr) + + with Context(): + # The module just forwards to a runtime function known as "some_callback_into_python". + module = Module.parse( + r""" +func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } { + call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> () + return +} +func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface } +""" + ) + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.register_runtime("some_callback_into_python", callback) + inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32) + # CHECK: Inside Callback: + # CHECK{LITERAL}: [[1. 5.] + # CHECK{LITERAL}: [6. 7.]] + execution_engine.invoke( + "callback_memref", ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))) + ) + +run(testRankedMemRefCallback) + +# Test addition of two memref +# CHECK-LABEL: TEST: testMemrefAdd +def testMemrefAdd(): + with Context(): + module = Module.parse( + """ + module { + func @main(%arg0: memref<1xf32>, %arg1: memref, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } { + %0 = constant 0 : index + %1 = memref.load %arg0[%0] : memref<1xf32> + %2 = memref.load %arg1[] : memref + %3 = addf %1, %2 : f32 + memref.store %3, %arg2[%0] : memref<1xf32> + return + } + } """ + ) + arg1 = np.array([32.5]).astype(np.float32) + arg2 = np.array(6).astype(np.float32) + res = np.array([0]).astype(np.float32) + + arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1))) + arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2))) + res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res))) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke( + "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr + ) + # CHECK: [32.5] + 6.0 = [38.5] + log("{0} + {1} = {2}".format(arg1, arg2, res)) + +run(testMemrefAdd)