diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from typing import Optional, Sequence + from typing import Optional, Sequence, Union import inspect @@ -82,8 +82,8 @@ return self.attributes["sym_visibility"] @property - def name(self): - return self.attributes["sym_name"] + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) @property def entry_block(self): @@ -104,11 +104,15 @@ @property def arg_attrs(self): - return self.attributes[ARGUMENT_ATTRIBUTE_NAME] + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) @arg_attrs.setter - def arg_attrs(self, attribute: ArrayAttr): - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context) @property def arguments(self): diff --git a/mlir/python/mlir/dialects/_std_ops_ext.py b/mlir/python/mlir/dialects/_std_ops_ext.py --- a/mlir/python/mlir/dialects/_std_ops_ext.py +++ b/mlir/python/mlir/dialects/_std_ops_ext.py @@ -69,3 +69,73 @@ return FloatAttr(self.value).value else: raise ValueError("only integer and float constants have literal values") + + +class CallOp: + """Specialization for the call op class.""" + + def __init__(self, + calleeOrResults: Union[FuncOp, List[Type]], + argumentsOrCallee: Union[List, FlatSymbolRefAttr, str], + arguments: Optional[List] = None, + *, + loc=None, + ip=None): + """Creates an call operation. + + The constructor accepts three different forms: + + 1. A function op to be called followed by a list of arguments. + 2. A list of result types, followed by the name of the function to be + called as string, following by a list of arguments. + 3. A list of result types, followed by the name of the function to be + called as symbol reference attribute, followed by a list of arguments. + + For example + + f = builtin.FuncOp("foo", ...) + std.CallOp(f, [args]) + std.CallOp([result_types], "foo", [args]) + + In all cases, the location and insertion point may be specified as keyword + arguments if not provided by the surrounding context managers. + """ + + # TODO: consider supporting constructor "overloads", e.g., through a custom + # or pybind-provided metaclass. + if isinstance(calleeOrResults, FuncOp): + if not isinstance(argumentsOrCallee, list): + raise ValueError( + "when constructing a call to a function, expected " + + "the second argument to be a list of call arguments, " + + f"got {type(argumentsOrCallee)}") + if arguments is not None: + raise ValueError("unexpected third argument when constructing a call" + + "to a function") + + super().__init__( + calleeOrResults.type.results, + FlatSymbolRefAttr.get( + calleeOrResults.name.value, + context=_get_default_loc_context(loc)), + argumentsOrCallee, + loc=loc, + ip=ip) + return + + if isinstance(argumentsOrCallee, list): + raise ValueError("when constructing a call to a function by name, " + + "expected the second argument to be a string or a " + + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}") + + if isinstance(argumentsOrCallee, FlatSymbolRefAttr): + super().__init__( + calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip) + elif isinstance(argumentsOrCallee, str): + super().__init__( + calleeOrResults, + FlatSymbolRefAttr.get( + argumentsOrCallee, context=_get_default_loc_context(loc)), + arguments, + loc=loc, + ip=ip) diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py --- a/mlir/test/python/dialects/builtin.py +++ b/mlir/test/python/dialects/builtin.py @@ -171,7 +171,7 @@ f32 = F32Type.get() f64 = F64Type.get() with InsertionPoint(module.body): - func = builtin.FuncOp("some_func", ([f32, f32], [f64, f64])) + func = builtin.FuncOp("some_func", ([f32, f32], [f32, f32])) with InsertionPoint(func.add_entry_block()): std.ReturnOp(func.arguments) func.arg_attrs = ArrayAttr.get([ @@ -186,6 +186,14 @@ DictAttr.get({"res2": FloatAttr.get(f64, 256.0)}) ]) + other = builtin.FuncOp("other_func", ([f32, f32], [])) + with InsertionPoint(other.add_entry_block()): + std.ReturnOp([]) + other.arg_attrs = [ + DictAttr.get({"foo": StringAttr.get("qux")}), + DictAttr.get() + ] + # CHECK: [{baz, foo = "bar"}, {qux = []}] print(func.arg_attrs) @@ -195,7 +203,11 @@ # CHECK: func @some_func( # CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"}, # CHECK: %[[ARG1:.*]]: f32 {qux = []}) -> - # CHECK: f64 {res1 = 4.200000e+01 : f32}, - # CHECK: f64 {res2 = 2.560000e+02 : f64}) + # CHECK: f32 {res1 = 4.200000e+01 : f32}, + # CHECK: f32 {res2 = 2.560000e+02 : f64}) # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32 + # + # CHECK: func @other_func( + # CHECK: %{{.*}}: f32 {foo = "qux"}, + # CHECK: %{{.*}}: f32) print(module) diff --git a/mlir/test/python/dialects/std.py b/mlir/test/python/dialects/std.py --- a/mlir/test/python/dialects/std.py +++ b/mlir/test/python/dialects/std.py @@ -1,6 +1,7 @@ # RUN: %PYTHON %s | FileCheck %s from mlir.ir import * +from mlir.dialects import builtin from mlir.dialects import std @@ -62,3 +63,27 @@ print(c1.literal_value) # CHECK: = constant 10 : index + +# CHECK-LABEL: TEST: testFunctionCalls +@constructAndPrintInModule +def testFunctionCalls(): + foo = builtin.FuncOp("foo", ([], [])) + bar = builtin.FuncOp("bar", ([], [IndexType.get()])) + qux = builtin.FuncOp("qux", ([], [F32Type.get()])) + + with InsertionPoint(builtin.FuncOp("caller", ([], [])).add_entry_block()): + std.CallOp(foo, []) + std.CallOp([IndexType.get()], "bar", []) + std.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), []) + std.ReturnOp([]) + +# CHECK: func @foo() +# CHECK: func @bar() -> index +# CHECK: func @qux() -> f32 +# CHECK: func @caller() { +# CHECK: call @foo() : () -> () +# CHECK: %0 = call @bar() : () -> index +# CHECK: %1 = call @qux() : () -> f32 +# CHECK: return +# CHECK: } +