diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1594,32 +1594,35 @@ /// elements, random access is cheap. The argument list is associated with the /// operation that contains the block (detached blocks are not allowed in /// Python bindings) and extends its lifetime. -class PyBlockArgumentList { +class PyBlockArgumentList + : public Sliceable { public: - PyBlockArgumentList(PyOperationRef operation, MlirBlock block) - : operation(std::move(operation)), block(block) {} + static constexpr const char *pyClassName = "BlockArgumentList"; - /// Returns the length of the block argument list. - intptr_t dunderLen() { + PyBlockArgumentList(PyOperationRef operation, MlirBlock block, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumArguments(block) : length, + step), + operation(std::move(operation)), block(block) {} + + /// Returns the number of arguments in the list. + intptr_t getNumElements() { operation->checkValid(); return mlirBlockGetNumArguments(block); } - /// Returns `index`-th element of the block argument list. - PyBlockArgument dunderGetItem(intptr_t index) { - if (index < 0 || index >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds region"); - } - PyValue value(operation, mlirBlockGetArgument(block, index)); - return PyBlockArgument(value); + /// Returns `pos`-the element in the list. Asserts on out-of-bounds. + PyBlockArgument getElement(intptr_t pos) { + MlirValue argument = mlirBlockGetArgument(block, pos); + return PyBlockArgument(operation, argument); } - /// Defines a Python class in the bindings. - static void bind(py::module &m) { - py::class_(m, "BlockArgumentList", py::module_local()) - .def("__len__", &PyBlockArgumentList::dunderLen) - .def("__getitem__", &PyBlockArgumentList::dunderGetItem); + /// Returns a sublist of this list. + PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyBlockArgumentList(operation, block, startIndex, length, step); } private: diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -260,13 +260,29 @@ sliceLength, step * extraStep); } + /// Returns a new vector (mapped to Python list) containing elements from two + /// slices. The new vector is necessary because slices may not be contiguous + /// or even come from the same original sequence. + std::vector dunderAdd(Derived &other) { + std::vector elements; + elements.reserve(length + other.length); + for (intptr_t i = 0; i < length; ++i) { + elements.push_back(dunderGetItem(i)); + } + for (intptr_t i = 0; i < other.length; ++i) { + elements.push_back(other.dunderGetItem(i)); + } + return elements; + } + /// Binds the indexing and length methods in the Python class. static void bind(pybind11::module &m) { auto clazz = pybind11::class_(m, Derived::pyClassName, pybind11::module_local()) .def("__len__", &Sliceable::dunderLen) .def("__getitem__", &Sliceable::dunderGetItem) - .def("__getitem__", &Sliceable::dunderGetItemSlice); + .def("__getitem__", &Sliceable::dunderGetItemSlice) + .def("__add__", &Sliceable::dunderAdd); Derived::bindDerived(clazz); } diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py @@ -11,6 +11,8 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e +ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" +RESULT_ATTRIBUTE_NAME = "res_attrs" class ModuleOp: """Specialization for the module op class.""" @@ -100,6 +102,26 @@ self.body.blocks.append(*self.type.inputs) return self.body.blocks[0] + @property + def arg_attrs(self): + return self.attributes[ARGUMENT_ATTRIBUTE_NAME] + + @arg_attrs.setter + def arg_attrs(self, attribute: ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute + @classmethod def from_py_func(FuncOp, *inputs: Type, diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py --- a/mlir/test/python/dialects/builtin.py +++ b/mlir/test/python/dialects/builtin.py @@ -161,3 +161,41 @@ # CHECK: return %arg0 : tensor<2x3x4xf32> # CHECK: } print(m) + + +# CHECK-LABEL: TEST: testFuncArgumentAccess +@run +def testFuncArgumentAccess(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + f64 = F64Type.get() + with InsertionPoint(module.body): + func = builtin.FuncOp("some_func", ([f32, f32], [f64, f64])) + with InsertionPoint(func.add_entry_block()): + std.ReturnOp(func.arguments) + func.arg_attrs = ArrayAttr.get([ + DictAttr.get({ + "foo": StringAttr.get("bar"), + "baz": UnitAttr.get() + }), + DictAttr.get({"qux": ArrayAttr.get([])}) + ]) + func.result_attrs = ArrayAttr.get([ + DictAttr.get({"res1": FloatAttr.get(f32, 42.0)}), + DictAttr.get({"res2": FloatAttr.get(f64, 256.0)}) + ]) + + # CHECK: [{baz, foo = "bar"}, {qux = []}] + print(func.arg_attrs) + + # CHECK: [{res1 = 4.200000e+01 : f32}, {res2 = 2.560000e+02 : f64}] + print(func.result_attrs) + + # CHECK: func @some_func( + # CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"}, + # CHECK: %[[ARG1:.*]]: f32 {qux = []}) -> + # CHECK: f64 {res1 = 4.200000e+01 : f32}, + # CHECK: f64 {res2 = 2.560000e+02 : f64}) + # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32 + print(module) diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -134,6 +134,17 @@ for arg in entry_block.arguments: print(f"Argument {arg.arg_number}, type {arg.type}") + # Check that slicing works for block argument lists. + # CHECK: Argument 1, type i16 + # CHECK: Argument 2, type i24 + for arg in entry_block.arguments[1:]: + print(f"Argument {arg.arg_number}, type {arg.type}") + + # Check that we can concatenate slices of argument lists. + # CHECK: Length: 4 + print("Length: ", + len(entry_block.arguments[:2] + entry_block.arguments[1:])) + run(testBlockArgumentList) @@ -598,22 +609,24 @@ ctx = Context() with Location.unknown(ctx): try: - Operation.create("builtin.module", attributes={None:StringAttr.get("name")}) + Operation.create( + "builtin.module", attributes={None: StringAttr.get("name")}) except Exception as e: # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" print(e) try: - Operation.create("builtin.module", attributes={42:StringAttr.get("name")}) + Operation.create( + "builtin.module", attributes={42: StringAttr.get("name")}) except Exception as e: # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" print(e) try: - Operation.create("builtin.module", attributes={"some_key":ctx}) + Operation.create("builtin.module", attributes={"some_key": ctx}) except Exception as e: # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module" print(e) try: - Operation.create("builtin.module", attributes={"some_key":None}) + Operation.create("builtin.module", attributes={"some_key": None}) except Exception as e: # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module" print(e)