diff --git a/mlir/lib/Bindings/Python/mlir/runtime/__init__.py b/mlir/lib/Bindings/Python/mlir/runtime/__init__.py new file mode 100644 diff --git a/mlir/lib/Bindings/Python/mlir/runtime/memref/__init__.py b/mlir/lib/Bindings/Python/mlir/runtime/memref/__init__.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/runtime/memref/__init__.py @@ -0,0 +1 @@ +from .np_to_memref import * diff --git a/mlir/lib/Bindings/Python/mlir/runtime/memref/np_to_memref.py b/mlir/lib/Bindings/Python/mlir/runtime/memref/np_to_memref.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/runtime/memref/np_to_memref.py @@ -0,0 +1,81 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import numpy as np +import ctypes + + +def make_nd_memref_descriptor(rank, dtype): + class MemRefDescriptor(ctypes.Structure): + """ + Build an empty descriptor for the given rank/dtype, where rank>0. + """ + + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ("shape", ctypes.c_longlong * rank), + ("strides", ctypes.c_longlong * rank), + ] + + return MemRefDescriptor + + +def make_zero_d_memref_descriptor(dtype): + class MemRefDescriptor(ctypes.Structure): + """ + Build an empty descriptor for the given rank/dtype, where rank=0. + """ + + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ] + + return MemRefDescriptor + + +class UnrankedMemRefDescriptor(ctypes.Structure): + """ Creates a ctype struct for memref descriptor""" + + _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] + + +def to_memref(nparray): + """ + Return a ranked memref descriptor for the given numpy array. + """ + if nparray.ndim == 0: + x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as( + ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) + ) + x.offset = ctypes.c_longlong(nparray.dtype.itemsize) + return x + + x = make_nd_memref_descriptor( + nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype) + )() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as( + ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) + ) + x.offset = ctypes.c_longlong(nparray.dtype.itemsize) + x.shape = nparray.ctypes.shape + x.strides = nparray.ctypes.strides + return x + + +def make_unranked_memref_descriptor(nparray): + """ + Return a generic/unranked memref descriptor for the given numpy array. + """ + d = UnrankedMemRefDescriptor() + d.rank = nparray.ndim + x = to_memref(nparray) + d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) + return d diff --git a/mlir/test/Bindings/Python/execution_engine.py b/mlir/test/Bindings/Python/execution_engine.py --- a/mlir/test/Bindings/Python/execution_engine.py +++ b/mlir/test/Bindings/Python/execution_engine.py @@ -4,6 +4,7 @@ from mlir.ir import * from mlir.passmanager import * from mlir.execution_engine import * +from mlir.runtime.memref import * # Log everything to stderr and flush so that we have a unified stream to match # errors/info emitted by MLIR to stderr. @@ -131,3 +132,64 @@ log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2)) run(testBasicCallback) + +# Test callback with a memref +# CHECK-LABEL: TEST: testMemRefCallback +def testMemRefCallback(): + # Define a callback function that takes a memref, converts it to a numpy array and prints it. + @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) + def callback(a): + d = make_nd_memref_descriptor(a[0].rank, ctypes.c_float) + x = ctypes.cast(a[0].descriptor, ctypes.POINTER(d)) + arr = np.ctypeslib.as_array(x[0].aligned, shape=x[0].shape) + print("Inside callback: ") + print (arr) + + with Context(): + # The module just forwards to a runtime function known as "some_callback_into_python". + module = Module.parse(r""" +func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } { + call @some_callback_into_python(%arg0) : (memref<*xf32>) -> () + return +} +func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface } +""") + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.register_runtime("some_callback_into_python", callback) + inp_arr = np.array([[1.,2.],[3.,4.]], np.float32) + # CHECK: Inside callback: + # CHECK: [[1. 2. + # CHECK: 3. 4.]] + execution_engine.invoke("callback_memref", ctypes.pointer(ctypes.pointer(make_unranked_memref_descriptor(inp_arr)))) + +run(testMemRefCallback) + +def testInvokeMemrefAdd(): + with Context(): + module = Module.parse( + """ + module { + func @main(%arg0: memref<1xf32>, %arg1: memref, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } { + %0 = constant 0 : index + %1 = memref.load %arg0[%0] : memref<1xf32> + %2 = memref.load %arg1[] : memref + %3 = addf %1, %2 : f32 + memref.store %3, %arg2[%0] : memref<1xf32> + return + } + } """ + ) + arg1 = np.array([32.5]).astype(np.float32) + arg2 = np.array(6).astype(np.float32) + res = np.array([0]).astype(np.float32) + + arg1_memref_ptr = ctypes.pointer(ctypes.pointer(to_memref(arg1))) + arg2_memref_ptr = ctypes.pointer(ctypes.pointer(to_memref(arg2))) + res_memref_ptr = ctypes.pointer(ctypes.pointer(to_memref(res))) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr) + # CHECK: [32.5] + 6.0 = [38.5] + log("{0} + {1} = {2}".format(arg1, arg2, res)) + +run(testInvokeMemrefAdd)