diff --git a/mlir/.gitignore b/mlir/.gitignore new file mode 100644 --- /dev/null +++ b/mlir/.gitignore @@ -0,0 +1,7 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ + +# Distribution / packaging +.Python +env/ +*.egg-info/ \ No newline at end of file diff --git a/mlir/benchmark/python/__init__.py b/mlir/benchmark/python/__init__.py new file mode 100644 diff --git a/mlir/benchmark/python/benchmark_sparse.py b/mlir/benchmark/python/benchmark_sparse.py new file mode 100644 --- /dev/null +++ b/mlir/benchmark/python/benchmark_sparse.py @@ -0,0 +1,98 @@ +import ctypes +import numpy as np +import os +import re + +import mlir.all_passes_registration + +from mbr import BenchmarkRunConfig +from mlir import ir +from mlir import runtime as rt +from mlir.dialects import builtin +from mlir.dialects.linalg.opdsl import lang as dsl +from mlir.execution_engine import ExecutionEngine +from mlir.passmanager import PassManager + +from common import emit_timer_func, emit_benchmark_wrapped_main_func, get_kernel_func_from_module + + +def setup_passes(mlir_module): + opt = "parallelization-strategy=0 vectorization-strategy=0 vl=1 enable-simd-index32=False" + pipeline = ( + f"builtin.func(linalg-generalize-named-ops,linalg-fuse-elementwise-ops)," + f"sparsification{{{opt}}}," + f"sparse-tensor-conversion," + f"builtin.func(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf)," + f"convert-scf-to-std," + f"func-bufferize," + f"tensor-constant-bufferize," + f"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize)," + f"convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}}," + f"lower-affine," + f"convert-memref-to-llvm," + f"convert-std-to-llvm," + f"reconcile-unrealized-casts" + ) + PassManager.parse(pipeline).run(mlir_module) + + +@dsl.linalg_structured_op +def matmul_dsl( + A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K), + B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N), + C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True) +): + C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n] + + +def benchmark_sparse_kernel_multiplication(): + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + f64 = ir.F64Type.get() + param1_type = ir.RankedTensorType.get([1000, 1500], f64) + param2_type = ir.RankedTensorType.get([1500, 2000], f64) + result_type = ir.RankedTensorType.get([1000, 2000], f64) + with ir.InsertionPoint(module.body): + @builtin.FuncOp.from_py_func(param1_type, param2_type, result_type) + def sparse_kernel(x, y, z): + return matmul_dsl(x, y, outs=[z]) + + def compiler(): + with ir.Context(), ir.Location.unknown(): + kernel_func = get_kernel_func_from_module(module) + timer_func = emit_timer_func() + wrapped_func = emit_benchmark_wrapped_main_func(kernel_func, timer_func) + main_module_with_benchmark = ir.Module.parse( + str(timer_func) + str(wrapped_func) + str(kernel_func) + ) + setup_passes(main_module_with_benchmark) + c_runner_utils = os.getenv("MLIR_C_RUNNER_UTILS", "") + assert os.path.exists(c_runner_utils),\ + f"{c_runner_utils} does not exist. Please pass a valid value for MLIR_C_RUNNER_UTILS environment variable." + runner_utils = os.getenv("MLIR_RUNNER_UTILS", "") + assert os.path.exists(runner_utils),\ + f"{runner_utils} does not exist. Please pass a valid value for MLIR_RUNNER_UTILS environment variable." + + engine = ExecutionEngine(main_module_with_benchmark, 3, shared_libs=[c_runner_utils, runner_utils]) + return engine.invoke + + def runner(engine_invoke): + compiled_program_args = [] + for argument_type in [result_type, param1_type, param2_type, result_type]: + argument_type_str = str(argument_type) + dimensions_str = re.sub("<|>|tensor", "", argument_type_str) + dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]] + if argument_type == result_type: + argument = np.zeros(dimensions, np.float64) + else: + argument = np.random.uniform(low=0.0, high=100.0, size=dimensions) + compiled_program_args.append(ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(argument)))) + np_timers_ns = np.array([0], dtype=np.int64) + compiled_program_args.append(ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_timers_ns)))) + engine_invoke("main", *compiled_program_args) + return np_timers_ns[0] + + return BenchmarkRunConfig( + compiler=compiler, + runner=runner, + ) diff --git a/mlir/benchmark/python/common.py b/mlir/benchmark/python/common.py new file mode 100644 --- /dev/null +++ b/mlir/benchmark/python/common.py @@ -0,0 +1,57 @@ +from mlir import ir +from mlir.dialects import arith +from mlir.dialects import builtin +from mlir.dialects import memref +from mlir.dialects import scf +from mlir.dialects import std + + +def get_kernel_func_from_module(module: ir.Module) -> builtin.FuncOp: + assert len(module.operation.regions) == 1, \ + "Expected kernel module to have only one region" + assert len(module.operation.regions[0].blocks) == 1, \ + "Expected kernel module to have only one block" + assert len(module.operation.regions[0].blocks[0].operations) == 1, \ + "Expected kernel module to have only one operation" + return module.operation.regions[0].blocks[0].operations[0] + + +def emit_timer_func() -> builtin.FuncOp: + i64_type = ir.IntegerType.get_signless(64) + nano_time = builtin.FuncOp( + "nano_time", ([], [i64_type]), visibility="private") + nano_time.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + return nano_time + + +def emit_benchmark_wrapped_main_func(func: builtin.FuncOp, timer_func: builtin.FuncOp) -> builtin.FuncOp: + i64_type = ir.IntegerType.get_signless(64) + memref_of_i64_type = ir.MemRefType.get([-1], i64_type) + wrapped_func = builtin.FuncOp( + # Same signature and an extra buffer of indices to save timings. + "main", + (func.arguments.types + [memref_of_i64_type], func.type.results), + visibility="public" + ) + wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + num_results = len(func.type.results) + with ir.InsertionPoint(wrapped_func.add_entry_block()): + timer_buffer = wrapped_func.arguments[-1] + zero = arith.ConstantOp.create_index(0) + n_iterations = memref.DimOp(ir.IndexType.get(), timer_buffer, zero) + one = arith.ConstantOp.create_index(1) + iter_args = list(wrapped_func.arguments[-num_results - 1:-1]) + loop = scf.ForOp(zero, n_iterations, one, iter_args) + with ir.InsertionPoint(loop.body): + start = std.CallOp(timer_func, []) + call = std.CallOp( + func, wrapped_func.arguments[:-num_results - 1] + loop.inner_iter_args + ) + end = std.CallOp(timer_func, []) + time_taken = arith.SubIOp(end, start) + memref.StoreOp(time_taken, timer_buffer, [loop.induction_variable]) + scf.YieldOp(list(call.results)) + std.ReturnOp(loop) + + return wrapped_func diff --git a/mlir/utils/mbr/mbr/__init__.py b/mlir/utils/mbr/mbr/__init__.py new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/mbr/__init__.py @@ -0,0 +1,8 @@ +import dataclasses +import typing + + +@dataclasses.dataclass +class BenchmarkRunConfig: + compiler: typing.Callable + runner: typing.Callable diff --git a/mlir/utils/mbr/mbr/discovery.py b/mlir/utils/mbr/mbr/discovery.py new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/mbr/discovery.py @@ -0,0 +1,24 @@ +import importlib +import os +import pathlib +import sys +import types + + +def discover_benchmark_modules(top_level_path): + benchmark_files = pathlib.Path(top_level_path).rglob("benchmark_*.py") + for benchmark_filename in benchmark_files: + benchmark_abs_dir = os.path.abspath(os.path.dirname(benchmark_filename)) + sys.path.append(benchmark_abs_dir) + module_file_name = os.path.basename(benchmark_filename) + module_name = module_file_name.replace(".py", "") + module = importlib.import_module(module_name) + yield module + sys.path.pop() + + +def get_benchmark_functions(module): + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + if isinstance(attribute, types.FunctionType) and attribute_name.startswith("benchmark_"): + yield attribute diff --git a/mlir/utils/mbr/mbr/main.py b/mlir/utils/mbr/mbr/main.py new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/mbr/main.py @@ -0,0 +1,108 @@ +import argparse +import datetime +import json +import os +import time + +import numpy as np + +from urllib import error as urlerror +from urllib import parse as urlparse +from urllib import request + +from discovery import discover_benchmark_modules, get_benchmark_functions +from stats import has_enough_measurements + + +def main(top_level_directory): + if not os.path.exists(top_level_directory): + raise AssertionError(f"The top-level directory {top_level_directory} doesn't exist.") + + modules = [module for module in discover_benchmark_modules(top_level_directory)] + benchmark_dicts = [] + for module in modules: + benchmark_functions = [function for function in get_benchmark_functions(module)] + for benchmark_function in benchmark_functions: + benchmark_run_config = benchmark_function() + measurements_ns = np.array([]) + if benchmark_run_config.compiler: + start_compile_time_s = time.time() + compiled_callable = benchmark_run_config.compiler() + total_compile_time_s = time.time() - start_compile_time_s + runner_args = (compiled_callable,) + else: + total_compile_time_s = 0 + runner_args = () + while not has_enough_measurements(measurements_ns): + measurement_ns = benchmark_run_config.runner(*runner_args) + measurements_ns = np.append(measurements_ns, measurement_ns) + + measurements_s = [t * 1e-9 for t in measurements_ns] + benchmark_identifier = ":".join([module.__name__, benchmark_function.__name__]) + benchmark_dicts.append( + { + "name": benchmark_identifier, + "compile_time": total_compile_time_s, + "execution_time": list(measurements_s), + } + ) + + return benchmark_dicts + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--machine", + required=True, + help="A platform identifier on which the " + "benchmarks are run. For example ---." + ) + parser.add_argument( + "--revision", + required=True, + help="The key used to identify different runs. " + "Could be anything as long as it can be sorted by python's sort function." + ) + parser.add_argument( + "--url", + help="The lnt server url to send the results to.", + default="http://localhost:8000/db_default/v4/nts/submitRun" + ) + parser.add_argument( + "--result-stdout", + help="Print benchmarking results to stdout instead of sending it to lnt.", + default=False, + action=argparse.BooleanOptionalAction + ) + parser.add_argument( + "top_level_directory", + help="The top level directory from which to search for benchmarks", + default=os.getcwd(), + ) + args = parser.parse_args() + + complete_benchmark_start_time = datetime.datetime.utcnow().isoformat() + benchmark_function_dicts = main(args.top_level_directory) + complete_benchmark_end_time = datetime.datetime.utcnow().isoformat() + lnt_dict = { + "format_version": "2", + "machine": {"name": args.machine}, + "run": { + "end_time": complete_benchmark_start_time, + "start_time": complete_benchmark_end_time, + "llvm_project_revision": args.revision + }, + "tests": benchmark_function_dicts, + "name": "MLIR benchmark suite" + } + lnt_json = json.dumps(lnt_dict, indent=4) + if args.result_stdout is True: + print(lnt_json) + else: + request_data = urlparse.urlencode({"input_data": lnt_json}).encode("ascii") + req = request.Request(args.url, request_data) + try: + resp = request.urlopen(req) + except urlerror.HTTPError as e: + print(e) diff --git a/mlir/utils/mbr/mbr/stats.py b/mlir/utils/mbr/mbr/stats.py new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/mbr/stats.py @@ -0,0 +1,12 @@ +import numpy as np + + +MAX_NUMBER_OF_MEASUREMENTS = 1e9 # 1 billion +MAX_TIME_FOR_A_BENCHMARK_NS = 1e11 # 100 seconds + + +def has_enough_measurements(measurements): + return ( + np.sum(measurements) >= MAX_TIME_FOR_A_BENCHMARK_NS or + np.size(measurements) >= MAX_NUMBER_OF_MEASUREMENTS + ) diff --git a/mlir/utils/mbr/requirements.txt b/mlir/utils/mbr/requirements.txt new file mode 100644 diff --git a/mlir/utils/mbr/setup.py b/mlir/utils/mbr/setup.py new file mode 100644 --- /dev/null +++ b/mlir/utils/mbr/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup +from setuptools import find_packages + + +setup( + name="mbr", + version="1.0.0", + packages=find_packages(), + entry_points={ + "console_scripts": [ + "mbr = mbr.main:main", + ], + }, +)