diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -492,6 +492,9 @@ /// Returns whether the value is null. static inline int mlirValueIsNull(MlirValue value) { return !value.ptr; } +/// Returns 1 if two values are equal, 0 otherwise. +int mlirValueEqual(MlirValue value1, MlirValue value2); + /// Returns 1 if the value is a block argument, 0 otherwise. MLIR_CAPI_EXPORTED int mlirValueIsABlockArgument(MlirValue value); diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1213,34 +1213,33 @@ MlirBlock block; }; -/// A list of operation results. Internally, these are stored as consecutive +/// A list of operation operands. Internally, these are stored as consecutive /// elements, random access is cheap. The result list is associated with the /// operation whose results these are, and extends the lifetime of this /// operation. -class PyOpOperandList { +class PyOpOperandList : public Sliceable { public: - PyOpOperandList(PyOperationRef operation) : operation(operation) {} + static constexpr const char *pyClassName = "OpOperandList"; - /// Returns the length of the result list. - intptr_t dunderLen() { + PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumOperands(operation->get()) + : length, + step), + operation(operation) {} + + intptr_t getNumElements() { operation->checkValid(); return mlirOperationGetNumOperands(operation->get()); } - /// Returns `index`-th element in the result list. - PyValue dunderGetItem(intptr_t index) { - if (index < 0 || index >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds region"); - } - return PyValue(operation, mlirOperationGetOperand(operation->get(), index)); + PyValue getElement(intptr_t pos) { + return PyValue(operation, mlirOperationGetOperand(operation->get(), pos)); } - /// Defines a Python class in the bindings. - static void bind(py::module &m) { - py::class_(m, "OpOperandList") - .def("__len__", &PyOpOperandList::dunderLen) - .def("__getitem__", &PyOpOperandList::dunderGetItem); + PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyOpOperandList(operation, startIndex, length, step); } private: @@ -1251,31 +1250,30 @@ /// elements, random access is cheap. The result list is associated with the /// operation whose results these are, and extends the lifetime of this /// operation. -class PyOpResultList { +class PyOpResultList : public Sliceable { public: - PyOpResultList(PyOperationRef operation) : operation(operation) {} + static constexpr const char *pyClassName = "OpResultList"; - /// Returns the length of the result list. - intptr_t dunderLen() { + PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumResults(operation->get()) + : length, + step), + operation(operation) {} + + intptr_t getNumElements() { operation->checkValid(); return mlirOperationGetNumResults(operation->get()); } - /// Returns `index`-th element in the result list. - PyOpResult dunderGetItem(intptr_t index) { - if (index < 0 || index >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds region"); - } + PyOpResult getElement(intptr_t index) { PyValue value(operation, mlirOperationGetResult(operation->get(), index)); return PyOpResult(value); } - /// Defines a Python class in the bindings. - static void bind(py::module &m) { - py::class_(m, "OpResultList") - .def("__len__", &PyOpResultList::dunderLen) - .def("__getitem__", &PyOpResultList::dunderGetItem); + PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyOpResultList(operation, startIndex, length, step); } private: @@ -2932,6 +2930,11 @@ .def( "dump", [](PyValue &self) { mlirValueDump(self.get()); }, kDumpDocstring) + .def("__eq__", + [](PyValue &self, PyValue &other) { + return self.get().ptr == other.get().ptr; + }) + .def("__eq__", [](PyValue &self, py::object other) { return false; }) .def( "__str__", [](PyValue &self) { diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -185,6 +185,93 @@ bool invoked = false; }; +/// A CRTP base class for pseudo-containers willing to support Python-type +/// slicing access on top of indexed access. Calling ::bind on this class +/// will define `__len__` as well as `__getitem__` with integer and slice +/// arguments. +/// +/// This is intended for pseudo-containers that can refer to arbitrary slices of +/// underlying storage indexed by a single integer. Indexing those with an +/// integer produces an instance of ElementTy. Indexing those with a slice +/// produces a new instance of Derived, which can be sliced further. +/// +/// A derived class must provide the following: +/// - a `static const char *pyClassName ` field containing the name of the +/// Python class to bind; +/// - an instance method `intptr_t getNumElements()` that returns the number +/// of elements in the backing container (NOT that of the slice); +/// - an instance method `ElementTy getElement(intptr_t)` that returns a +/// single element at the given index. +/// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that +/// constructs a new instance of the derived pseudo-container with the +/// given slice parameters (to be forwarded to the Sliceable constructor). +/// +/// A derived class may additionally define: +/// - a `static void bindDerived(ClassTy &)` method to bind additional methods +/// the python class. +template +class Sliceable { +protected: + using ClassTy = pybind11::class_; + +public: + explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) + : startIndex(startIndex), length(length), step(step) { + assert(length >= 0 && "expected non-negative slice length"); + } + + /// Returns the length of the slice. + intptr_t dunderLen() const { return length; } + + /// Returns the element at the given slice index. Supports negative indices + /// by taking elements in inverse order. Throws if the index is out of bounds. + ElementTy dunderGetItem(intptr_t index) { + // Negative indices mean we count from the end. + if (index < 0) + index = length + index; + if (index < 0 || index >= length) { + throw python::SetPyError(PyExc_IndexError, + "attempt to access out of bounds"); + } + + // Compute the linear index given the current slice properties. + int linearIndex = index * step + startIndex; + assert(linearIndex >= 0 && + linearIndex < static_cast(this)->getNumElements() && + "linear index out of bounds, the slice is ill-formed"); + return static_cast(this)->getElement(linearIndex); + } + + /// Returns a new instance of the pseudo-container restricted to the given + /// slice. + Derived dunderGetItemSlice(pybind11::slice slice) { + ssize_t start, stop, extraStep, sliceLength; + if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) { + throw python::SetPyError(PyExc_IndexError, + "attempt to access out of bounds"); + } + return static_cast(this)->slice(startIndex + start * step, + sliceLength, step * extraStep); + } + + /// Binds the indexing and length methods in the Python class. + static void bind(pybind11::module &m) { + auto clazz = pybind11::class_(m, Derived::pyClassName) + .def("__len__", &Sliceable::dunderLen) + .def("__getitem__", &Sliceable::dunderGetItem) + .def("__getitem__", &Sliceable::dunderGetItemSlice); + Derived::bindDerived(clazz); + } + + /// Hook for derived classes willing to bind more methods. + static void bindDerived(ClassTy &) {} + +private: + intptr_t startIndex; + intptr_t length; + intptr_t step; +}; + } // namespace mlir #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -479,6 +479,10 @@ // Value API. //===----------------------------------------------------------------------===// +int mlirValueEqual(MlirValue value1, MlirValue value2) { + return unwrap(value1) == unwrap(value2); +} + int mlirValueIsABlockArgument(MlirValue value) { return unwrap(value).isa(); } diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -155,6 +155,64 @@ run(testOperationOperands) +# CHECK-LABEL: TEST: testOperationOperandsSlice +def testOperationOperandsSlice(): + with Context() as ctx: + ctx.allow_unregistered_dialects = True + module = Module.parse(r""" + func @f1() { + %0 = "test.producer0"() : () -> i64 + %1 = "test.producer1"() : () -> i64 + %2 = "test.producer2"() : () -> i64 + %3 = "test.producer3"() : () -> i64 + %4 = "test.producer4"() : () -> i64 + "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> () + return + }""") + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + consumer = entry_block.operations[5] + assert len(consumer.operands) == 5 + for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]): + assert left == right + + # CHECK: test.producer0 + # CHECK: test.producer1 + # CHECK: test.producer2 + # CHECK: test.producer3 + # CHECK: test.producer4 + full_slice = consumer.operands[:] + for operand in full_slice: + print(operand) + + # CHECK: test.producer0 + # CHECK: test.producer1 + first_two = consumer.operands[0:2] + for operand in first_two: + print(operand) + + # CHECK: test.producer3 + # CHECK: test.producer4 + last_two = consumer.operands[3:] + for operand in last_two: + print(operand) + + # CHECK: test.producer0 + # CHECK: test.producer2 + # CHECK: test.producer4 + even = consumer.operands[::2] + for operand in even: + print(operand) + + # CHECK: test.producer2 + fourth = consumer.operands[::2][1::2] + for operand in fourth: + print(operand) + + +run(testOperationOperandsSlice) + + # CHECK-LABEL: TEST: testDetachedOperation def testDetachedOperation(): ctx = Context() @@ -277,6 +335,57 @@ run(testOperationResultList) +# CHECK-LABEL: TEST: testOperationResultListSlice +def testOperationResultListSlice(): + with Context() as ctx: + ctx.allow_unregistered_dialects = True + module = Module.parse(r""" + func @f1() { + "some.op"() : () -> (i1, i2, i3, i4, i5) + return + } + """) + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + producer = entry_block.operations[0] + + assert len(producer.results) == 5 + for left, right in zip(producer.results, producer.results[::-1][::-1]): + assert left == right + assert left.result_number == right.result_number + + # CHECK: Result 0, type i1 + # CHECK: Result 1, type i2 + # CHECK: Result 2, type i3 + # CHECK: Result 3, type i4 + # CHECK: Result 4, type i5 + full_slice = producer.results[:] + for res in full_slice: + print(f"Result {res.result_number}, type {res.type}") + + # CHECK: Result 1, type i2 + # CHECK: Result 2, type i3 + # CHECK: Result 3, type i4 + middle = producer.results[1:4] + for res in middle: + print(f"Result {res.result_number}, type {res.type}") + + # CHECK: Result 1, type i2 + # CHECK: Result 3, type i4 + odd = producer.results[1::2] + for res in odd: + print(f"Result {res.result_number}, type {res.type}") + + # CHECK: Result 3, type i4 + # CHECK: Result 1, type i2 + inverted_middle = producer.results[-2:0:-2] + for res in inverted_middle: + print(f"Result {res.result_number}, type {res.type}") + + +run(testOperationResultListSlice) + + # CHECK-LABEL: TEST: testOperationAttributes def testOperationAttributes(): ctx = Context()