diff --git a/mlir/lib/Bindings/Python/mlir/runtime/__init__.py b/mlir/lib/Bindings/Python/mlir/runtime/__init__.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/runtime/__init__.py @@ -0,0 +1 @@ +from .np_to_memref import * diff --git a/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py b/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py @@ -0,0 +1,119 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This file contains functions to convert between Memrefs and NumPy arrays and vice-versa. + +import numpy as np +import ctypes + + +def make_nd_memref_descriptor(rank, dtype): + class MemRefDescriptor(ctypes.Structure): + """ + Build an empty descriptor for the given rank/dtype, where rank>0. + """ + + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ("shape", ctypes.c_longlong * rank), + ("strides", ctypes.c_longlong * rank), + ] + + return MemRefDescriptor + + +def make_zero_d_memref_descriptor(dtype): + class MemRefDescriptor(ctypes.Structure): + """ + Build an empty descriptor for the given dtype, where rank=0. + """ + + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ] + + return MemRefDescriptor + + +class UnrankedMemRefDescriptor(ctypes.Structure): + """ Creates a ctype struct for memref descriptor""" + + _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] + + +def get_ranked_memref_descriptor(nparray): + """ + Return a ranked memref descriptor for the given numpy array. + """ + if nparray.ndim == 0: + x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as( + ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) + ) + x.offset = ctypes.c_longlong(0) + return x + + x = make_nd_memref_descriptor( + nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype) + )() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as( + ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) + ) + x.offset = ctypes.c_longlong(0) + x.shape = nparray.ctypes.shape + + # Numpy uses byte quantities to express strides, MLIR OTOH uses the + # torch abstraction which specifies strides in terms of elements. + strides_ctype_t = ctypes.c_longlong * nparray.ndim + x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) + return x + + +def get_unranked_memref_descriptor(nparray): + """ + Return a generic/unranked memref descriptor for the given numpy array. + """ + d = UnrankedMemRefDescriptor() + d.rank = nparray.ndim + x = get_ranked_memref_descriptor(nparray) + d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) + return d + + +def unranked_memref_to_numpy(unranked_memref, np_dtype): + """ + Converts unranked memrefs to numpy arrays. + """ + descriptor = make_nd_memref_descriptor( + unranked_memref[0].rank, np.ctypeslib.as_ctypes_type(np_dtype) + ) + val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) + np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(val[0].shape), + np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, + ) + return strided_arr + + +def ranked_memref_to_numpy(ranked_memref): + """ + Converts ranked memrefs to numpy arrays. + """ + np_arr = np.ctypeslib.as_array( + ranked_memref[0].aligned, shape=ranked_memref[0].shape + ) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(ranked_memref[0].shape), + np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, + ) + return strided_arr diff --git a/mlir/test/Bindings/Python/execution_engine.py b/mlir/test/Bindings/Python/execution_engine.py --- a/mlir/test/Bindings/Python/execution_engine.py +++ b/mlir/test/Bindings/Python/execution_engine.py @@ -4,6 +4,7 @@ from mlir.ir import * from mlir.passmanager import * from mlir.execution_engine import * +from mlir.runtime import * # Log everything to stderr and flush so that we have a unified stream to match # errors/info emitted by MLIR to stderr. @@ -131,3 +132,179 @@ log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2)) run(testBasicCallback) + +# Test callback with an unranked memref +# CHECK-LABEL: TEST: testUnrankedMemRefCallback +def testUnrankedMemRefCallback(): + # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it. + @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) + def callback(a): + arr = unranked_memref_to_numpy(a, np.float32) + log("Inside callback: ") + log(arr) + + with Context(): + # The module just forwards to a runtime function known as "some_callback_into_python". + module = Module.parse( + r""" +func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } { + call @some_callback_into_python(%arg0) : (memref<*xf32>) -> () + return +} +func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface } +""" + ) + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.register_runtime("some_callback_into_python", callback) + inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32) + # CHECK: Inside callback: + # CHECK{LITERAL}: [[1. 2.] + # CHECK{LITERAL}: [3. 4.]] + execution_engine.invoke( + "callback_memref", + ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))), + ) + inp_arr_1 = np.array([5, 6, 7], dtype=np.float32) + strided_arr = np.lib.stride_tricks.as_strided( + inp_arr_1, strides=(4, 0), shape=(3, 4) + ) + # CHECK: Inside callback: + # CHECK{LITERAL}: [[5. 5. 5. 5.] + # CHECK{LITERAL}: [6. 6. 6. 6.] + # CHECK{LITERAL}: [7. 7. 7. 7.]] + execution_engine.invoke( + "callback_memref", + ctypes.pointer( + ctypes.pointer(get_unranked_memref_descriptor(strided_arr)) + ), + ) + +run(testUnrankedMemRefCallback) + +# Test callback with a ranked memref. +# CHECK-LABEL: TEST: testRankedMemRefCallback +def testRankedMemRefCallback(): + # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it. + @ctypes.CFUNCTYPE( + None, + ctypes.POINTER( + make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32)) + ), + ) + def callback(a): + arr = ranked_memref_to_numpy(a) + log("Inside Callback: ") + log(arr) + + with Context(): + # The module just forwards to a runtime function known as "some_callback_into_python". + module = Module.parse( + r""" +func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } { + call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> () + return +} +func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface } +""" + ) + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.register_runtime("some_callback_into_python", callback) + inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32) + # CHECK: Inside Callback: + # CHECK{LITERAL}: [[1. 5.] + # CHECK{LITERAL}: [6. 7.]] + execution_engine.invoke( + "callback_memref", ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))) + ) + +run(testRankedMemRefCallback) + +# Test addition of two memref +# CHECK-LABEL: TEST: testMemrefAdd +def testMemrefAdd(): + with Context(): + module = Module.parse( + """ + module { + func @main(%arg0: memref<1xf32>, %arg1: memref, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } { + %0 = constant 0 : index + %1 = memref.load %arg0[%0] : memref<1xf32> + %2 = memref.load %arg1[] : memref + %3 = addf %1, %2 : f32 + memref.store %3, %arg2[%0] : memref<1xf32> + return + } + } """ + ) + arg1 = np.array([32.5]).astype(np.float32) + arg2 = np.array(6).astype(np.float32) + res = np.array([0]).astype(np.float32) + + arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1))) + arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2))) + res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res))) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke( + "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr + ) + # CHECK: [32.5] + 6.0 = [38.5] + log("{0} + {1} = {2}".format(arg1, arg2, res)) + +run(testMemrefAdd) + +# Test addition of two 2d_memref +# CHECK-LABEL: TEST: testDynamicMemrefAdd2D +def testDynamicMemrefAdd2D(): + with Context(): + module = Module.parse( + """ + module { + func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} { + %c0 = constant 0 : index + %c2 = constant 2 : index + %c1 = constant 1 : index + br ^bb1(%c0 : index) + ^bb1(%0: index): // 2 preds: ^bb0, ^bb5 + %1 = cmpi slt, %0, %c2 : index + cond_br %1, ^bb2, ^bb6 + ^bb2: // pred: ^bb1 + %c0_0 = constant 0 : index + %c2_1 = constant 2 : index + %c1_2 = constant 1 : index + br ^bb3(%c0_0 : index) + ^bb3(%2: index): // 2 preds: ^bb2, ^bb4 + %3 = cmpi slt, %2, %c2_1 : index + cond_br %3, ^bb4, ^bb5 + ^bb4: // pred: ^bb3 + %4 = memref.load %arg0[%0, %2] : memref<2x2xf32> + %5 = memref.load %arg1[%0, %2] : memref + %6 = addf %4, %5 : f32 + memref.store %6, %arg2[%0, %2] : memref<2x2xf32> + %7 = addi %2, %c1_2 : index + br ^bb3(%7 : index) + ^bb5: // pred: ^bb3 + %8 = addi %0, %c1 : index + br ^bb1(%8 : index) + ^bb6: // pred: ^bb1 + return + } + } + """ + ) + arg1 = np.random.randn(2,2).astype(np.float32) + arg2 = np.random.randn(2,2).astype(np.float32) + res = np.random.randn(2,2).astype(np.float32) + + arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1))) + arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2))) + res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res))) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke( + "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr + ) + # CHECK: True + log(np.allclose(arg1+arg2, res)) + +run(testDynamicMemrefAdd2D)