diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py --- a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py @@ -1,6 +1,11 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional, Sequence + +import inspect + from ..ir import * @@ -93,3 +98,99 @@ raise IndexError('The function already has an entry block!') self.body.blocks.append(*self.type.inputs) return self.body.blocks[0] + + @classmethod + def from_py_func(FuncOp, + *inputs: Type, + results: Optional[Sequence[Type]] = None, + name: Optional[str] = None): + """Decorator to define an MLIR FuncOp specified as a python function. + + Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are + active for the current thread (i.e. established in a `with` block). + + When applied as a decorator to a Python function, an entry block will + be constructed for the FuncOp with types as specified in `*inputs`. The + block arguments will be pass positionally to the Python function. In + addition, if the Python function accepts keyword arguments generally or + has a corresponding keyword argument, the following will be passed: + * `func_op`: The `func` op being defined. + + By default, the function name will be the Python function `__name__`. This + can be overriden by passing the `name` argument to the decorator. + + If `results` is not specified, then the decorator will implicitly + insert a `ReturnOp` with the `Value`'s returned from the decorated + function. It will also set the `FuncOp` type with the actual return + value types. If `results` is specified, then the decorated function + must return `None` and no implicit `ReturnOp` is added (nor are the result + types updated). The implicit behavior is intended for simple, single-block + cases, and users should specify result types explicitly for any complicated + cases. + + The decorated function can further be called from Python and will insert + a `CallOp` at the then-current insertion point, returning either None ( + if no return values), a unary Value (for one result), or a list of Values). + This mechanism cannot be used to emit recursive calls (by construction). + """ + + def decorator(f): + from . import std + # Introspect the callable for optional features. + sig = inspect.signature(f) + has_arg_func_op = False + for param in sig.parameters.values(): + if param.kind == param.VAR_KEYWORD: + has_arg_func_op = True + if param.name == "func_op" and (param.kind + == param.POSITIONAL_OR_KEYWORD or + param.kind == param.KEYWORD_ONLY): + has_arg_func_op = True + + # Emit the FuncOp. + implicit_return = results is None + symbol_name = name or f.__name__ + function_type = FunctionType.get( + inputs=inputs, results=[] if implicit_return else results) + func_op = FuncOp(name=symbol_name, type=function_type) + with InsertionPoint(func_op.add_entry_block()): + func_args = func_op.entry_block.arguments + func_kwargs = {} + if has_arg_func_op: + func_kwargs["func_op"] = func_op + return_values = f(*func_args, **func_kwargs) + if not implicit_return: + return_types = list(results) + assert return_values is None, ( + "Capturing a python function with explicit `results=` " + "requires that the wrapped function returns None.") + else: + # Coerce return values, add ReturnOp and rewrite func type. + if return_values is None: + return_values = [] + elif isinstance(return_values, Value): + return_values = [return_values] + else: + return_values = list(return_values) + std.ReturnOp(return_values) + # Recompute the function type. + return_types = [v.type for v in return_values] + function_type = FunctionType.get(inputs=inputs, results=return_types) + func_op.attributes["type"] = TypeAttr.get(function_type) + + def emit_call_op(*call_args): + call_op = std.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name), + call_args) + if return_types is None: + return None + elif len(return_types) == 1: + return call_op.result + else: + return call_op.results + + wrapped = emit_call_op + wrapped.__name__ = f.__name__ + wrapped.func_op = func_op + return wrapped + + return decorator diff --git a/mlir/test/Bindings/Python/dialects/builtin.py b/mlir/test/Bindings/Python/dialects/builtin.py --- a/mlir/test/Bindings/Python/dialects/builtin.py +++ b/mlir/test/Bindings/Python/dialects/builtin.py @@ -8,9 +8,106 @@ def run(f): print("\nTEST:", f.__name__) f() + return f + + +# CHECK-LABEL: TEST: testFromPyFunc +@run +def testFromPyFunc(): + with Context() as ctx, Location.unknown() as loc: + m = builtin.ModuleOp() + f32 = F32Type.get() + f64 = F64Type.get() + with InsertionPoint.at_block_terminator(m.body): + # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64 + # CHECK: return %arg0 : f64 + @builtin.FuncOp.from_py_func(f64) + def unary_return(a): + return a + + # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64) + # CHECK: return %arg0, %arg1 : f32, f64 + @builtin.FuncOp.from_py_func(f32, f64) + def binary_return(a, b): + return a, b + + # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64) + # CHECK: return + @builtin.FuncOp.from_py_func(f32, f64) + def none_return(a, b): + pass + + # CHECK-LABEL: func @call_unary + # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64 + # CHECK: return %0 : f64 + @builtin.FuncOp.from_py_func(f64) + def call_unary(a): + return unary_return(a) + + # CHECK-LABEL: func @call_binary + # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64) + # CHECK: return %0#0, %0#1 : f32, f64 + @builtin.FuncOp.from_py_func(f32, f64) + def call_binary(a, b): + return binary_return(a, b) + + # CHECK-LABEL: func @call_none + # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> () + # CHECK: return + @builtin.FuncOp.from_py_func(f32, f64) + def call_none(a, b): + return none_return(a, b) + + ## Variants and optional feature tests. + # CHECK-LABEL: func @from_name_arg + @builtin.FuncOp.from_py_func(f32, f64, name="from_name_arg") + def explicit_name(a, b): + return b + + @builtin.FuncOp.from_py_func(f32, f64) + def positional_func_op(a, b, func_op): + assert isinstance(func_op, builtin.FuncOp) + return b + + @builtin.FuncOp.from_py_func(f32, f64) + def kw_func_op(a, b=None, func_op=None): + assert isinstance(func_op, builtin.FuncOp) + return b + + @builtin.FuncOp.from_py_func(f32, f64) + def kwargs_func_op(a, b=None, **kwargs): + assert isinstance(kwargs["func_op"], builtin.FuncOp) + return b + + # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64 + # CHECK: return %arg1 : f64 + @builtin.FuncOp.from_py_func(f32, f64, results=[f64]) + def explicit_results(a, b): + std.ReturnOp([b]) + + print(m) + + +# CHECK-LABEL: TEST: testFromPyFuncErrors +@run +def testFromPyFuncErrors(): + with Context() as ctx, Location.unknown() as loc: + m = builtin.ModuleOp() + f32 = F32Type.get() + f64 = F64Type.get() + with InsertionPoint.at_block_terminator(m.body): + try: + + @builtin.FuncOp.from_py_func(f64, results=[f64]) + def unary_return(a): + return a + except AssertionError as e: + # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None. + print(e) # CHECK-LABEL: TEST: testBuildFuncOp +@run def testBuildFuncOp(): ctx = Context() with Location.unknown(ctx) as loc: @@ -64,6 +161,3 @@ # CHECK: return %arg0 : tensor<2x3x4xf32> # CHECK: } print(m) - - -run(testBuildFuncOp) diff --git a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py @@ -10,46 +10,6 @@ from mlir.dialects.linalg.opdsl.lang import * -# TODO: Find a home for this quality of life helper. -def build_function(*inputs: Type, results: Optional[Sequence[Type]] = None): - """Decorator that emits a function in a more pythonic way. - - If result types are not specified, they are inferred from the function - returns. The `ReturnOp` is implicitly added upon the wrapped function return. - """ - - def decorator(f): - return_types = results - symbol_name = f.__name__ - function_type = FunctionType.get(inputs=inputs, results=results or []) - func_op = builtin.FuncOp(name=symbol_name, type=function_type) - with InsertionPoint(func_op.add_entry_block()): - func_args = func_op.entry_block.arguments - return_values = f(*func_args) - if return_values is None: - return_values = [] - elif isinstance(return_values, Value): - return_values = [return_values] - else: - return_values = list(return_values) - std.ReturnOp(return_values) - if return_types is None: - # Recompute the function type. - return_types = [v.type for v in return_values] - function_type = FunctionType.get(inputs=inputs, results=return_types) - # TODO: Have an API or a setter for this. - func_op.attributes["type"] = TypeAttr.get(function_type) - - # TODO: When turning this into a real facility, return a function that emits - # a `call` to the function instead of doing nothing. - wrapped = lambda: None - wrapped.__name__ = symbol_name - wrapped.func_op = func_op - return wrapped - - return decorator - - @linalg_structured_op def matmul_mono(A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), @@ -92,8 +52,8 @@ # CHECK-SAME: ins(%[[A]], %[[B]] # CHECK-SAME: outs(%[[INITC]] - @build_function(RankedTensorType.get((4, 16), f32), - RankedTensorType.get((16, 8), f32)) + @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), + RankedTensorType.get((16, 8), f32)) def test_matmul_mono(lhs, rhs): # TODO: Enable outs inference and add sugar for InitTensorOp # construction. @@ -114,9 +74,9 @@ # CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32 # CHECK-NEXT: linalg.yield %[[ADD]] : i32 # CHECK-NEXT: -> tensor<4x8xi32> - @build_function(RankedTensorType.get((4, 16), i8), - RankedTensorType.get((16, 8), i8), - RankedTensorType.get((4, 8), i32)) + @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8), + RankedTensorType.get((16, 8), i8), + RankedTensorType.get((4, 8), i32)) def test_i8i8i32_matmul(lhs, rhs, init_result): return matmul_poly(lhs, rhs, outs=[init_result]) @@ -128,9 +88,9 @@ # CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32 # CHECK-NEXT: linalg.yield %[[ADD]] : i32 # CHECK-NEXT: -> tensor<4x8xi32> - @build_function(RankedTensorType.get((4, 16), i8), - RankedTensorType.get((16, 8), i16), - RankedTensorType.get((4, 8), i32)) + @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8), + RankedTensorType.get((16, 8), i16), + RankedTensorType.get((4, 8), i32)) def test_i8i16i32_matmul(lhs, rhs, init_result): return matmul_poly(lhs, rhs, outs=[init_result]) @@ -142,9 +102,9 @@ # CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16 # CHECK-NEXT: linalg.yield %[[ADD]] : i16 # CHECK-NEXT: -> tensor<4x8xi16> - @build_function(RankedTensorType.get((4, 16), i32), - RankedTensorType.get((16, 8), i32), - RankedTensorType.get((4, 8), i16)) + @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32), + RankedTensorType.get((16, 8), i32), + RankedTensorType.get((4, 8), i16)) def test_i32i32i16_matmul(lhs, rhs, init_result): return matmul_poly(lhs, rhs, outs=[init_result]) @@ -156,9 +116,9 @@ # CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 # CHECK-NEXT: linalg.yield %[[ADD]] : f32 # CHECK-NEXT: -> tensor<4x8xf32> - @build_function(RankedTensorType.get((4, 16), i8), - RankedTensorType.get((16, 8), i8), - RankedTensorType.get((4, 8), f32)) + @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8), + RankedTensorType.get((16, 8), i8), + RankedTensorType.get((4, 8), f32)) def test_i8i8f32_matmul(lhs, rhs, init_result): return matmul_poly(lhs, rhs, outs=[init_result]) @@ -170,9 +130,9 @@ # CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 # CHECK-NEXT: linalg.yield %[[ADD]] : f32 # CHECK-NEXT: -> tensor<4x8xf32> - @build_function(RankedTensorType.get((4, 16), f16), - RankedTensorType.get((16, 8), f16), - RankedTensorType.get((4, 8), f32)) + @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f16), + RankedTensorType.get((16, 8), f16), + RankedTensorType.get((4, 8), f32)) def test_f16f16f32_matmul(lhs, rhs, init_result): return matmul_poly(lhs, rhs, outs=[init_result]) @@ -184,9 +144,9 @@ # CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 # CHECK-NEXT: linalg.yield %[[ADD]] : f32 # CHECK-NEXT: -> tensor<4x8xf32> - @build_function(RankedTensorType.get((4, 16), f64), - RankedTensorType.get((16, 8), f64), - RankedTensorType.get((4, 8), f32)) + @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f64), + RankedTensorType.get((16, 8), f64), + RankedTensorType.get((4, 8), f32)) def test_f64f64f32_matmul(lhs, rhs, init_result): return matmul_poly(lhs, rhs, outs=[init_result])