diff --git a/mlir/benchmark/python/__init__.py b/mlir/benchmark/python/__init__.py new file mode 100644 diff --git a/mlir/benchmark/python/benchmark_sparse.py b/mlir/benchmark/python/benchmark_sparse.py new file mode 100644 --- /dev/null +++ b/mlir/benchmark/python/benchmark_sparse.py @@ -0,0 +1,71 @@ +import ctypes +import numpy as np +import re + +import mlir.all_passes_registration + +from mbr import BenchmarkRunConfig +from mlir import ir +from mlir import runtime as rt +from mlir.dialects import builtin +from mlir.dialects.linalg.opdsl import lang as dsl +from mlir.passmanager import PassManager + + +def setup_passes(mlir_module): + opt = "parallelization-strategy=0 vectorization-strategy=0 vl=1 enable-simd-index32=False" + pipeline = ( + f"builtin.func(linalg-generalize-named-ops,linalg-fuse-elementwise-ops)," + f"sparsification{{{opt}}}," + f"sparse-tensor-conversion," + f"builtin.func(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf)," + f"convert-scf-to-std," + f"func-bufferize," + f"tensor-constant-bufferize," + f"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize)," + f"convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}}," + f"lower-affine," + f"convert-memref-to-llvm," + f"convert-std-to-llvm," + f"reconcile-unrealized-casts" + ) + PassManager.parse(pipeline).run(mlir_module) + + +@dsl.linalg_structured_op +def matmul_dsl( + A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K), + B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N), + C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True) +): + C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n] + + +def benchmark_sparse_kernel_multiplication(): + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + f64 = ir.F64Type.get() + param1_type = ir.RankedTensorType.get([1000, 1500], f64) + param2_type = ir.RankedTensorType.get([1500, 2000], f64) + result_type = ir.RankedTensorType.get([1000, 2000], f64) + with ir.InsertionPoint(module.body): + @builtin.FuncOp.from_py_func(param1_type, param2_type, result_type) + def sparse_kernel(x, y, z): + return matmul_dsl(x, y, outs=[z]) + + argument_types = [result_type, param1_type, param2_type, result_type] + arguments = [] + for argument_type in argument_types: + argument_type_str = str(argument_type) + dimensions_str = re.sub("<|>|tensor", "", argument_type_str) + dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]] + if argument_type == result_type: + argument = np.zeros(dimensions, np.float64) + else: + argument = np.random.uniform(low=0.0, high=100.0, size=dimensions) + arguments.append(ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(argument)))) + + return BenchmarkRunConfig( + module=module, + arguments=arguments, + ) diff --git a/mlir/utils/mbr/README.md b/mlir/utils/mbr/README.md new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/README.md @@ -0,0 +1,76 @@ +# MBR - MLIR Benchmark Runner +mbr is a command line tool to run mlir benchmarks. In mbr, we define kernels to benchmark and mbr discovers and run +them. + +## Defining benchmarks +To define a benchmark, we need two things: +1. An MLIR module which contains the kernel to benchmark. +2. The input arguments for the kernel. + +A benchmark is defined as a python function that returns a `BenchmarkRunConfig` object. Here's an example. +We explain it after the example code. + +```python +from mbr import BenchmarkRunConfig +# Other imports + + +def benchmark_something(): + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + # Define a kernel and insert it into module's body. + + arguments = [] + # Prepare arguments for the kernel. + + return BenchmarkRunConfig( + module=module, + arguments=arguments, + ) +``` +A benchmark function must start with the prefix `"benchmark_"` for it to be discoverable by mbr. +In a benchmark function, we define an MLIR module and the arguments to call its kernel. We then +wrap the module and arguments in a `BenchmarkRunConfig` object provided by mbr and return it. + +These benchmark functions must be in python files that are prefixed by `"benchmark_"` for them to be +discoverable by mbr. For example `benchmark_sparse.py` is a discoverable name but `benchmarksparse` isn't. + +## Defining passes +We can define passes to be applied while running a benchmark. To define passes to be applied to +all benchmarks in a file, define a function `setup_passes` as follows. + +```python +import mlir.all_passes_registration + +from mlir.passmanager import PassManager + + +def setup_passes(mlir_module): + # Define pipeline string + PassManager.parse(pipeline).run(mlir_module) +``` + +To define a pass to be applied for a specific benchmark, define it like this. + +```python +def setup_passes_for_something(mlir_module): + ... + + +def benchmark_something(): + ... +``` +That is, for a pass to be applied to a benchmark named `benchmark_something`, we remove the prefix +`benchmark_` and add a new prefix `setup_passes_for_`. + +If for a particular benchmark `benchmark_something`, we have both `setup_passes` and `setup_passes_for_something` in +that file, we only run `setup_passes_for_something`. That is, we give preference to particular pass functions over +general pass functions. + +## Running benchmarks +To run the benchmarks, run the `main.py` script. Here's an invocation that will run all the benchmarks +in the `mlir` directory in llvm project's root directory. + +```bash +PYTHONPATH=build/tools/mlir/python_packages/mlir_core MLIR_C_RUNNER_UTILS=build/lib/libmlir_c_runner_utils.dylib MLIR_RUNNER_UTILS=build/lib/libmlir_runner_utils.dylib python mlir/utils/mbr/mbr/main.py --machine arm-m1-home --revision v1 +``` diff --git a/mlir/utils/mbr/mbr.egg-info/PKG-INFO b/mlir/utils/mbr/mbr.egg-info/PKG-INFO new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/mbr.egg-info/PKG-INFO @@ -0,0 +1,10 @@ +Metadata-Version: 2.1 +Name: mbr +Version: 1.0.0 +Summary: UNKNOWN +Home-page: UNKNOWN +License: UNKNOWN +Platform: UNKNOWN + +UNKNOWN + diff --git a/mlir/utils/mbr/mbr.egg-info/SOURCES.txt b/mlir/utils/mbr/mbr.egg-info/SOURCES.txt new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/mbr.egg-info/SOURCES.txt @@ -0,0 +1,10 @@ +setup.py +mbr/__init__.py +mbr/compile.py +mbr/discovery.py +mbr/main.py +mbr.egg-info/PKG-INFO +mbr.egg-info/SOURCES.txt +mbr.egg-info/dependency_links.txt +mbr.egg-info/entry_points.txt +mbr.egg-info/top_level.txt \ No newline at end of file diff --git a/mlir/utils/mbr/mbr.egg-info/dependency_links.txt b/mlir/utils/mbr/mbr.egg-info/dependency_links.txt new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/mbr.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/mlir/utils/mbr/mbr.egg-info/entry_points.txt b/mlir/utils/mbr/mbr.egg-info/entry_points.txt new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/mbr.egg-info/entry_points.txt @@ -0,0 +1,3 @@ +[console_scripts] +mbr = mbr.main:main + diff --git a/mlir/utils/mbr/mbr.egg-info/top_level.txt b/mlir/utils/mbr/mbr.egg-info/top_level.txt new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/mbr.egg-info/top_level.txt @@ -0,0 +1 @@ +mbr diff --git a/mlir/utils/mbr/mbr/__init__.py b/mlir/utils/mbr/mbr/__init__.py new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/mbr/__init__.py @@ -0,0 +1,10 @@ +import dataclasses +import typing + +from mlir import ir + + +@dataclasses.dataclass +class BenchmarkRunConfig: + module: ir.Module + arguments: list[typing.Any] diff --git a/mlir/utils/mbr/mbr/compile.py b/mlir/utils/mbr/mbr/compile.py new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/mbr/compile.py @@ -0,0 +1,133 @@ +import ctypes +import os +import numpy as np +import typing +import re +import time + +from mlir import ir +from mlir import runtime as rt +from mlir.dialects import arith +from mlir.dialects import builtin +from mlir.dialects import memref +from mlir.dialects import scf +from mlir.dialects import std +from mlir.execution_engine import ExecutionEngine + + +def create_random_np_tensor(tensor_type): + tensor_type_str = str(tensor_type) + dimensions_str = re.sub("<|>|tensor", "", tensor_type_str) + dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]] + return np.random.uniform(low=0.0, high=100.0, size=dimensions) + + +def create_zero_np_tensor(tensor_type): + tensor_type_str = str(tensor_type) + dimensions_str = re.sub("<|>|tensor", "", tensor_type_str) + dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]] + return np.zeros(dimensions, np.float64) + + +def construct_arguments_for_kernel_function(kernel_func: typing.Callable): + tensor_np_args = [create_zero_np_tensor(kernel_func.type.inputs[-1])] + for input_type in kernel_func.type.inputs[:-1]: + tensor_np_args.append(create_random_np_tensor(input_type)) + tensor_np_args.append(create_zero_np_tensor(kernel_func.type.inputs[-1])) + tensor_mem_args = [ + ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_tensor))) + for np_tensor in tensor_np_args + ] + return tensor_mem_args + + +def get_kernel_func_from_module(module: ir.Module) -> builtin.FuncOp: + assert len(module.operation.regions) == 1, \ + "Expected kernel module to have only one region" + assert len(module.operation.regions[0].blocks) == 1, \ + "Expected kernel module to have only one block" + assert len(module.operation.regions[0].blocks[0].operations) == 1, \ + "Expected kernel module to have only one operation" + return module.operation.regions[0].blocks[0].operations[0] + + +def emit_timer_func() -> builtin.FuncOp: + i64_type = ir.IntegerType.get_signless(64) + nano_time = builtin.FuncOp( + "nano_time", ([], [i64_type]), visibility="private") + nano_time.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + return nano_time + + +def emit_benchmark_wrapped_main_func( + func: builtin.FuncOp, + timer_func: builtin.FuncOp +) -> builtin.FuncOp: + i64_type = ir.IntegerType.get_signless(64) + memref_of_i64_type = ir.MemRefType.get([-1], i64_type) + wrapped_func = builtin.FuncOp( + # Same signature and an extra buffer of indices to save timings. + "main", + (func.arguments.types + [memref_of_i64_type], func.type.results), + visibility="public" + ) + wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + num_results = len(func.type.results) + with ir.InsertionPoint(wrapped_func.add_entry_block()): + timer_buffer = wrapped_func.arguments[-1] + zero = arith.ConstantOp.create_index(0) + n_iterations = memref.DimOp(ir.IndexType.get(), timer_buffer, zero) + one = arith.ConstantOp.create_index(1) + iter_args = list(wrapped_func.arguments[-num_results - 1:-1]) + loop = scf.ForOp(zero, n_iterations, one, iter_args) + with ir.InsertionPoint(loop.body): + start = std.CallOp(timer_func, []) + call = std.CallOp( + func, wrapped_func.arguments[:-num_results - 1] + loop.inner_iter_args + ) + end = std.CallOp(timer_func, []) + time_taken = arith.SubIOp(end, start) + memref.StoreOp(time_taken, timer_buffer, [loop.induction_variable]) + scf.YieldOp(list(call.results)) + std.ReturnOp(loop) + + return wrapped_func + + +def compile_and_run_benchmark( + benchmark_identifier: str, + kernel_func_module: ir.Module, + kernel_func_arguments: list[typing.Any], + setup_pass_function: typing.Optional[typing.Callable], +): + c_runner_utils = os.getenv("MLIR_C_RUNNER_UTILS", "") + assert os.path.exists(c_runner_utils), f"{c_runner_utils} does not exist. Please pass a valid value for MLIR_C_RUNNER_UTILS environment variable." + runner_utils = os.getenv("MLIR_RUNNER_UTILS", "") + assert os.path.exists(runner_utils), f"{runner_utils} does not exist. Please pass a valid value for MLIR_RUNNER_UTILS environment variable." + + with ir.Context(), ir.Location.unknown(): + kernel_func = get_kernel_func_from_module(kernel_func_module) + timer_func = emit_timer_func() + wrapped_func = emit_benchmark_wrapped_main_func(kernel_func, timer_func) + main_module_with_benchmark = ir.Module.parse( + str(timer_func) + str(wrapped_func) + str(kernel_func) + ) + + if setup_pass_function: + setup_pass_function(main_module_with_benchmark) + np_timers_ns = np.zeros([10], dtype=np.int64) # Run the benchmark 10 times. + kernel_func_arguments.append( + ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_timers_ns))) + ) + + compilation_time_start_seconds = time.time() + engine = ExecutionEngine(main_module_with_benchmark, 3, shared_libs=[c_runner_utils, runner_utils]) + compilation_time_seconds = time.time() - compilation_time_start_seconds + engine.invoke("main", *kernel_func_arguments) + np_timers_s = [t * 10**(-9) for t in np_timers_ns] + return { + "name": benchmark_identifier, + "compile_time": compilation_time_seconds, + "execution_time": list(np_timers_s) + } diff --git a/mlir/utils/mbr/mbr/discovery.py b/mlir/utils/mbr/mbr/discovery.py new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/mbr/discovery.py @@ -0,0 +1,31 @@ +import importlib +import os +import pathlib +import sys +import types + + +def discover_benchmark_modules(top_level_path): + benchmark_files = pathlib.Path(top_level_path).rglob("benchmark_*.py") + for benchmark_filename in benchmark_files: + benchmark_abs_dir = os.path.abspath(os.path.dirname(benchmark_filename)) + sys.path.append(benchmark_abs_dir) + module_file_name = os.path.basename(benchmark_filename) + module_name = module_file_name.replace(".py", "") + module = importlib.import_module(module_name) + yield module + sys.path.pop() + + +def get_benchmark_functions(module): + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + if isinstance(attribute, types.FunctionType) and attribute_name.startswith("benchmark_"): + yield attribute + + +def get_setup_pass_functions_for_benchmark(module): + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + if isinstance(attribute, types.FunctionType) and attribute_name.startswith("setup_passes"): + yield attribute \ No newline at end of file diff --git a/mlir/utils/mbr/mbr/main.py b/mlir/utils/mbr/mbr/main.py new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/mbr/main.py @@ -0,0 +1,92 @@ +import argparse +import datetime +import json + +from urllib import error as urlerror +from urllib import parse as urlparse +from urllib import request + +from discovery import discover_benchmark_modules, get_benchmark_functions, get_setup_pass_functions_for_benchmark +from compile import compile_and_run_benchmark + + +def main(): + modules = [module for module in discover_benchmark_modules("mlir")] + benchmark_dicts = [] + for module in modules: + benchmark_functions = [function for function in get_benchmark_functions(module)] + setup_pass_functions = [function for function in get_setup_pass_functions_for_benchmark(module)] + setup_pass_function_dicts = {function.__name__: function for function in setup_pass_functions} + for benchmark_function in benchmark_functions: + benchmark_run_config = benchmark_function() + benchmark_function_name = benchmark_function.__name__ + if benchmark_function_name.replace("benchmark_", "setup_passes_for_") in setup_pass_function_dicts: + setup_pass_function = setup_pass_function_dicts[ + benchmark_function_name.replace("benchmark_", "setup_passes_for_") + ] + elif "setup_passes" in setup_pass_function_dicts: + setup_pass_function = setup_pass_function_dicts["setup_passes"] + else: + setup_pass_function = None + + benchmark_identifier = ":".join([module.__name__, benchmark_function_name]) + benchmark_result_dict = compile_and_run_benchmark( + benchmark_identifier, + benchmark_run_config.module, + benchmark_run_config.arguments, + setup_pass_function + ) + benchmark_dicts.append(benchmark_result_dict) + + return benchmark_dicts + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--machine", + required=True, + help="A platform identifier on which the benchmarks are run. For example ---." + ) + parser.add_argument( + "--revision", + required=True, + help="The key used to identify different runs. Could be anything as long as it can be sorted by python's sort function." + ) + parser.add_argument( + "--url", + help="The lnt server url to send the results to.", + default="http://localhost:8000/db_default/v4/nts/submitRun" + ) + parser.add_argument( + "--result-stdout", + help="Print benchmarking results to stdout instead of sending it to lnt.", + default=False, + action=argparse.BooleanOptionalAction + ) + args = parser.parse_args() + + complete_benchmark_start_time = datetime.datetime.utcnow().isoformat() + benchmark_function_dicts = main() + complete_benchmark_end_time = datetime.datetime.utcnow().isoformat() + lnt_dict = { + "format_version": "2", + "machine": {"name": args.machine}, + "run": { + "end_time": complete_benchmark_start_time, + "start_time": complete_benchmark_end_time, + "llvm_project_revision": args.revision + }, + "tests": benchmark_function_dicts, + "name": "MLIR benchmark suite" + } + lnt_json = json.dumps(lnt_dict, indent=4) + if args.result_stdout is True: + print(lnt_json) + else: + request_data = urlparse.urlencode({"input_data": lnt_json}).encode("ascii") + req = request.Request(args.url, request_data) + try: + resp = request.urlopen(req) + except urlerror.HTTPError as e: + print(e) diff --git a/mlir/utils/mbr/setup.py b/mlir/utils/mbr/setup.py new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup +from setuptools import find_packages + + +setup( + name="mbr", + version="1.0.0", + packages=find_packages(), + entry_points={ + "console_scripts": [ + "mbr = mbr.main:main", + ], + }, +)