diff --git a/mlir/benchmark/python/README.md b/mlir/benchmark/python/README.md new file mode 100644 --- /dev/null +++ b/mlir/benchmark/python/README.md @@ -0,0 +1,46 @@ +# MLIR Python Benchmarks + +This directory contains definitions of MLIR benchmarks and a way to run them. These benchmarks are implemented in python +using MLIR's python bindings. Below, we describe how to implement and run these benchmarks. + +## Writing new benchmarks +A benchmark is an MLIR module containing a kernel function and a call to that kernel function. + +Let's look at a benchmark example and walk through it. + +```python +@benchmark(pipeline_string, ntimes) +def benchmark_something_module(): + module = ir.Module.create() + # Define arguments for the kernel function + with ir.InsertionPoint(module.body): + @builtin.FuncOp.from_py_func() + def kernel_name(): + # Kernel implementation +``` + +A benchmark function creates a module containing a kernel function and a call to it. +A benchmark is decorated with the `benchmark` decorator which takes two required arguments: +1. `pipeline_string`: Passed to the pass manager. +2. `ntimes`: Number of times this module should be executed. + +The outer function must have the prefix `"benchmark_"`. The inner kernel function, named `"kernel_name"` above, is used +as the benchmark identifier and appears as test name in the LNT report. We will describe viewing the result in LNT +below. + +These benchmarks are contained in python files that must have a suffix `"_bench.py"`. + +## Viewing benchmark results +We use [LNT](https://llvm.org/docs/lnt/index.html) to view benchmark results. We can push benchmark results to an LNT +server using `run.py` script in the current directory. Here's the script invocation from the root directory of the llvm +project. + +```bash +$ PYTHONPATH= MLIR_C_RUNNER_UTILS= MLIR_RUNNER_UTILS= python mlir/benchmark/python/run.py --machine --revision --url +``` +For more on what `machine` and `revision` mean in LNT context, check out [this](https://llvm.org/docs/lnt/concepts.html) +page. For a description about the command itself, run + +```bash +python mlir/benchmark/python/run.py --help +``` \ No newline at end of file diff --git a/mlir/benchmark/python/__init__.py b/mlir/benchmark/python/__init__.py new file mode 100644 diff --git a/mlir/benchmark/python/common.py b/mlir/benchmark/python/common.py new file mode 100644 --- /dev/null +++ b/mlir/benchmark/python/common.py @@ -0,0 +1,128 @@ +import ctypes +import numpy as np +import os +import re +import time +import typing + +from mlir import ir +from mlir import runtime as rt +from mlir.dialects import builtin +from mlir.dialects import arith +from mlir.dialects import memref +from mlir.dialects import scf +from mlir.dialects import std +from mlir.execution_engine import ExecutionEngine +from mlir.passmanager import PassManager + + +def create_random_np_tensor(tensor_type): + tensor_type_str = str(tensor_type) + dimensions_str = re.sub("<|>|tensor", "", tensor_type_str) + dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]] + return np.random.uniform(low=0.0, high=100.0, size=dimensions) + + +def create_zero_np_tensor(tensor_type): + tensor_type_str = str(tensor_type) + dimensions_str = re.sub("<|>|tensor", "", tensor_type_str) + dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]] + return np.zeros(dimensions, np.float64) + + +def construct_arguments_for_kernel_function(kernel_func: typing.Callable): + tensor_np_args = [create_zero_np_tensor(kernel_func.type.inputs[-1])] + for input_type in kernel_func.type.inputs[:-1]: + tensor_np_args.append(create_random_np_tensor(input_type)) + tensor_np_args.append(create_zero_np_tensor(kernel_func.type.inputs[-1])) + tensor_mem_args = [ + ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_tensor))) + for np_tensor in tensor_np_args + ] + return tensor_mem_args + + +def emit_timer_func() -> builtin.FuncOp: + i64_type = ir.IntegerType.get_signless(64) + nano_time = builtin.FuncOp( + "nano_time", ([], [i64_type]), visibility="private") + nano_time.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + return nano_time + + +def emit_benchmark_wrapped_main_func( + func: builtin.FuncOp, + timer_func: builtin.FuncOp +) -> builtin.FuncOp: + i64_type = ir.IntegerType.get_signless(64) + memref_of_i64_type = ir.MemRefType.get([-1], i64_type) + wrapped_func = builtin.FuncOp( + # Same signature and an extra buffer of indices to save timings. + "main", + (func.arguments.types + [memref_of_i64_type], func.type.results), + visibility="public" + ) + wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + num_results = len(func.type.results) + with ir.InsertionPoint(wrapped_func.add_entry_block()): + timer_buffer = wrapped_func.arguments[-1] + zero = arith.ConstantOp.create_index(0) + n_iterations = memref.DimOp(ir.IndexType.get(), timer_buffer, zero) + one = arith.ConstantOp.create_index(1) + iter_args = list(wrapped_func.arguments[-num_results - 1:-1]) + loop = scf.ForOp(zero, n_iterations, one, iter_args) + with ir.InsertionPoint(loop.body): + start = std.CallOp(timer_func, []) + call = std.CallOp( + func, wrapped_func.arguments[:-num_results - 1] + loop.inner_iter_args + ) + end = std.CallOp(timer_func, []) + time_taken = arith.SubIOp(end, start) + memref.StoreOp(time_taken, timer_buffer, [loop.induction_variable]) + scf.YieldOp(list(call.results)) + std.ReturnOp(loop) + + return wrapped_func + + +def benchmark(pipeline: str, number_of_runs: str): + def decorator(create_kernel_module: typing.Callable): + def wrapper(): + c_runner_utils = os.getenv("MLIR_C_RUNNER_UTILS", "") + assert os.path.exists(c_runner_utils), f"{c_runner_utils} does not exist. Please pass a valid value for MLIR_C_RUNNER_UTILS environment variable." + runner_utils = os.getenv("MLIR_RUNNER_UTILS", "") + assert os.path.exists(runner_utils), f"{runner_utils} does not exist. Please pass a valid value for MLIR_RUNNER_UTILS environment variable." + + with ir.Context(), ir.Location.unknown(): + kernel_func_main_module = create_kernel_module() + kernel_func = kernel_func_main_module.operation.regions[0].blocks[0].operations[0] + main_module_with_benchmark = ir.Module.create() + with ir.InsertionPoint(main_module_with_benchmark.body): + timer_func = emit_timer_func() + wrapped_func = emit_benchmark_wrapped_main_func(kernel_func, timer_func) + main_module_with_benchmark = ir.Module.parse( + str(timer_func) + str(wrapped_func) + str(kernel_func) + ) + + PassManager.parse(pipeline).run(main_module_with_benchmark) + + tensor_mem_args = construct_arguments_for_kernel_function(kernel_func) + np_timers_ns = np.zeros([number_of_runs], dtype=np.int64) + tensor_mem_args.append( + ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_timers_ns))) + ) + + compilation_time_start_seconds = time.time() + engine = ExecutionEngine(main_module_with_benchmark, 3, shared_libs=[c_runner_utils, runner_utils]) + compilation_time_seconds = time.time() - compilation_time_start_seconds + engine.invoke("main", *tensor_mem_args) + np_timers_s = [t * 10**(-9) for t in np_timers_ns] + return { + "name": str(kernel_func.name), + "compile_time": compilation_time_seconds, + "execution_time": list(np_timers_s) + } + + return wrapper + return decorator diff --git a/mlir/benchmark/python/run.py b/mlir/benchmark/python/run.py new file mode 100644 --- /dev/null +++ b/mlir/benchmark/python/run.py @@ -0,0 +1,68 @@ +import argparse +import datetime +import glob +import importlib +import json +import os +import types + +from urllib import error as urlerror +from urllib import parse as urlparse +from urllib import request + +BENCHMARK_FILE_SUFFIX_REGEX = "*_bench.py" +BENCHMARK_FUNCTION_PREFIX = "benchmark_" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--machine", + required=True, + help="A platform identifier on which the benchmarks are run. For example ---." + ) + parser.add_argument( + "--revision", + required=True, + help="The key used to identify different runs. Could be anything as long as it can be sorted by python's sort function." + ) + parser.add_argument( + "--url", + help="The lnt server url to send the results to.", + default="http://localhost:8000/db_default/v4/nts/submitRun" + ) + args = parser.parse_args() + + script_dir = os.path.abspath(os.path.dirname(__file__)) + benchmark_file_paths = glob.glob(f"{script_dir}/{BENCHMARK_FILE_SUFFIX_REGEX}") + benchmark_function_dicts = [] + complete_benchmark_start_time = datetime.datetime.utcnow().isoformat() + for benchmark_file_path in benchmark_file_paths: + benchmark_filename = os.path.basename(benchmark_file_path) + module_name = benchmark_filename.replace(".py", "") + module = importlib.import_module(module_name) + for attribute_string in dir(module): + if attribute_string.startswith(BENCHMARK_FUNCTION_PREFIX): + attribute = getattr(module, attribute_string) + if isinstance(attribute, types.FunctionType): + benchmark_function_dicts.append(attribute()) + + complete_benchmark_end_time = datetime.datetime.utcnow().isoformat() + lnt_dict = { + "format_version": "2", + "machine": {"name": args.machine}, + "run": { + "end_time": complete_benchmark_start_time, + "start_time": complete_benchmark_end_time, + "llvm_project_revision": args.revision + }, + "tests": benchmark_function_dicts, + "name": "MLIR benchmark suite" + } + lnt_json = json.dumps(lnt_dict) + request_data = urlparse.urlencode({"input_data": lnt_json}).encode("ascii") + req = request.Request(args.url, request_data) + try: + resp = request.urlopen(req) + except urlerror.HTTPError as e: + print(e) diff --git a/mlir/benchmark/python/sparse_bench.py b/mlir/benchmark/python/sparse_bench.py new file mode 100644 --- /dev/null +++ b/mlir/benchmark/python/sparse_bench.py @@ -0,0 +1,50 @@ +import mlir.all_passes_registration + +from mlir import ir +from mlir.dialects import builtin +from mlir.dialects.linalg.opdsl import lang as dsl +from common import benchmark + + +@dsl.linalg_structured_op +def matmul_dsl( + A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K), + B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N), + C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True) +): + C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n] + + +def get_sparse_kernel_pipeline() -> str: + opt = "parallelization-strategy=0 vectorization-strategy=0 vl=1 enable-simd-index32=False" + return ( + f"builtin.func(linalg-generalize-named-ops,linalg-fuse-elementwise-ops)," + f"sparsification{{{opt}}}," + f"sparse-tensor-conversion," + f"builtin.func(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf)," + f"convert-scf-to-std," + f"func-bufferize," + f"tensor-constant-bufferize," + f"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize)," + f"convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}}," + f"lower-affine," + f"convert-memref-to-llvm," + f"convert-std-to-llvm," + f"reconcile-unrealized-casts" + ) + + +@benchmark(get_sparse_kernel_pipeline(), 100) +def benchmark_sparse_kernel_module(): + module = ir.Module.create() + f64 = ir.F64Type.get() + a = ir.RankedTensorType.get([1000, 1500], f64) + b = ir.RankedTensorType.get([1500, 2000], f64) + c = ir.RankedTensorType.get([1000, 2000], f64) + with ir.InsertionPoint(module.body): + @builtin.FuncOp.from_py_func(a, b, c) + def sparse_kernel(x, y, z): + return matmul_dsl(x, y, outs=[z]) + + return module +