diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -26,12 +26,14 @@ #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" +#include "mlir-c/IntegerSet.h" #include "mlir-c/Pass.h" #define MLIR_PYTHON_CAPSULE_AFFINE_EXPR "mlir.ir.AffineExpr._CAPIPtr" #define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr" #define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr" #define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr" +#define MLIR_PYTHON_CAPSULE_INTEGER_SET "mlir.ir.IntegerSet._CAPIPtr" #define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr" #define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr" #define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr" @@ -240,6 +242,25 @@ return affineMap; } +/** Creates a capsule object encapsulating the raw C-API MlirIntegerSet. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the set in any way. */ +static inline PyObject * +mlirPythonIntegerSetToCapsule(MlirIntegerSet integerSet) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(integerSet), + MLIR_PYTHON_CAPSULE_INTEGER_SET, NULL); +} + +/** Extracts an MlirIntegerSet from a capsule as produced from + * mlirPythonIntegerSetToCapsule. If the capsule is not of the right type, then + * a null set is returned (as checked via mlirIntegerSetIsNull). In such a + * case, the Python APIs will have already set an error. */ +static inline MlirIntegerSet mlirPythonCapsuleToIntegerSet(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_INTEGER_SET); + MlirIntegerSet integerSet = {ptr}; + return integerSet; +} + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -16,6 +16,7 @@ #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" +#include "mlir-c/IntegerSet.h" #include "llvm/ADT/DenseMap.h" namespace mlir { @@ -726,6 +727,26 @@ MlirAffineMap affineMap; }; +class PyIntegerSet : public BaseContextObject { +public: + PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) + : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} + bool operator==(const PyIntegerSet &other); + operator MlirIntegerSet() const { return integerSet; } + MlirIntegerSet get() const { return integerSet; } + + /// Gets a capsule wrapping the void* within the MlirIntegerSet. + pybind11::object getCapsule(); + + /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. + /// Note that PyIntegerSet instances may be uniqued, so the returned object + /// may be a pre-existing object. Integer sets are owned by the context. + static PyIntegerSet createFromCapsule(pybind11::object capsule); + +private: + MlirIntegerSet integerSet; +}; + void populateIRSubmodule(pybind11::module &m); } // namespace python diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -15,6 +15,7 @@ #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" +#include "mlir-c/IntegerSet.h" #include "mlir-c/Registration.h" #include "llvm/ADT/SmallVector.h" #include @@ -3331,6 +3332,102 @@ rawAffineMap); } +//------------------------------------------------------------------------------ +// PyIntegerSet and utilities. +//------------------------------------------------------------------------------ + +class PyIntegerSetConstraint { +public: + PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} + + PyAffineExpr getExpr() { + return PyAffineExpr(set.getContext(), + mlirIntegerSetGetConstraint(set, pos)); + } + + bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } + + static void bind(py::module &m) { + py::class_(m, "IntegerSetConstraint") + .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) + .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); + } + +private: + PyIntegerSet set; + intptr_t pos; +}; + +class PyIntegerSetConstraintList + : public Sliceable { +public: + static constexpr const char *pyClassName = "IntegerSetConstraintList"; + + PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, + step), + set(set) {} + + intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } + + PyIntegerSetConstraint getElement(intptr_t pos) { + return PyIntegerSetConstraint(set, pos); + } + + PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyIntegerSetConstraintList(set, startIndex, length, step); + } + +private: + PyIntegerSet set; +}; + +bool PyIntegerSet::operator==(const PyIntegerSet &other) { + return mlirIntegerSetEqual(integerSet, other.integerSet); +} + +py::object PyIntegerSet::getCapsule() { + return py::reinterpret_steal( + mlirPythonIntegerSetToCapsule(*this)); +} + +PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { + MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); + if (mlirIntegerSetIsNull(rawIntegerSet)) + throw py::error_already_set(); + return PyIntegerSet( + PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), + rawIntegerSet); +} + +/// Attempts to populate `result` with the content of `list` casted to the +/// appropriate type (Python and C types are provided as template arguments). +/// Throws errors in case of failure, using "action" to describe what the caller +/// was attempting to do. +template +static void pyListToVector(py::list list, llvm::SmallVectorImpl &result, + StringRef action) { + result.reserve(py::len(list)); + for (py::handle item : list) { + try { + result.push_back(item.cast()); + } catch (py::cast_error &err) { + std::string msg = (llvm::Twine("Invalid expression when ") + action + + " (" + err.what() + ")") + .str(); + throw py::cast_error(msg); + } catch (py::reference_cast_error &err) { + std::string msg = (llvm::Twine("Invalid expression (None?) when ") + + action + " (" + err.what() + ")") + .str(); + throw py::cast_error(msg); + } + } +} + //------------------------------------------------------------------------------ // Populates the pybind11 IR submodule. //------------------------------------------------------------------------------ @@ -4152,24 +4249,8 @@ [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, DefaultingPyMlirContext context) { SmallVector affineExprs; - affineExprs.reserve(py::len(exprs)); - for (py::handle expr : exprs) { - try { - affineExprs.push_back(expr.cast()); - } catch (py::cast_error &err) { - std::string msg = - std::string("Invalid expression when attempting to create " - "an AffineMap (") + - err.what() + ")"; - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { - std::string msg = - std::string("Invalid expression (None?) when attempting to " - "create an AffineMap (") + - err.what() + ")"; - throw py::cast_error(msg); - } - } + pyListToVector( + exprs, affineExprs, "attempting to create an AffineMap"); MlirAffineMap map = mlirAffineMapGet(context->get(), dimCount, symbolCount, affineExprs.size(), affineExprs.data()); @@ -4275,4 +4356,125 @@ return PyAffineMapExprList(self); }); PyAffineMapExprList::bind(m); + + //---------------------------------------------------------------------------- + // Mapping of PyIntegerSet. + //---------------------------------------------------------------------------- + py::class_(m, "IntegerSet") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyIntegerSet::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) + .def("__eq__", [](PyIntegerSet &self, + PyIntegerSet &other) { return self == other; }) + .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) + .def("__str__", + [](PyIntegerSet &self) { + PyPrintAccumulator printAccum; + mlirIntegerSetPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }) + .def("__repr__", + [](PyIntegerSet &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("IntegerSet("); + mlirIntegerSetPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "context", + [](PyIntegerSet &self) { return self.getContext().getObject(); }) + .def( + "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, + kDumpDocstring) + .def_static( + "get", + [](intptr_t numDims, intptr_t numSymbols, py::list exprs, + std::vector eqFlags, DefaultingPyMlirContext context) { + if (exprs.size() != eqFlags.size()) + throw py::value_error( + "Expected the number of constraints to match " + "that of equality flags"); + if (exprs.empty()) + throw py::value_error("Expected non-empty list of constraints"); + + // Copy over to a SmallVector because std::vector has a + // specialization for booleans that packs data and does not + // expose a `bool *`. + SmallVector flags(eqFlags.begin(), eqFlags.end()); + + SmallVector affineExprs; + pyListToVector(exprs, affineExprs, + "attempting to create an IntegerSet"); + MlirIntegerSet set = mlirIntegerSetGet( + context->get(), numDims, numSymbols, exprs.size(), + affineExprs.data(), flags.data()); + return PyIntegerSet(context->getRef(), set); + }, + py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), + py::arg("eq_flags"), py::arg("context") = py::none()) + .def_static( + "get_empty", + [](intptr_t numDims, intptr_t numSymbols, + DefaultingPyMlirContext context) { + MlirIntegerSet set = + mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); + return PyIntegerSet(context->getRef(), set); + }, + py::arg("num_dims"), py::arg("num_symbols"), + py::arg("context") = py::none()) + .def("get_replaced", + [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, + intptr_t numResultDims, intptr_t numResultSymbols) { + if (static_cast(dimExprs.size()) != + mlirIntegerSetGetNumDims(self)) + throw py::value_error( + "Expected the number of dimension replacement expressions " + "to match that of dimensions"); + if (static_cast(symbolExprs.size()) != + mlirIntegerSetGetNumSymbols(self)) + throw py::value_error( + "Expected the number of symbol replacement expressions " + "to match that of symbols"); + + SmallVector dimAffineExprs, symbolAffineExprs; + pyListToVector( + dimExprs, dimAffineExprs, + "attempting to create an IntegerSet by replacing dimensions"); + pyListToVector( + symbolExprs, symbolAffineExprs, + "attempting to create an IntegerSet by replacing symbols"); + MlirIntegerSet set = mlirIntegerSetReplaceGet( + self, dimAffineExprs.data(), symbolAffineExprs.data(), + numResultDims, numResultSymbols); + return PyIntegerSet(self.getContext(), set); + }) + .def_property_readonly("is_canonical_empty", + [](PyIntegerSet &self) { + return mlirIntegerSetIsCanonicalEmpty(self); + }) + .def_property_readonly( + "n_dims", + [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) + .def_property_readonly( + "n_symbols", + [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) + .def_property_readonly( + "n_inputs", + [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) + .def_property_readonly("n_equalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumEqualities(self); + }) + .def_property_readonly("n_inequalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumInequalities(self); + }) + .def_property_readonly("constraints", [](PyIntegerSet &self) { + return PyIntegerSetConstraintList(self); + }); + PyIntegerSetConstraint::bind(m); + PyIntegerSetConstraintList::bind(m); } diff --git a/mlir/test/Bindings/Python/ir_integer_set.py b/mlir/test/Bindings/Python/ir_integer_set.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/ir_integer_set.py @@ -0,0 +1,128 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +from mlir.ir import * + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + + +# CHECK-LABEL: TEST: testIntegerSetCapsule +def testIntegerSetCapsule(): + with Context() as ctx: + is1 = IntegerSet.get_empty(1, 1, ctx) + capsule = is1._CAPIPtr + # CHECK: mlir.ir.IntegerSet._CAPIPtr + print(capsule) + is2 = IntegerSet._CAPICreate(capsule) + assert is1 == is2 + assert is2.context is ctx + +run(testIntegerSetCapsule) + + +# CHECK-LABEL: TEST: testIntegerSetGet +def testIntegerSetGet(): + with Context(): + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + s0 = AffineSymbolExpr.get(0) + c42 = AffineConstantExpr.get(42) + + # CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0) + set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False]) + print(set0) + + # CHECK: (d0)[s0] : (1 == 0) + set1 = IntegerSet.get_empty(1, 1) + print(set1) + + # CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0) + set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2) + print(set2) + + try: + IntegerSet.get(2, 1, [], []) + except ValueError as e: + # CHECK: Expected non-empty list of constraints + print(e) + + try: + IntegerSet.get(2, 1, [d0 - d1], [True, False]) + except ValueError as e: + # CHECK: Expected the number of constraints to match that of equality flags + print(e) + + try: + IntegerSet.get(2, 1, [0], [True]) + except RuntimeError as e: + # CHECK: Invalid expression when attempting to create an IntegerSet + print(e) + + try: + IntegerSet.get(2, 1, [None], [True]) + except RuntimeError as e: + # CHECK: Invalid expression (None?) when attempting to create an IntegerSet + print(e) + + try: + set0.get_replaced([d0], [s0], 1, 1) + except ValueError as e: + # CHECK: Expected the number of dimension replacement expressions to match that of dimensions + print(e) + + try: + set0.get_replaced([d0, d1], [s0, s0], 1, 1) + except ValueError as e: + # CHECK: Expected the number of symbol replacement expressions to match that of symbols + print(e) + + try: + set0.get_replaced([d0, 1], [s0], 1, 1) + except RuntimeError as e: + # CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions + print(e) + + try: + set0.get_replaced([d0, d1], [None], 1, 1) + except RuntimeError as e: + # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols + print(e) + +run(testIntegerSetGet) + + +# CHECK-LABEL: TEST: testIntegerSetProperties +def testIntegerSetProperties(): + with Context(): + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + s0 = AffineSymbolExpr.get(0) + c42 = AffineConstantExpr.get(42) + + set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False]) + # CHECK: 2 + print(set0.n_dims) + # CHECK: 1 + print(set0.n_symbols) + # CHECK: 3 + print(set0.n_inputs) + # CHECK: 1 + print(set0.n_equalities) + # CHECK: 2 + print(set0.n_inequalities) + + # CHECK: 3 + print(len(set0.constraints)) + + # CHECK-DAG: d0 - d1 == 0 + # CHECK-DAG: s0 - 42 >= 0 + # CHECK-DAG: -d0 + s0 >= 0 + for cstr in set0.constraints: + print(cstr.expr, end='') + print(" == 0" if cstr.is_eq else " >= 0") + +run(testIntegerSetProperties)