diff --git a/mlir/test/python/dialects/sparse_tensor/test_SpMM.py b/mlir/test/python/dialects/sparse_tensor/test_SpMM.py --- a/mlir/test/python/dialects/sparse_tensor/test_SpMM.py +++ b/mlir/test/python/dialects/sparse_tensor/test_SpMM.py @@ -59,20 +59,13 @@ sparse kernel. For convenience, this part is purely done as string input. """ return f""" -func @main(%c: tensor<3x2xf64>) -> tensor<3x2xf64> +func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>) -> tensor<3x2xf64> attributes {{ llvm.emit_c_interface }} {{ - %0 = constant dense<[ [ 1.1, 0.0, 0.0, 1.4 ], - [ 0.0, 0.0, 0.0, 0.0 ], - [ 0.0, 0.0, 3.3, 0.0 ]]> : tensor<3x4xf64> - %a = sparse_tensor.convert %0 : tensor<3x4xf64> to tensor<3x4xf64, {attr}> - %b = constant dense<[ [ 1.0, 2.0 ], - [ 4.0, 3.0 ], - [ 5.0, 6.0 ], - [ 8.0, 7.0 ]]> : tensor<4x2xf64> - %1 = call @spMxM(%a, %b, %c) : (tensor<3x4xf64, {attr}>, + %a = sparse_tensor.convert %ad : tensor<3x4xf64> to tensor<3x4xf64, {attr}> + %0 = call @spMxM(%a, %b, %c) : (tensor<3x4xf64, {attr}>, tensor<4x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> - return %1 : tensor<3x2xf64> + return %0 : tensor<3x2xf64> }} """ @@ -83,24 +76,38 @@ module = build_SpMM(attr) func = str(module.operation.regions[0].blocks[0].operations[0].operation) module = ir.Module.parse(func + boilerplate(attr)) + # Compile. compiler(module) engine = execution_engine.ExecutionEngine( module, opt_level=0, shared_libs=[support_lib]) - # Set up numpy input, invoke the kernel, and get numpy output. - # Built-in bufferization uses in-out buffers. - # TODO: replace with inplace comprehensive bufferization. - Cin = np.zeros((3, 2), np.double) - Cout = np.zeros((3, 2), np.double) - Cin_memref_ptr = ctypes.pointer( - ctypes.pointer(rt.get_ranked_memref_descriptor(Cin))) + + # Set up numpy input and buffer for output. + Ca = np.array( + [[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]], + np.float64) + Cb = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64) + Cc = np.zeros((3, 2), np.float64) + Cout = np.zeros((3, 2), np.float64) + + Ca_memref_ptr = ctypes.pointer( + ctypes.pointer(rt.get_ranked_memref_descriptor(Ca))) + Cb_memref_ptr = ctypes.pointer( + ctypes.pointer(rt.get_ranked_memref_descriptor(Cb))) + Cc_memref_ptr = ctypes.pointer( + ctypes.pointer(rt.get_ranked_memref_descriptor(Cc))) Cout_memref_ptr = ctypes.pointer( ctypes.pointer(rt.get_ranked_memref_descriptor(Cout))) - engine.invoke('main', Cout_memref_ptr, Cin_memref_ptr) - Cresult = rt.ranked_memref_to_numpy(Cout_memref_ptr[0]) + + # Invoke the kernel and get numpy output. + # Built-in bufferization uses in-out buffers. + # TODO: replace with inplace comprehensive bufferization. + engine.invoke('main', Cout_memref_ptr, Ca_memref_ptr, Cb_memref_ptr, + Cc_memref_ptr) # Sanity check on computed result. - expected = [[12.3, 12.0], [0.0, 0.0], [16.5, 19.8]] + Cresult = rt.ranked_memref_to_numpy(Cout_memref_ptr[0]) + expected = np.array([[12.3, 12.0], [0.0, 0.0], [16.5, 19.8]], np.float64) if np.allclose(Cresult, expected): pass else: @@ -132,7 +139,10 @@ # CHECK: Passed 72 tests @run def testSpMM(): + # Obtain path to runtime support library. support_lib = os.getenv('SUPPORT_LIB') + assert os.path.exists(support_lib), f'{support_lib} does not exist' + with ir.Context() as ctx, ir.Location.unknown(): count = 0 # Fixed compiler optimization strategy.