diff --git a/mlir/test/Bindings/Python/execution_engine.py b/mlir/test/Bindings/Python/execution_engine.py --- a/mlir/test/Bindings/Python/execution_engine.py +++ b/mlir/test/Bindings/Python/execution_engine.py @@ -4,6 +4,7 @@ from mlir.ir import * from mlir.passmanager import * from mlir.execution_engine import * +import numpy as np # Log everything to stderr and flush so that we have a unified stream to match # errors/info emitted by MLIR to stderr. @@ -131,3 +132,79 @@ log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2)) run(testBasicCallback) + + + +def make_rankedmemref_descriptor(rank, dtype): + ''' + Build an empty descriptor for the give rank/dtype + ''' + fields = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ("shape", ctypes.c_longlong * rank), + ("strides", ctypes.c_longlong * rank), + ] + return type("Memref_" + str(rank) + "D", (ctypes.Structure,), {"_fields_": fields}) + +def to_memref(nparray): + ''' + Return a ranked memref descriptor for the given numpy array. + ''' + x = make_rankedmemref_descriptor(nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype))() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) + x.offset = ctypes.c_longlong(nparray.dtype.itemsize) + x.shape = nparray.ctypes.shape + x.strides = nparray.ctypes.strides + return x + +class UnrankedMemRefDescriptor(ctypes.Structure): + """ Creates a ctype struct for memref descriptor""" + _fields_ = [ + ("rank", ctypes.c_longlong), + ("descriptor", ctypes.c_void_p) + ] + +def make_unrankedmemref_descriptor(nparray): + ''' + Return a generic/unranked memref descriptor for the given numpy array. + ''' + d = UnrankedMemRefDescriptor() + d.rank = nparray.ndim + x = to_memref(nparray) + d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) + return d + + +# Test callback with a memref +# CHECK-LABEL: TEST: testMemRefCallback +def testMemRefCallback(): + # Define a callback function that takes a memref, converts it to a numpy array and prints it. + @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) + def callback(a): + d = make_rankedmemref_descriptor(a[0].rank, ctypes.c_float) + x = ctypes.cast(a[0].descriptor, ctypes.POINTER(d)) + arr = np.ctypeslib.as_array(x[0].aligned, shape=x[0].shape) + print("Inside callback: ") + print (arr) + + with Context(): + # The module just forwards to a runtime function known as "some_callback_into_python". + module = Module.parse(r""" +func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } { + call @some_callback_into_python(%arg0) : (memref<*xf32>) -> () + return +} +func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface } +""") + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.register_runtime("some_callback_into_python", callback) + inp_arr = np.array([[1.,2.],[3.,4.]], np.float32) + # CHECK: Inside callback: + # CHECK: [[1., 2. + # CHECK: 3., 4.]] + execution_engine.invoke("callback_memref", ctypes.pointer(ctypes.pointer(make_unrankedmemref_descriptor(inp_arr)))) + +run(testMemRefCallback)