diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -385,9 +385,13 @@ step), affineMap(map) {} - intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); } - PyAffineExpr getElement(intptr_t pos) { + PyAffineExpr getRawElement(intptr_t pos) { return PyAffineExpr(affineMap.getContext(), mlirAffineMapGetResult(affineMap, pos)); } @@ -397,7 +401,6 @@ return PyAffineMapExprList(affineMap, startIndex, length, step); } -private: PyAffineMap affineMap; }; } // namespace @@ -460,9 +463,13 @@ step), set(set) {} - intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); } - PyIntegerSetConstraint getElement(intptr_t pos) { + PyIntegerSetConstraint getRawElement(intptr_t pos) { return PyIntegerSetConstraint(set, pos); } @@ -471,7 +478,6 @@ return PyIntegerSetConstraintList(set, startIndex, length, step); } -private: PyIntegerSet set; }; } // namespace diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1968,8 +1968,8 @@ static std::vector getValueTypes(Container &container, PyMlirContextRef &context) { std::vector result; - result.reserve(container.getNumElements()); - for (int i = 0, e = container.getNumElements(); i < e; ++i) { + result.reserve(container.size()); + for (int i = 0, e = container.size(); i < e; ++i) { result.push_back( PyType(context, mlirValueGetType(container.getElement(i).get()))); } @@ -1993,14 +1993,24 @@ step), operation(std::move(operation)), block(block) {} + static void bindDerived(ClassTy &c) { + c.def_property_readonly("types", [](PyBlockArgumentList &self) { + return getValueTypes(self, self.operation->getContext()); + }); + } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + /// Returns the number of arguments in the list. - intptr_t getNumElements() { + intptr_t getRawNumElements() { operation->checkValid(); return mlirBlockGetNumArguments(block); } - /// Returns `pos`-the element in the list. Asserts on out-of-bounds. - PyBlockArgument getElement(intptr_t pos) { + /// Returns `pos`-the element in the list. + PyBlockArgument getRawElement(intptr_t pos) { MlirValue argument = mlirBlockGetArgument(block, pos); return PyBlockArgument(operation, argument); } @@ -2011,13 +2021,6 @@ return PyBlockArgumentList(operation, block, startIndex, length, step); } - static void bindDerived(ClassTy &c) { - c.def_property_readonly("types", [](PyBlockArgumentList &self) { - return getValueTypes(self, self.operation->getContext()); - }); - } - -private: PyOperationRef operation; MlirBlock block; }; @@ -2038,12 +2041,25 @@ step), operation(operation) {} - intptr_t getNumElements() { + void dunderSetItem(intptr_t index, PyValue value) { + index = wrapIndex(index); + mlirOperationSetOperand(operation->get(), index, value.get()); + } + + static void bindDerived(ClassTy &c) { + c.def("__setitem__", &PyOpOperandList::dunderSetItem); + } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { operation->checkValid(); return mlirOperationGetNumOperands(operation->get()); } - PyValue getElement(intptr_t pos) { + PyValue getRawElement(intptr_t pos) { MlirValue operand = mlirOperationGetOperand(operation->get(), pos); MlirOperation owner; if (mlirValueIsAOpResult(operand)) @@ -2061,16 +2077,6 @@ return PyOpOperandList(operation, startIndex, length, step); } - void dunderSetItem(intptr_t index, PyValue value) { - index = wrapIndex(index); - mlirOperationSetOperand(operation->get(), index, value.get()); - } - - static void bindDerived(ClassTy &c) { - c.def("__setitem__", &PyOpOperandList::dunderSetItem); - } - -private: PyOperationRef operation; }; @@ -2090,12 +2096,22 @@ step), operation(operation) {} - intptr_t getNumElements() { + static void bindDerived(ClassTy &c) { + c.def_property_readonly("types", [](PyOpResultList &self) { + return getValueTypes(self, self.operation->getContext()); + }); + } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { operation->checkValid(); return mlirOperationGetNumResults(operation->get()); } - PyOpResult getElement(intptr_t index) { + PyOpResult getRawElement(intptr_t index) { PyValue value(operation, mlirOperationGetResult(operation->get(), index)); return PyOpResult(value); } @@ -2104,13 +2120,6 @@ return PyOpResultList(operation, startIndex, length, step); } - static void bindDerived(ClassTy &c) { - c.def_property_readonly("types", [](PyOpResultList &self) { - return getValueTypes(self, self.operation->getContext()); - }); - } - -private: PyOperationRef operation; }; diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -199,15 +199,17 @@ /// A derived class must provide the following: /// - a `static const char *pyClassName ` field containing the name of the /// Python class to bind; -/// - an instance method `intptr_t getNumElements()` that returns the number +/// - an instance method `intptr_t getRawNumElements()` that returns the +/// number /// of elements in the backing container (NOT that of the slice); -/// - an instance method `ElementTy getElement(intptr_t)` that returns a -/// single element at the given index. +/// - an instance method `ElementTy getRawElement(intptr_t)` that returns a +/// single element at the given linear index (NOT slice index); /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that /// constructs a new instance of the derived pseudo-container with the /// given slice parameters (to be forwarded to the Sliceable constructor). /// -/// The getNumElements() and getElement(intptr_t) callbacks must not throw. +/// The getRawNumElements() and getRawElement(intptr_t) callbacks must not +/// throw. /// /// A derived class may additionally define: /// - a `static void bindDerived(ClassTy &)` method to bind additional methods @@ -217,8 +219,8 @@ protected: using ClassTy = pybind11::class_; - // Transforms `index` into a legal value to access the underlying sequence. - // Returns <0 on failure. + /// Transforms `index` into a legal value to access the underlying sequence. + /// Returns <0 on failure. intptr_t wrapIndex(intptr_t index) { if (index < 0) index = length + index; @@ -227,6 +229,15 @@ return index; } + /// Computes the linear index given the current slice properties. + intptr_t linearizeIndex(intptr_t index) { + intptr_t linearIndex = index * step + startIndex; + assert(linearIndex >= 0 && + linearIndex < static_cast(this)->getRawNumElements() && + "linear index out of bounds, the slice is ill-formed"); + return linearIndex; + } + /// Returns the element at the given slice index. Supports negative indices /// by taking elements in inverse order. Returns a nullptr object if out /// of bounds. @@ -238,13 +249,8 @@ return {}; } - // Compute the linear index given the current slice properties. - int linearIndex = index * step + startIndex; - assert(linearIndex >= 0 && - linearIndex < static_cast(this)->getNumElements() && - "linear index out of bounds, the slice is ill-formed"); return pybind11::cast( - static_cast(this)->getElement(linearIndex)); + static_cast(this)->getRawElement(linearizeIndex(index))); } /// Returns a new instance of the pseudo-container restricted to the given @@ -266,6 +272,21 @@ assert(length >= 0 && "expected non-negative slice length"); } + /// Returns the `index`-th element in the slice, supports negative indices. + /// Throws if the index is out of bounds. + ElementTy getElement(intptr_t index) { + // Negative indices mean we count from the end. + index = wrapIndex(index); + if (index < 0) { + throw pybind11::index_error("index out of range"); + } + + return static_cast(this)->getRawElement(linearizeIndex(index)); + } + + /// Returns the size of slice. + intptr_t size() { return length; } + /// Returns a new vector (mapped to Python list) containing elements from two /// slices. The new vector is necessary because slices may not be contiguous /// or even come from the same original sequence. @@ -276,7 +297,7 @@ elements.push_back(static_cast(this)->getElement(i)); } for (intptr_t i = 0; i < other.length; ++i) { - elements.push_back(static_cast(this)->getElement(i)); + elements.push_back(static_cast(&other)->getElement(i)); } return elements; } diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -185,6 +185,19 @@ for t in entry_block.arguments.types: print("Type: ", t) + # Check that slicing and type access compose. + # CHECK: Sliced type: i16 + # CHECK: Sliced type: i24 + for t in entry_block.arguments[1:].types: + print("Sliced type: ", t) + + # Check that slice addition works as expected. + # CHECK: Argument 2, type i24 + # CHECK: Argument 0, type i8 + restructured = entry_block.arguments[-1:] + entry_block.arguments[:1] + for arg in restructured: + print(f"Argument {arg.arg_number}, type {arg.type}") + # CHECK-LABEL: TEST: testOperationOperands @run