diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -207,6 +207,8 @@ /// constructs a new instance of the derived pseudo-container with the /// given slice parameters (to be forwarded to the Sliceable constructor). /// +/// The getNumElements() and getElement(intptr_t) callbacks must not throw. +/// /// A derived class may additionally define: /// - a `static void bindDerived(ClassTy &)` method to bind additional methods /// the python class. @@ -215,49 +217,53 @@ protected: using ClassTy = pybind11::class_; + // Transforms `index` into a legal value to access the underlying sequence. + // Returns <0 on failure. intptr_t wrapIndex(intptr_t index) { if (index < 0) index = length + index; - if (index < 0 || index >= length) { - throw python::SetPyError(PyExc_IndexError, - "attempt to access out of bounds"); - } + if (index < 0 || index >= length) + return -1; return index; } -public: - explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) - : startIndex(startIndex), length(length), step(step) { - assert(length >= 0 && "expected non-negative slice length"); - } - - /// Returns the length of the slice. - intptr_t dunderLen() const { return length; } - /// Returns the element at the given slice index. Supports negative indices - /// by taking elements in inverse order. Throws if the index is out of bounds. - ElementTy dunderGetItem(intptr_t index) { + /// by taking elements in inverse order. Returns a nullptr object if out + /// of bounds. + pybind11::object getItem(intptr_t index) { // Negative indices mean we count from the end. index = wrapIndex(index); + if (index < 0) { + PyErr_SetString(PyExc_IndexError, "index out of range"); + return {}; + } // Compute the linear index given the current slice properties. int linearIndex = index * step + startIndex; assert(linearIndex >= 0 && linearIndex < static_cast(this)->getNumElements() && "linear index out of bounds, the slice is ill-formed"); - return static_cast(this)->getElement(linearIndex); + return pybind11::cast( + static_cast(this)->getElement(linearIndex)); } /// Returns a new instance of the pseudo-container restricted to the given - /// slice. - Derived dunderGetItemSlice(pybind11::slice slice) { + /// slice. Returns a nullptr object on failure. + pybind11::object getItemSlice(PyObject *slice) { ssize_t start, stop, extraStep, sliceLength; - if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) { - throw python::SetPyError(PyExc_IndexError, - "attempt to access out of bounds"); + if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep, + &sliceLength) != 0) { + PyErr_SetString(PyExc_IndexError, "index out of range"); + return {}; } - return static_cast(this)->slice(startIndex + start * step, - sliceLength, step * extraStep); + return pybind11::cast(static_cast(this)->slice( + startIndex + start * step, sliceLength, step * extraStep)); + } + +public: + explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) + : startIndex(startIndex), length(length), step(step) { + assert(length >= 0 && "expected non-negative slice length"); } /// Returns a new vector (mapped to Python list) containing elements from two @@ -267,10 +273,10 @@ std::vector elements; elements.reserve(length + other.length); for (intptr_t i = 0; i < length; ++i) { - elements.push_back(dunderGetItem(i)); + elements.push_back(static_cast(this)->getElement(i)); } for (intptr_t i = 0; i < other.length; ++i) { - elements.push_back(other.dunderGetItem(i)); + elements.push_back(static_cast(this)->getElement(i)); } return elements; } @@ -279,11 +285,51 @@ static void bind(pybind11::module &m) { auto clazz = pybind11::class_(m, Derived::pyClassName, pybind11::module_local()) - .def("__len__", &Sliceable::dunderLen) - .def("__getitem__", &Sliceable::dunderGetItem) - .def("__getitem__", &Sliceable::dunderGetItemSlice) .def("__add__", &Sliceable::dunderAdd); Derived::bindDerived(clazz); + + // Manually implement the sequence protocol via the C API. We do this + // because it is approx 4x faster than via pybind11, largely because that + // formulation requires a C++ exception to be thrown to detect end of + // sequence. + // Since we are in a C-context, any C++ exception that happens here + // will terminate the program. There is nothing in this implementation + // that should throw in a non-terminal way, so we forgo further + // exception marshalling. + // See: https://github.com/pybind/pybind11/issues/2842 + auto heap_type = reinterpret_cast(clazz.ptr()); + assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && + "must be heap type"); + heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t { + auto self = pybind11::cast(rawSelf); + return self->length; + }; + // sq_item is called as part of the sequence protocol for iteration, + // list construction, etc. + heap_type->as_sequence.sq_item = + +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * { + auto self = pybind11::cast(rawSelf); + return self->getItem(index).release().ptr(); + }; + // mp_subscript is used for both slices and integer lookups. + heap_type->as_mapping.mp_subscript = + +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * { + auto self = pybind11::cast(rawSelf); + Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError); + if (!PyErr_Occurred()) { + // Integer indexing. + return self->getItem(index).release().ptr(); + } + PyErr_Clear(); + + // Assume slice-based indexing. + if (PySlice_Check(rawSubscript)) { + return self->getItemSlice(rawSubscript).release().ptr(); + } + + PyErr_SetString(PyExc_ValueError, "expected integer or slice"); + return nullptr; + }; } /// Hook for derived classes willing to bind more methods. diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -14,6 +14,14 @@ return f +def expect_index_error(callback): + try: + _ = callback() + raise RuntimeError("Expected IndexError") + except IndexError: + pass + + # Verify iterator based traversal of the op/region/block hierarchy. # CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators @run @@ -418,7 +426,9 @@ for t in call.results.types: print(f"Result type {t}") - + # Out of range + expect_index_error(lambda: call.results[3]) + expect_index_error(lambda: call.results[-4]) # CHECK-LABEL: TEST: testOperationResultListSlice @@ -470,8 +480,6 @@ print(f"Result {res.result_number}, type {res.type}") - - # CHECK-LABEL: TEST: testOperationAttributes @run def testOperationAttributes():