diff --git a/mlir/test/python/dialects/sparse_tensor/test_SpMM.py b/mlir/test/python/dialects/sparse_tensor/test_SpMM.py --- a/mlir/test/python/dialects/sparse_tensor/test_SpMM.py +++ b/mlir/test/python/dialects/sparse_tensor/test_SpMM.py @@ -55,24 +55,19 @@ def boilerplate(attr: st.EncodingAttr): """Returns boilerplate main method. - This method sets up a boilerplate main method that calls the generated - sparse kernel. For convenience, this part is purely done as string input. + This method sets up a boilerplate main method that takes three tensors + (a, b, c), converts the first tensor a into s sparse tensor, and then + calls the sparse kernel for matrix multiplication. For convenience, + this part is purely done as string input. """ return f""" -func @main(%c: tensor<3x2xf64>) -> tensor<3x2xf64> +func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>) -> tensor<3x2xf64> attributes {{ llvm.emit_c_interface }} {{ - %0 = constant dense<[ [ 1.1, 0.0, 0.0, 1.4 ], - [ 0.0, 0.0, 0.0, 0.0 ], - [ 0.0, 0.0, 3.3, 0.0 ]]> : tensor<3x4xf64> - %a = sparse_tensor.convert %0 : tensor<3x4xf64> to tensor<3x4xf64, {attr}> - %b = constant dense<[ [ 1.0, 2.0 ], - [ 4.0, 3.0 ], - [ 5.0, 6.0 ], - [ 8.0, 7.0 ]]> : tensor<4x2xf64> - %1 = call @spMxM(%a, %b, %c) : (tensor<3x4xf64, {attr}>, + %a = sparse_tensor.convert %ad : tensor<3x4xf64> to tensor<3x4xf64, {attr}> + %0 = call @spMxM(%a, %b, %c) : (tensor<3x4xf64, {attr}>, tensor<4x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> - return %1 : tensor<3x2xf64> + return %0 : tensor<3x2xf64> }} """ @@ -83,25 +78,34 @@ module = build_SpMM(attr) func = str(module.operation.regions[0].blocks[0].operations[0].operation) module = ir.Module.parse(func + boilerplate(attr)) + # Compile. compiler(module) engine = execution_engine.ExecutionEngine( module, opt_level=0, shared_libs=[support_lib]) - # Set up numpy input, invoke the kernel, and get numpy output. + + # Set up numpy input and buffer for output. + a = np.array( + [[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]], + np.float64) + b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64) + c = np.zeros((3, 2), np.float64) + out = np.zeros((3, 2), np.float64) + + mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a))) + mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b))) + mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c))) + mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(out))) + + # Invoke the kernel and get numpy output. # Built-in bufferization uses in-out buffers. # TODO: replace with inplace comprehensive bufferization. - Cin = np.zeros((3, 2), np.double) - Cout = np.zeros((3, 2), np.double) - Cin_memref_ptr = ctypes.pointer( - ctypes.pointer(rt.get_ranked_memref_descriptor(Cin))) - Cout_memref_ptr = ctypes.pointer( - ctypes.pointer(rt.get_ranked_memref_descriptor(Cout))) - engine.invoke('main', Cout_memref_ptr, Cin_memref_ptr) - Cresult = rt.ranked_memref_to_numpy(Cout_memref_ptr[0]) + engine.invoke('main', mem_out, mem_a, mem_b, mem_c) # Sanity check on computed result. - expected = [[12.3, 12.0], [0.0, 0.0], [16.5, 19.8]] - if np.allclose(Cresult, expected): + expected = np.matmul(a, b); + c = rt.ranked_memref_to_numpy(mem_out[0]) + if np.allclose(c, expected): pass else: quit(f'FAILURE') @@ -132,7 +136,10 @@ # CHECK: Passed 72 tests @run def testSpMM(): + # Obtain path to runtime support library. support_lib = os.getenv('SUPPORT_LIB') + assert os.path.exists(support_lib), f'{support_lib} does not exist' + with ir.Context() as ctx, ir.Location.unknown(): count = 0 # Fixed compiler optimization strategy.