diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -114,13 +114,21 @@ d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) return d +def move_aligned_ptr_by_offset(aligned_ptr, offset): + """Moves the supplied ctypes pointer ahead by `offset` elements.""" + aligned_addr = ctypes.addressof(aligned_ptr.contents) + elem_size = ctypes.sizeof(aligned_ptr.contents) + shift = offset * elem_size + content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr)) + return content_ptr def unranked_memref_to_numpy(unranked_memref, np_dtype): """Converts unranked memrefs to numpy arrays.""" ctp = as_ctype(np_dtype) descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp) val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) - np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) + content_ptr = move_aligned_ptr_by_offset(val[0].aligned, val[0].offset) + np_arr = np.ctypeslib.as_array(content_ptr, shape=val[0].shape) strided_arr = np.lib.stride_tricks.as_strided( np_arr, np.ctypeslib.as_array(val[0].shape), @@ -131,8 +139,9 @@ def ranked_memref_to_numpy(ranked_memref): """Converts ranked memrefs to numpy arrays.""" + content_ptr = move_aligned_ptr_by_offset(ranked_memref[0].aligned, ranked_memref[0].offset) np_arr = np.ctypeslib.as_array( - ranked_memref[0].aligned, shape=ranked_memref[0].shape + content_ptr, shape=ranked_memref[0].shape ) strided_arr = np.lib.stride_tricks.as_strided( np_arr, diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -245,6 +245,87 @@ run(testRankedMemRefCallback) +# Test callback with a ranked memref with non-zero offset. +# CHECK-LABEL: TEST: testRankedMemRefWithOffsetCallback +def testRankedMemRefWithOffsetCallback(): + # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it. + @ctypes.CFUNCTYPE( + None, + ctypes.POINTER( + make_nd_memref_descriptor(1, np.ctypeslib.as_ctypes_type(np.float32)) + ), + ) + def callback(a): + arr = ranked_memref_to_numpy(a) + log("Inside Callback: ") + log(arr) + + with Context(): + # The module takes a subview of the argument memref and calls the callback with it + module = Module.parse( + r""" +func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} { + %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref, index, index, index + %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref to memref<2xf32, strided<[1], offset: 3>> + %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref> + call @some_callback_into_python(%cast) : (memref>) -> () + return +} +func.func private @some_callback_into_python(memref>) attributes {llvm.emit_c_interface} +""" + ) + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.register_runtime("some_callback_into_python", callback) + inp_arr = np.array([0, 0, 0, 1, 2], np.float32) + # CHECK: Inside Callback: + # CHECK{LITERAL}: [1. 2.] + execution_engine.invoke( + "callback_memref", + ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))), + ) + + +run(testRankedMemRefWithOffsetCallback) + + +# Test callback with an unranked memref with non-zero offset +# CHECK-LABEL: TEST: testUnrankedMemRefWithOffsetCallback +def testUnrankedMemRefWithOffsetCallback(): + # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it. + @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) + def callback(a): + arr = unranked_memref_to_numpy(a, np.float32) + log("Inside callback: ") + log(arr) + + with Context(): + # The module takes a subview of the argument memref, casts it to an unranked memref and + # calls the callback with it. + module = Module.parse( + r""" +func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} { + %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref, index, index, index + %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref to memref<2xf32, strided<[1], offset: 3>> + %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<*xf32> + call @some_callback_into_python(%cast) : (memref<*xf32>) -> () + return +} +func.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emit_c_interface} +""" + ) + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.register_runtime("some_callback_into_python", callback) + inp_arr = np.array([1, 2, 3, 4, 5], np.float32) + # CHECK: Inside callback: + # CHECK{LITERAL}: [4. 5.] + execution_engine.invoke( + "callback_memref", + ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))), + ) + +run(testUnrankedMemRefWithOffsetCallback) + + # Test addition of two memrefs. # CHECK-LABEL: TEST: testMemrefAdd def testMemrefAdd():