diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -70,7 +70,10 @@ python SOURCES MainModule.cpp - IRModules.cpp + IRAffine.cpp + IRAttributes.cpp + IRCore.cpp + IRTypes.cpp PybindUtils.cpp Pass.cpp ExecutionEngine.cpp diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngine.cpp --- a/mlir/lib/Bindings/Python/ExecutionEngine.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngine.cpp @@ -8,7 +8,7 @@ #include "ExecutionEngine.h" -#include "IRModules.h" +#include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/ExecutionEngine.h" diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -0,0 +1,781 @@ +//===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" + +#include "PybindUtils.h" + +#include "mlir-c/AffineMap.h" +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/IntegerSet.h" + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +static const char kDumpDocstring[] = + R"(Dumps a debug representation of the object to stderr.)"; + +/// Attempts to populate `result` with the content of `list` casted to the +/// appropriate type (Python and C types are provided as template arguments). +/// Throws errors in case of failure, using "action" to describe what the caller +/// was attempting to do. +template +static void pyListToVector(py::list list, llvm::SmallVectorImpl &result, + StringRef action) { + result.reserve(py::len(list)); + for (py::handle item : list) { + try { + result.push_back(item.cast()); + } catch (py::cast_error &err) { + std::string msg = (llvm::Twine("Invalid expression when ") + action + + " (" + err.what() + ")") + .str(); + throw py::cast_error(msg); + } catch (py::reference_cast_error &err) { + std::string msg = (llvm::Twine("Invalid expression (None?) when ") + + action + " (" + err.what() + ")") + .str(); + throw py::cast_error(msg); + } + } +} + +template +static bool isPermutation(std::vector permutation) { + llvm::SmallVector seen(permutation.size(), false); + for (auto val : permutation) { + if (val < permutation.size()) { + if (seen[val]) + return false; + seen[val] = true; + continue; + } + return false; + } + return true; +} + +namespace { + +/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr +/// and should be castable from it. Intermediate hierarchy classes can be +/// modeled by specifying BaseTy. +template +class PyConcreteAffineExpr : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + // and redefine bindDerived. + using ClassTy = py::class_; + using IsAFunctionTy = bool (*)(MlirAffineExpr); + + PyConcreteAffineExpr() = default; + PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) + : BaseTy(std::move(contextRef), affineExpr) {} + PyConcreteAffineExpr(PyAffineExpr &orig) + : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} + + static MlirAffineExpr castFrom(PyAffineExpr &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = py::repr(py::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, + Twine("Cannot cast affine expression to ") + + DerivedTy::pyClassName + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(py::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(py::init()); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +class PyAffineConstantExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; + static constexpr const char *pyClassName = "AffineConstantExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineConstantExpr get(intptr_t value, + DefaultingPyMlirContext context) { + MlirAffineExpr affineExpr = + mlirAffineConstantExprGet(context->get(), static_cast(value)); + return PyAffineConstantExpr(context->getRef(), affineExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), + py::arg("context") = py::none()); + c.def_property_readonly("value", [](PyAffineConstantExpr &self) { + return mlirAffineConstantExprGetValue(self); + }); + } +}; + +class PyAffineDimExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; + static constexpr const char *pyClassName = "AffineDimExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { + MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); + return PyAffineDimExpr(context->getRef(), affineExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), + py::arg("context") = py::none()); + c.def_property_readonly("position", [](PyAffineDimExpr &self) { + return mlirAffineDimExprGetPosition(self); + }); + } +}; + +class PyAffineSymbolExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; + static constexpr const char *pyClassName = "AffineSymbolExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { + MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); + return PyAffineSymbolExpr(context->getRef(), affineExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), + py::arg("context") = py::none()); + c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { + return mlirAffineSymbolExprGetPosition(self); + }); + } +}; + +class PyAffineBinaryExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; + static constexpr const char *pyClassName = "AffineBinaryExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + PyAffineExpr lhs() { + MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); + return PyAffineExpr(getContext(), lhsExpr); + } + + PyAffineExpr rhs() { + MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); + return PyAffineExpr(getContext(), rhsExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); + c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); + } +}; + +class PyAffineAddExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; + static constexpr const char *pyClassName = "AffineAddExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); + return PyAffineAddExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineAddExpr::get); + } +}; + +class PyAffineMulExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; + static constexpr const char *pyClassName = "AffineMulExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); + return PyAffineMulExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineMulExpr::get); + } +}; + +class PyAffineModExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; + static constexpr const char *pyClassName = "AffineModExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); + return PyAffineModExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineModExpr::get); + } +}; + +class PyAffineFloorDivExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; + static constexpr const char *pyClassName = "AffineFloorDivExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); + return PyAffineFloorDivExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineFloorDivExpr::get); + } +}; + +class PyAffineCeilDivExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; + static constexpr const char *pyClassName = "AffineCeilDivExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); + return PyAffineCeilDivExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineCeilDivExpr::get); + } +}; + +} // namespace + +bool PyAffineExpr::operator==(const PyAffineExpr &other) { + return mlirAffineExprEqual(affineExpr, other.affineExpr); +} + +py::object PyAffineExpr::getCapsule() { + return py::reinterpret_steal( + mlirPythonAffineExprToCapsule(*this)); +} + +PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { + MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); + if (mlirAffineExprIsNull(rawAffineExpr)) + throw py::error_already_set(); + return PyAffineExpr( + PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), + rawAffineExpr); +} + +//------------------------------------------------------------------------------ +// PyAffineMap and utilities. +//------------------------------------------------------------------------------ +namespace { + +/// A list of expressions contained in an affine map. Internally these are +/// stored as a consecutive array leading to inexpensive random access. Both +/// the map and the expression are owned by the context so we need not bother +/// with lifetime extension. +class PyAffineMapExprList + : public Sliceable { +public: + static constexpr const char *pyClassName = "AffineExprList"; + + PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirAffineMapGetNumResults(map) : length, + step), + affineMap(map) {} + + intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } + + PyAffineExpr getElement(intptr_t pos) { + return PyAffineExpr(affineMap.getContext(), + mlirAffineMapGetResult(affineMap, pos)); + } + + PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyAffineMapExprList(affineMap, startIndex, length, step); + } + +private: + PyAffineMap affineMap; +}; +} // end namespace + +bool PyAffineMap::operator==(const PyAffineMap &other) { + return mlirAffineMapEqual(affineMap, other.affineMap); +} + +py::object PyAffineMap::getCapsule() { + return py::reinterpret_steal(mlirPythonAffineMapToCapsule(*this)); +} + +PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { + MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); + if (mlirAffineMapIsNull(rawAffineMap)) + throw py::error_already_set(); + return PyAffineMap( + PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), + rawAffineMap); +} + +//------------------------------------------------------------------------------ +// PyIntegerSet and utilities. +//------------------------------------------------------------------------------ +namespace { + +class PyIntegerSetConstraint { +public: + PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} + + PyAffineExpr getExpr() { + return PyAffineExpr(set.getContext(), + mlirIntegerSetGetConstraint(set, pos)); + } + + bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } + + static void bind(py::module &m) { + py::class_(m, "IntegerSetConstraint") + .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) + .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); + } + +private: + PyIntegerSet set; + intptr_t pos; +}; + +class PyIntegerSetConstraintList + : public Sliceable { +public: + static constexpr const char *pyClassName = "IntegerSetConstraintList"; + + PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, + step), + set(set) {} + + intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } + + PyIntegerSetConstraint getElement(intptr_t pos) { + return PyIntegerSetConstraint(set, pos); + } + + PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyIntegerSetConstraintList(set, startIndex, length, step); + } + +private: + PyIntegerSet set; +}; +} // namespace + +bool PyIntegerSet::operator==(const PyIntegerSet &other) { + return mlirIntegerSetEqual(integerSet, other.integerSet); +} + +py::object PyIntegerSet::getCapsule() { + return py::reinterpret_steal( + mlirPythonIntegerSetToCapsule(*this)); +} + +PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { + MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); + if (mlirIntegerSetIsNull(rawIntegerSet)) + throw py::error_already_set(); + return PyIntegerSet( + PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), + rawIntegerSet); +} + +void mlir::python::populateIRAffine(py::module &m) { + //---------------------------------------------------------------------------- + // Mapping of PyAffineExpr and derived classes. + //---------------------------------------------------------------------------- + py::class_(m, "AffineExpr") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyAffineExpr::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) + .def("__add__", + [](PyAffineExpr &self, PyAffineExpr &other) { + return PyAffineAddExpr::get(self, other); + }) + .def("__mul__", + [](PyAffineExpr &self, PyAffineExpr &other) { + return PyAffineMulExpr::get(self, other); + }) + .def("__mod__", + [](PyAffineExpr &self, PyAffineExpr &other) { + return PyAffineModExpr::get(self, other); + }) + .def("__sub__", + [](PyAffineExpr &self, PyAffineExpr &other) { + auto negOne = + PyAffineConstantExpr::get(-1, *self.getContext().get()); + return PyAffineAddExpr::get(self, + PyAffineMulExpr::get(negOne, other)); + }) + .def("__eq__", [](PyAffineExpr &self, + PyAffineExpr &other) { return self == other; }) + .def("__eq__", + [](PyAffineExpr &self, py::object &other) { return false; }) + .def("__str__", + [](PyAffineExpr &self) { + PyPrintAccumulator printAccum; + mlirAffineExprPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }) + .def("__repr__", + [](PyAffineExpr &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("AffineExpr("); + mlirAffineExprPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "context", + [](PyAffineExpr &self) { return self.getContext().getObject(); }) + .def_static( + "get_add", &PyAffineAddExpr::get, + "Gets an affine expression containing a sum of two expressions.") + .def_static( + "get_mul", &PyAffineMulExpr::get, + "Gets an affine expression containing a product of two expressions.") + .def_static("get_mod", &PyAffineModExpr::get, + "Gets an affine expression containing the modulo of dividing " + "one expression by another.") + .def_static("get_floor_div", &PyAffineFloorDivExpr::get, + "Gets an affine expression containing the rounded-down " + "result of dividing one expression by another.") + .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, + "Gets an affine expression containing the rounded-up result " + "of dividing one expression by another.") + .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), + py::arg("context") = py::none(), + "Gets a constant affine expression with the given value.") + .def_static( + "get_dim", &PyAffineDimExpr::get, py::arg("position"), + py::arg("context") = py::none(), + "Gets an affine expression of a dimension at the given position.") + .def_static( + "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), + py::arg("context") = py::none(), + "Gets an affine expression of a symbol at the given position.") + .def( + "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, + kDumpDocstring); + PyAffineConstantExpr::bind(m); + PyAffineDimExpr::bind(m); + PyAffineSymbolExpr::bind(m); + PyAffineBinaryExpr::bind(m); + PyAffineAddExpr::bind(m); + PyAffineMulExpr::bind(m); + PyAffineModExpr::bind(m); + PyAffineFloorDivExpr::bind(m); + PyAffineCeilDivExpr::bind(m); + + //---------------------------------------------------------------------------- + // Mapping of PyAffineMap. + //---------------------------------------------------------------------------- + py::class_(m, "AffineMap") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyAffineMap::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) + .def("__eq__", + [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) + .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) + .def("__str__", + [](PyAffineMap &self) { + PyPrintAccumulator printAccum; + mlirAffineMapPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }) + .def("__repr__", + [](PyAffineMap &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("AffineMap("); + mlirAffineMapPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "context", + [](PyAffineMap &self) { return self.getContext().getObject(); }, + "Context that owns the Affine Map") + .def( + "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, + kDumpDocstring) + .def_static( + "get", + [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, + DefaultingPyMlirContext context) { + SmallVector affineExprs; + pyListToVector( + exprs, affineExprs, "attempting to create an AffineMap"); + MlirAffineMap map = + mlirAffineMapGet(context->get(), dimCount, symbolCount, + affineExprs.size(), affineExprs.data()); + return PyAffineMap(context->getRef(), map); + }, + py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), + py::arg("context") = py::none(), + "Gets a map with the given expressions as results.") + .def_static( + "get_constant", + [](intptr_t value, DefaultingPyMlirContext context) { + MlirAffineMap affineMap = + mlirAffineMapConstantGet(context->get(), value); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets an affine map with a single constant result") + .def_static( + "get_empty", + [](DefaultingPyMlirContext context) { + MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("context") = py::none(), "Gets an empty affine map.") + .def_static( + "get_identity", + [](intptr_t nDims, DefaultingPyMlirContext context) { + MlirAffineMap affineMap = + mlirAffineMapMultiDimIdentityGet(context->get(), nDims); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("n_dims"), py::arg("context") = py::none(), + "Gets an identity map with the given number of dimensions.") + .def_static( + "get_minor_identity", + [](intptr_t nDims, intptr_t nResults, + DefaultingPyMlirContext context) { + MlirAffineMap affineMap = + mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("n_dims"), py::arg("n_results"), + py::arg("context") = py::none(), + "Gets a minor identity map with the given number of dimensions and " + "results.") + .def_static( + "get_permutation", + [](std::vector permutation, + DefaultingPyMlirContext context) { + if (!isPermutation(permutation)) + throw py::cast_error("Invalid permutation when attempting to " + "create an AffineMap"); + MlirAffineMap affineMap = mlirAffineMapPermutationGet( + context->get(), permutation.size(), permutation.data()); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("permutation"), py::arg("context") = py::none(), + "Gets an affine map that permutes its inputs.") + .def("get_submap", + [](PyAffineMap &self, std::vector &resultPos) { + intptr_t numResults = mlirAffineMapGetNumResults(self); + for (intptr_t pos : resultPos) { + if (pos < 0 || pos >= numResults) + throw py::value_error("result position out of bounds"); + } + MlirAffineMap affineMap = mlirAffineMapGetSubMap( + self, resultPos.size(), resultPos.data()); + return PyAffineMap(self.getContext(), affineMap); + }) + .def("get_major_submap", + [](PyAffineMap &self, intptr_t nResults) { + if (nResults >= mlirAffineMapGetNumResults(self)) + throw py::value_error("number of results out of bounds"); + MlirAffineMap affineMap = + mlirAffineMapGetMajorSubMap(self, nResults); + return PyAffineMap(self.getContext(), affineMap); + }) + .def("get_minor_submap", + [](PyAffineMap &self, intptr_t nResults) { + if (nResults >= mlirAffineMapGetNumResults(self)) + throw py::value_error("number of results out of bounds"); + MlirAffineMap affineMap = + mlirAffineMapGetMinorSubMap(self, nResults); + return PyAffineMap(self.getContext(), affineMap); + }) + .def_property_readonly( + "is_permutation", + [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) + .def_property_readonly("is_projected_permutation", + [](PyAffineMap &self) { + return mlirAffineMapIsProjectedPermutation(self); + }) + .def_property_readonly( + "n_dims", + [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) + .def_property_readonly( + "n_inputs", + [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) + .def_property_readonly( + "n_symbols", + [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) + .def_property_readonly("results", [](PyAffineMap &self) { + return PyAffineMapExprList(self); + }); + PyAffineMapExprList::bind(m); + + //---------------------------------------------------------------------------- + // Mapping of PyIntegerSet. + //---------------------------------------------------------------------------- + py::class_(m, "IntegerSet") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyIntegerSet::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) + .def("__eq__", [](PyIntegerSet &self, + PyIntegerSet &other) { return self == other; }) + .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) + .def("__str__", + [](PyIntegerSet &self) { + PyPrintAccumulator printAccum; + mlirIntegerSetPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }) + .def("__repr__", + [](PyIntegerSet &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("IntegerSet("); + mlirIntegerSetPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "context", + [](PyIntegerSet &self) { return self.getContext().getObject(); }) + .def( + "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, + kDumpDocstring) + .def_static( + "get", + [](intptr_t numDims, intptr_t numSymbols, py::list exprs, + std::vector eqFlags, DefaultingPyMlirContext context) { + if (exprs.size() != eqFlags.size()) + throw py::value_error( + "Expected the number of constraints to match " + "that of equality flags"); + if (exprs.empty()) + throw py::value_error("Expected non-empty list of constraints"); + + // Copy over to a SmallVector because std::vector has a + // specialization for booleans that packs data and does not + // expose a `bool *`. + SmallVector flags(eqFlags.begin(), eqFlags.end()); + + SmallVector affineExprs; + pyListToVector(exprs, affineExprs, + "attempting to create an IntegerSet"); + MlirIntegerSet set = mlirIntegerSetGet( + context->get(), numDims, numSymbols, exprs.size(), + affineExprs.data(), flags.data()); + return PyIntegerSet(context->getRef(), set); + }, + py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), + py::arg("eq_flags"), py::arg("context") = py::none()) + .def_static( + "get_empty", + [](intptr_t numDims, intptr_t numSymbols, + DefaultingPyMlirContext context) { + MlirIntegerSet set = + mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); + return PyIntegerSet(context->getRef(), set); + }, + py::arg("num_dims"), py::arg("num_symbols"), + py::arg("context") = py::none()) + .def("get_replaced", + [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, + intptr_t numResultDims, intptr_t numResultSymbols) { + if (static_cast(dimExprs.size()) != + mlirIntegerSetGetNumDims(self)) + throw py::value_error( + "Expected the number of dimension replacement expressions " + "to match that of dimensions"); + if (static_cast(symbolExprs.size()) != + mlirIntegerSetGetNumSymbols(self)) + throw py::value_error( + "Expected the number of symbol replacement expressions " + "to match that of symbols"); + + SmallVector dimAffineExprs, symbolAffineExprs; + pyListToVector( + dimExprs, dimAffineExprs, + "attempting to create an IntegerSet by replacing dimensions"); + pyListToVector( + symbolExprs, symbolAffineExprs, + "attempting to create an IntegerSet by replacing symbols"); + MlirIntegerSet set = mlirIntegerSetReplaceGet( + self, dimAffineExprs.data(), symbolAffineExprs.data(), + numResultDims, numResultSymbols); + return PyIntegerSet(self.getContext(), set); + }) + .def_property_readonly("is_canonical_empty", + [](PyIntegerSet &self) { + return mlirIntegerSetIsCanonicalEmpty(self); + }) + .def_property_readonly( + "n_dims", + [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) + .def_property_readonly( + "n_symbols", + [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) + .def_property_readonly( + "n_inputs", + [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) + .def_property_readonly("n_equalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumEqualities(self); + }) + .def_property_readonly("n_inequalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumInequalities(self); + }) + .def_property_readonly("constraints", [](PyIntegerSet &self) { + return PyIntegerSetConstraintList(self); + }); + PyIntegerSetConstraint::bind(m); + PyIntegerSetConstraintList::bind(m); +} diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -0,0 +1,761 @@ +//===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" + +#include "PybindUtils.h" + +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/BuiltinTypes.h" + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +static MlirStringRef toMlirStringRef(const std::string &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +/// CRTP base classes for Python attributes that subclass Attribute and should +/// be castable from it (i.e. via something like StringAttr(attr)). +/// By default, attribute class hierarchies are one level deep (i.e. a +/// concrete attribute class extends PyAttribute); however, intermediate +/// python-visible base classes can be modeled by specifying a BaseTy. +template +class PyConcreteAttribute : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = py::class_; + using IsAFunctionTy = bool (*)(MlirAttribute); + + PyConcreteAttribute() = default; + PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) + : BaseTy(std::move(contextRef), attr) {} + PyConcreteAttribute(PyAttribute &orig) + : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} + + static MlirAttribute castFrom(PyAttribute &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = py::repr(py::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") + + DerivedTy::pyClassName + + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(py::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol()); + cls.def(py::init(), py::keep_alive<0, 1>()); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +class PyAffineMapAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; + static constexpr const char *pyClassName = "AffineMapAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyAffineMap &affineMap) { + MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); + return PyAffineMapAttribute(affineMap.getContext(), attr); + }, + py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); + } +}; + +class PyArrayAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; + static constexpr const char *pyClassName = "ArrayAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + class PyArrayAttributeIterator { + public: + PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} + + PyArrayAttributeIterator &dunderIter() { return *this; } + + PyAttribute dunderNext() { + if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { + throw py::stop_iteration(); + } + return PyAttribute(attr.getContext(), + mlirArrayAttrGetElement(attr.get(), nextIndex++)); + } + + static void bind(py::module &m) { + py::class_(m, "ArrayAttributeIterator") + .def("__iter__", &PyArrayAttributeIterator::dunderIter) + .def("__next__", &PyArrayAttributeIterator::dunderNext); + } + + private: + PyAttribute attr; + int nextIndex = 0; + }; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](py::list attributes, DefaultingPyMlirContext context) { + SmallVector mlirAttributes; + mlirAttributes.reserve(py::len(attributes)); + for (auto attribute : attributes) { + try { + mlirAttributes.push_back(attribute.cast()); + } catch (py::cast_error &err) { + std::string msg = std::string("Invalid attribute when attempting " + "to create an ArrayAttribute (") + + err.what() + ")"; + throw py::cast_error(msg); + } catch (py::reference_cast_error &err) { + // This exception seems thrown when the value is "None". + std::string msg = + std::string("Invalid attribute (None?) when attempting to " + "create an ArrayAttribute (") + + err.what() + ")"; + throw py::cast_error(msg); + } + } + MlirAttribute attr = mlirArrayAttrGet( + context->get(), mlirAttributes.size(), mlirAttributes.data()); + return PyArrayAttribute(context->getRef(), attr); + }, + py::arg("attributes"), py::arg("context") = py::none(), + "Gets a uniqued Array attribute"); + c.def("__getitem__", + [](PyArrayAttribute &arr, intptr_t i) { + if (i >= mlirArrayAttrGetNumElements(arr)) + throw py::index_error("ArrayAttribute index out of range"); + return PyAttribute(arr.getContext(), + mlirArrayAttrGetElement(arr, i)); + }) + .def("__len__", + [](const PyArrayAttribute &arr) { + return mlirArrayAttrGetNumElements(arr); + }) + .def("__iter__", [](const PyArrayAttribute &arr) { + return PyArrayAttributeIterator(arr); + }); + } +}; + +/// Float Point Attribute subclass - FloatAttr. +class PyFloatAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; + static constexpr const char *pyClassName = "FloatAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &type, double value, DefaultingPyLocation loc) { + MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirAttributeIsNull(attr)) { + throw SetPyError(PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(type)).cast() + + "' and expected floating point type."); + } + return PyFloatAttribute(type.getContext(), attr); + }, + py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), + "Gets an uniqued float point attribute associated to a type"); + c.def_static( + "get_f32", + [](double value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirFloatAttrDoubleGet( + context->get(), mlirF32TypeGet(context->get()), value); + return PyFloatAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets an uniqued float point attribute associated to a f32 type"); + c.def_static( + "get_f64", + [](double value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirFloatAttrDoubleGet( + context->get(), mlirF64TypeGet(context->get()), value); + return PyFloatAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets an uniqued float point attribute associated to a f64 type"); + c.def_property_readonly( + "value", + [](PyFloatAttribute &self) { + return mlirFloatAttrGetValueDouble(self); + }, + "Returns the value of the float point attribute"); + } +}; + +/// Integer Attribute subclass - IntegerAttr. +class PyIntegerAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; + static constexpr const char *pyClassName = "IntegerAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &type, int64_t value) { + MlirAttribute attr = mlirIntegerAttrGet(type, value); + return PyIntegerAttribute(type.getContext(), attr); + }, + py::arg("type"), py::arg("value"), + "Gets an uniqued integer attribute associated to a type"); + c.def_property_readonly( + "value", + [](PyIntegerAttribute &self) { + return mlirIntegerAttrGetValueInt(self); + }, + "Returns the value of the integer attribute"); + } +}; + +/// Bool Attribute subclass - BoolAttr. +class PyBoolAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; + static constexpr const char *pyClassName = "BoolAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](bool value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirBoolAttrGet(context->get(), value); + return PyBoolAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets an uniqued bool attribute"); + c.def_property_readonly( + "value", + [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, + "Returns the value of the bool attribute"); + } +}; + +class PyFlatSymbolRefAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; + static constexpr const char *pyClassName = "FlatSymbolRefAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); + return PyFlatSymbolRefAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets a uniqued FlatSymbolRef attribute"); + c.def_property_readonly( + "value", + [](PyFlatSymbolRefAttribute &self) { + MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the value of the FlatSymbolRef attribute as a string"); + } +}; + +class PyStringAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; + static constexpr const char *pyClassName = "StringAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirStringAttrGet(context->get(), toMlirStringRef(value)); + return PyStringAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets a uniqued string attribute"); + c.def_static( + "get_typed", + [](PyType &type, std::string value) { + MlirAttribute attr = + mlirStringAttrTypedGet(type, toMlirStringRef(value)); + return PyStringAttribute(type.getContext(), attr); + }, + + "Gets a uniqued string attribute associated to a type"); + c.def_property_readonly( + "value", + [](PyStringAttribute &self) { + MlirStringRef stringRef = mlirStringAttrGetValue(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the value of the string attribute"); + } +}; + +// TODO: Support construction of bool elements. +// TODO: Support construction of string elements. +class PyDenseElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; + static constexpr const char *pyClassName = "DenseElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static PyDenseElementsAttribute + getFromBuffer(py::buffer array, bool signless, + DefaultingPyMlirContext contextWrapper) { + // Request a contiguous view. In exotic cases, this will cause a copy. + int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; + Py_buffer *view = new Py_buffer(); + if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { + delete view; + throw py::error_already_set(); + } + py::buffer_info arrayInfo(view); + + MlirContext context = contextWrapper->get(); + // Switch on the types that can be bulk loaded between the Python and + // MLIR-C APIs. + // See: https://docs.python.org/3/library/struct.html#format-characters + if (arrayInfo.format == "f") { + // f32 + assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); + return PyDenseElementsAttribute( + contextWrapper->getRef(), + bulkLoad(context, mlirDenseElementsAttrFloatGet, + mlirF32TypeGet(context), arrayInfo)); + } else if (arrayInfo.format == "d") { + // f64 + assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); + return PyDenseElementsAttribute( + contextWrapper->getRef(), + bulkLoad(context, mlirDenseElementsAttrDoubleGet, + mlirF64TypeGet(context), arrayInfo)); + } else if (isSignedIntegerFormat(arrayInfo.format)) { + if (arrayInfo.itemsize == 4) { + // i32 + MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); + return PyDenseElementsAttribute(contextWrapper->getRef(), + bulkLoad(context, + mlirDenseElementsAttrInt32Get, + elementType, arrayInfo)); + } else if (arrayInfo.itemsize == 8) { + // i64 + MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + return PyDenseElementsAttribute(contextWrapper->getRef(), + bulkLoad(context, + mlirDenseElementsAttrInt64Get, + elementType, arrayInfo)); + } + } else if (isUnsignedIntegerFormat(arrayInfo.format)) { + if (arrayInfo.itemsize == 4) { + // unsigned i32 + MlirType elementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); + return PyDenseElementsAttribute(contextWrapper->getRef(), + bulkLoad(context, + mlirDenseElementsAttrUInt32Get, + elementType, arrayInfo)); + } else if (arrayInfo.itemsize == 8) { + // unsigned i64 + MlirType elementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + return PyDenseElementsAttribute(contextWrapper->getRef(), + bulkLoad(context, + mlirDenseElementsAttrUInt64Get, + elementType, arrayInfo)); + } + } + + // TODO: Fall back to string-based get. + std::string message = "unimplemented array format conversion from format: "; + message.append(arrayInfo.format); + throw SetPyError(PyExc_ValueError, message); + } + + static PyDenseElementsAttribute getSplat(PyType shapedType, + PyAttribute &elementAttr) { + auto contextWrapper = + PyMlirContext::forContext(mlirTypeGetContext(shapedType)); + if (!mlirAttributeIsAInteger(elementAttr) && + !mlirAttributeIsAFloat(elementAttr)) { + std::string message = "Illegal element type for DenseElementsAttr: "; + message.append(py::repr(py::cast(elementAttr))); + throw SetPyError(PyExc_ValueError, message); + } + if (!mlirTypeIsAShaped(shapedType) || + !mlirShapedTypeHasStaticShape(shapedType)) { + std::string message = + "Expected a static ShapedType for the shaped_type parameter: "; + message.append(py::repr(py::cast(shapedType))); + throw SetPyError(PyExc_ValueError, message); + } + MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); + MlirType attrType = mlirAttributeGetType(elementAttr); + if (!mlirTypeEqual(shapedElementType, attrType)) { + std::string message = + "Shaped element type and attribute type must be equal: shaped="; + message.append(py::repr(py::cast(shapedType))); + message.append(", element="); + message.append(py::repr(py::cast(elementAttr))); + throw SetPyError(PyExc_ValueError, message); + } + + MlirAttribute elements = + mlirDenseElementsAttrSplatGet(shapedType, elementAttr); + return PyDenseElementsAttribute(contextWrapper->getRef(), elements); + } + + intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } + + py::buffer_info accessBuffer() { + MlirType shapedType = mlirAttributeGetType(*this); + MlirType elementType = mlirShapedTypeGetElementType(shapedType); + + if (mlirTypeIsAF32(elementType)) { + // f32 + return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); + } else if (mlirTypeIsAF64(elementType)) { + // f64 + return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 32) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i32 + return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); + } else if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i32 + return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); + } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 64) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i64 + return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); + } else if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i64 + return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); + } + } + + std::string message = "unimplemented array format."; + throw SetPyError(PyExc_ValueError, message); + } + + static void bindDerived(ClassTy &c) { + c.def("__len__", &PyDenseElementsAttribute::dunderLen) + .def_static("get", PyDenseElementsAttribute::getFromBuffer, + py::arg("array"), py::arg("signless") = true, + py::arg("context") = py::none(), + "Gets from a buffer or ndarray") + .def_static("get_splat", PyDenseElementsAttribute::getSplat, + py::arg("shaped_type"), py::arg("element_attr"), + "Gets a DenseElementsAttr where all values are the same") + .def_property_readonly("is_splat", + [](PyDenseElementsAttribute &self) -> bool { + return mlirDenseElementsAttrIsSplat(self); + }) + .def_buffer(&PyDenseElementsAttribute::accessBuffer); + } + +private: + template + static MlirAttribute + bulkLoad(MlirContext context, + MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), + MlirType mlirElementType, py::buffer_info &arrayInfo) { + SmallVector shape(arrayInfo.shape.begin(), + arrayInfo.shape.begin() + arrayInfo.ndim); + auto shapedType = + mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType); + intptr_t numElements = arrayInfo.size; + const ElementTy *contents = static_cast(arrayInfo.ptr); + return ctor(shapedType, numElements, contents); + } + + static bool isUnsignedIntegerFormat(const std::string &format) { + if (format.empty()) + return false; + char code = format[0]; + return code == 'I' || code == 'B' || code == 'H' || code == 'L' || + code == 'Q'; + } + + static bool isSignedIntegerFormat(const std::string &format) { + if (format.empty()) + return false; + char code = format[0]; + return code == 'i' || code == 'b' || code == 'h' || code == 'l' || + code == 'q'; + } + + template + py::buffer_info bufferInfo(MlirType shapedType, + Type (*value)(MlirAttribute, intptr_t)) { + intptr_t rank = mlirShapedTypeGetRank(shapedType); + // Prepare the data for the buffer_info. + // Buffer is configured for read-only access below. + Type *data = static_cast( + const_cast(mlirDenseElementsAttrGetRawData(*this))); + // Prepare the shape for the buffer_info. + SmallVector shape; + for (intptr_t i = 0; i < rank; ++i) + shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); + // Prepare the strides for the buffer_info. + SmallVector strides; + intptr_t strideFactor = 1; + for (intptr_t i = 1; i < rank; ++i) { + strideFactor = 1; + for (intptr_t j = i; j < rank; ++j) { + strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); + } + strides.push_back(sizeof(Type) * strideFactor); + } + strides.push_back(sizeof(Type)); + return py::buffer_info(data, sizeof(Type), + py::format_descriptor::format(), rank, shape, + strides, /*readonly=*/true); + } +}; // namespace + +/// Refinement of the PyDenseElementsAttribute for attributes containing integer +/// (and boolean) values. Supports element access. +class PyDenseIntElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; + static constexpr const char *pyClassName = "DenseIntElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + /// Returns the element at the given linear position. Asserts if the index is + /// out of range. + py::int_ dunderGetItem(intptr_t pos) { + if (pos < 0 || pos >= dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds element"); + } + + MlirType type = mlirAttributeGetType(*this); + type = mlirShapedTypeGetElementType(type); + assert(mlirTypeIsAInteger(type) && + "expected integer element type in dense int elements attribute"); + // Dispatch element extraction to an appropriate C function based on the + // elemental type of the attribute. py::int_ is implicitly constructible + // from any C++ integral type and handles bitwidth correctly. + // TODO: consider caching the type properties in the constructor to avoid + // querying them on each element access. + unsigned width = mlirIntegerTypeGetWidth(type); + bool isUnsigned = mlirIntegerTypeIsUnsigned(type); + if (isUnsigned) { + if (width == 1) { + return mlirDenseElementsAttrGetBoolValue(*this, pos); + } + if (width == 32) { + return mlirDenseElementsAttrGetUInt32Value(*this, pos); + } + if (width == 64) { + return mlirDenseElementsAttrGetUInt64Value(*this, pos); + } + } else { + if (width == 1) { + return mlirDenseElementsAttrGetBoolValue(*this, pos); + } + if (width == 32) { + return mlirDenseElementsAttrGetInt32Value(*this, pos); + } + if (width == 64) { + return mlirDenseElementsAttrGetInt64Value(*this, pos); + } + } + throw SetPyError(PyExc_TypeError, "Unsupported integer type"); + } + + static void bindDerived(ClassTy &c) { + c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); + } +}; + +class PyDictAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; + static constexpr const char *pyClassName = "DictAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } + + static void bindDerived(ClassTy &c) { + c.def("__len__", &PyDictAttribute::dunderLen); + c.def_static( + "get", + [](py::dict attributes, DefaultingPyMlirContext context) { + SmallVector mlirNamedAttributes; + mlirNamedAttributes.reserve(attributes.size()); + for (auto &it : attributes) { + auto &mlir_attr = it.second.cast(); + auto name = it.first.cast(); + mlirNamedAttributes.push_back(mlirNamedAttributeGet( + mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), + toMlirStringRef(name)), + mlir_attr)); + } + MlirAttribute attr = + mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), + mlirNamedAttributes.data()); + return PyDictAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets an uniqued dict attribute"); + c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { + MlirAttribute attr = + mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); + if (mlirAttributeIsNull(attr)) { + throw SetPyError(PyExc_KeyError, + "attempt to access a non-existent attribute"); + } + return PyAttribute(self.getContext(), attr); + }); + c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { + if (index < 0 || index >= self.dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds attribute"); + } + MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); + return PyNamedAttribute( + namedAttr.attribute, + std::string(mlirIdentifierStr(namedAttr.name).data)); + }); + } +}; + +/// Refinement of PyDenseElementsAttribute for attributes containing +/// floating-point values. Supports element access. +class PyDenseFPElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; + static constexpr const char *pyClassName = "DenseFPElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + py::float_ dunderGetItem(intptr_t pos) { + if (pos < 0 || pos >= dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds element"); + } + + MlirType type = mlirAttributeGetType(*this); + type = mlirShapedTypeGetElementType(type); + // Dispatch element extraction to an appropriate C function based on the + // elemental type of the attribute. py::float_ is implicitly constructible + // from float and double. + // TODO: consider caching the type properties in the constructor to avoid + // querying them on each element access. + if (mlirTypeIsAF32(type)) { + return mlirDenseElementsAttrGetFloatValue(*this, pos); + } + if (mlirTypeIsAF64(type)) { + return mlirDenseElementsAttrGetDoubleValue(*this, pos); + } + throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); + } + + static void bindDerived(ClassTy &c) { + c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); + } +}; + +class PyTypeAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; + static constexpr const char *pyClassName = "TypeAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirTypeAttrGet(value.get()); + return PyTypeAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets a uniqued Type attribute"); + c.def_property_readonly("value", [](PyTypeAttribute &self) { + return PyType(self.getContext()->getRef(), + mlirTypeAttrGetValue(self.get())); + }); + } +}; + +/// Unit Attribute subclass. Unit attributes don't have values. +class PyUnitAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; + static constexpr const char *pyClassName = "UnitAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return PyUnitAttribute(context->getRef(), + mlirUnitAttrGet(context->get())); + }, + py::arg("context") = py::none(), "Create a Unit attribute."); + } +}; + +} // namespace + +void mlir::python::populateIRAttributes(py::module &m) { + PyAffineMapAttribute::bind(m); + PyArrayAttribute::bind(m); + PyArrayAttribute::PyArrayAttributeIterator::bind(m); + PyBoolAttribute::bind(m); + PyDenseElementsAttribute::bind(m); + PyDenseFPElementsAttribute::bind(m); + PyDenseIntElementsAttribute::bind(m); + PyDictAttribute::bind(m); + PyFlatSymbolRefAttribute::bind(m); + PyFloatAttribute::bind(m); + PyIntegerAttribute::bind(m); + PyStringAttribute::bind(m); + PyTypeAttribute::bind(m); + PyUnitAttribute::bind(m); +} diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRCore.cpp rename from mlir/lib/Bindings/Python/IRModules.cpp rename to mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,16 +6,14 @@ // //===----------------------------------------------------------------------===// -#include "IRModules.h" +#include "IRModule.h" #include "Globals.h" #include "PybindUtils.h" -#include "mlir-c/AffineMap.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" -#include "mlir-c/IntegerSet.h" #include "mlir-c/Registration.h" #include "llvm/ADT/SmallVector.h" #include @@ -138,12 +136,6 @@ return py::reinterpret_borrow((PyClassMethod_New(cf.ptr()))); } -/// Checks whether the given type is an integer or float type. -static int mlirTypeIsAIntegerOrFloat(MlirType type) { - return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || - mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); -} - static py::object createCustomDialectWrapper(const std::string &dialectNamespace, py::object dialectDescriptor) { @@ -161,21 +153,6 @@ return mlirStringRefCreate(s.data(), s.size()); } -template -static bool isPermutation(std::vector permutation) { - llvm::SmallVector seen(permutation.size(), false); - for (auto val : permutation) { - if (val < permutation.size()) { - if (seen[val]) - return false; - seen[val] = true; - continue; - } - return false; - } - return true; -} - //------------------------------------------------------------------------------ // Collections. //------------------------------------------------------------------------------ @@ -1466,7 +1443,8 @@ /// CRTP base class for Python MLIR values that subclass Value and should be /// castable from it. The value hierarchy is one level deep and is not supposed /// to accommodate other levels unless core MLIR changes. -template class PyConcreteValue : public PyValue { +template +class PyConcreteValue : public PyValue { public: // Derived classes must define statics for: // IsAFunctionTy isaFunction @@ -1717,1910 +1695,169 @@ } // end namespace //------------------------------------------------------------------------------ -// Builtin attribute subclasses. +// Populates the core exports of the 'ir' submodule. //------------------------------------------------------------------------------ -namespace { - -/// CRTP base classes for Python attributes that subclass Attribute and should -/// be castable from it (i.e. via something like StringAttr(attr)). -/// By default, attribute class hierarchies are one level deep (i.e. a -/// concrete attribute class extends PyAttribute); however, intermediate -/// python-visible base classes can be modeled by specifying a BaseTy. -template -class PyConcreteAttribute : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirAttribute); - - PyConcreteAttribute() = default; - PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) - : BaseTy(std::move(contextRef), attr) {} - PyConcreteAttribute(PyAttribute &orig) - : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} - - static MlirAttribute castFrom(PyAttribute &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol()); - cls.def(py::init(), py::keep_alive<0, 1>()); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -class PyAffineMapAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; - static constexpr const char *pyClassName = "AffineMapAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyAffineMap &affineMap) { - MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); - return PyAffineMapAttribute(affineMap.getContext(), attr); - }, - py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); - } -}; - -class PyArrayAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; - static constexpr const char *pyClassName = "ArrayAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - class PyArrayAttributeIterator { - public: - PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} - - PyArrayAttributeIterator &dunderIter() { return *this; } - - PyAttribute dunderNext() { - if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { - throw py::stop_iteration(); - } - return PyAttribute(attr.getContext(), - mlirArrayAttrGetElement(attr.get(), nextIndex++)); - } - - static void bind(py::module &m) { - py::class_(m, "ArrayAttributeIterator") - .def("__iter__", &PyArrayAttributeIterator::dunderIter) - .def("__next__", &PyArrayAttributeIterator::dunderNext); - } - - private: - PyAttribute attr; - int nextIndex = 0; - }; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](py::list attributes, DefaultingPyMlirContext context) { - SmallVector mlirAttributes; - mlirAttributes.reserve(py::len(attributes)); - for (auto attribute : attributes) { - try { - mlirAttributes.push_back(attribute.cast()); - } catch (py::cast_error &err) { - std::string msg = std::string("Invalid attribute when attempting " - "to create an ArrayAttribute (") + - err.what() + ")"; - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { - // This exception seems thrown when the value is "None". - std::string msg = - std::string("Invalid attribute (None?) when attempting to " - "create an ArrayAttribute (") + - err.what() + ")"; - throw py::cast_error(msg); +void mlir::python::populateIRCore(py::module &m) { + //---------------------------------------------------------------------------- + // Mapping of MlirContext + //---------------------------------------------------------------------------- + py::class_(m, "Context") + .def(py::init<>(&PyMlirContext::createNewContextForInit)) + .def_static("_get_live_count", &PyMlirContext::getLiveCount) + .def("_get_context_again", + [](PyMlirContext &self) { + PyMlirContextRef ref = PyMlirContext::forContext(self.get()); + return ref.releaseObject(); + }) + .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) + .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyMlirContext::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) + .def("__enter__", &PyMlirContext::contextEnter) + .def("__exit__", &PyMlirContext::contextExit) + .def_property_readonly_static( + "current", + [](py::object & /*class*/) { + auto *context = PyThreadContextEntry::getDefaultContext(); + if (!context) + throw SetPyError(PyExc_ValueError, "No current Context"); + return context; + }, + "Gets the Context bound to the current thread or raises ValueError") + .def_property_readonly( + "dialects", + [](PyMlirContext &self) { return PyDialects(self.getRef()); }, + "Gets a container for accessing dialects by name") + .def_property_readonly( + "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, + "Alias for 'dialect'") + .def( + "get_dialect_descriptor", + [=](PyMlirContext &self, std::string &name) { + MlirDialect dialect = mlirContextGetOrLoadDialect( + self.get(), {name.data(), name.size()}); + if (mlirDialectIsNull(dialect)) { + throw SetPyError(PyExc_ValueError, + Twine("Dialect '") + name + "' not found"); } - } - MlirAttribute attr = mlirArrayAttrGet( - context->get(), mlirAttributes.size(), mlirAttributes.data()); - return PyArrayAttribute(context->getRef(), attr); - }, - py::arg("attributes"), py::arg("context") = py::none(), - "Gets a uniqued Array attribute"); - c.def("__getitem__", - [](PyArrayAttribute &arr, intptr_t i) { - if (i >= mlirArrayAttrGetNumElements(arr)) - throw py::index_error("ArrayAttribute index out of range"); - return PyAttribute(arr.getContext(), - mlirArrayAttrGetElement(arr, i)); - }) - .def("__len__", - [](const PyArrayAttribute &arr) { - return mlirArrayAttrGetNumElements(arr); - }) - .def("__iter__", [](const PyArrayAttribute &arr) { - return PyArrayAttributeIterator(arr); - }); - } -}; - -/// Float Point Attribute subclass - FloatAttr. -class PyFloatAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; - static constexpr const char *pyClassName = "FloatAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &type, double value, DefaultingPyLocation loc) { - MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(type)).cast() + - "' and expected floating point type."); - } - return PyFloatAttribute(type.getContext(), attr); - }, - py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), - "Gets an uniqued float point attribute associated to a type"); - c.def_static( - "get_f32", - [](double value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirFloatAttrDoubleGet( - context->get(), mlirF32TypeGet(context->get()), value); - return PyFloatAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an uniqued float point attribute associated to a f32 type"); - c.def_static( - "get_f64", - [](double value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirFloatAttrDoubleGet( - context->get(), mlirF64TypeGet(context->get()), value); - return PyFloatAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an uniqued float point attribute associated to a f64 type"); - c.def_property_readonly( - "value", - [](PyFloatAttribute &self) { - return mlirFloatAttrGetValueDouble(self); - }, - "Returns the value of the float point attribute"); - } -}; - -/// Integer Attribute subclass - IntegerAttr. -class PyIntegerAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; - static constexpr const char *pyClassName = "IntegerAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &type, int64_t value) { - MlirAttribute attr = mlirIntegerAttrGet(type, value); - return PyIntegerAttribute(type.getContext(), attr); - }, - py::arg("type"), py::arg("value"), - "Gets an uniqued integer attribute associated to a type"); - c.def_property_readonly( - "value", - [](PyIntegerAttribute &self) { - return mlirIntegerAttrGetValueInt(self); - }, - "Returns the value of the integer attribute"); - } -}; - -/// Bool Attribute subclass - BoolAttr. -class PyBoolAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; - static constexpr const char *pyClassName = "BoolAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](bool value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirBoolAttrGet(context->get(), value); - return PyBoolAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an uniqued bool attribute"); - c.def_property_readonly( - "value", - [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, - "Returns the value of the bool attribute"); - } -}; - -class PyFlatSymbolRefAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; - static constexpr const char *pyClassName = "FlatSymbolRefAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::string value, DefaultingPyMlirContext context) { - MlirAttribute attr = - mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); - return PyFlatSymbolRefAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets a uniqued FlatSymbolRef attribute"); - c.def_property_readonly( - "value", - [](PyFlatSymbolRefAttribute &self) { - MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); - }, - "Returns the value of the FlatSymbolRef attribute as a string"); - } -}; - -class PyStringAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; - static constexpr const char *pyClassName = "StringAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::string value, DefaultingPyMlirContext context) { - MlirAttribute attr = - mlirStringAttrGet(context->get(), toMlirStringRef(value)); - return PyStringAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets a uniqued string attribute"); - c.def_static( - "get_typed", - [](PyType &type, std::string value) { - MlirAttribute attr = - mlirStringAttrTypedGet(type, toMlirStringRef(value)); - return PyStringAttribute(type.getContext(), attr); - }, - - "Gets a uniqued string attribute associated to a type"); - c.def_property_readonly( - "value", - [](PyStringAttribute &self) { - MlirStringRef stringRef = mlirStringAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); - }, - "Returns the value of the string attribute"); - } -}; - -// TODO: Support construction of bool elements. -// TODO: Support construction of string elements. -class PyDenseElementsAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; - static constexpr const char *pyClassName = "DenseElementsAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static PyDenseElementsAttribute - getFromBuffer(py::buffer array, bool signless, - DefaultingPyMlirContext contextWrapper) { - // Request a contiguous view. In exotic cases, this will cause a copy. - int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; - Py_buffer *view = new Py_buffer(); - if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { - delete view; - throw py::error_already_set(); - } - py::buffer_info arrayInfo(view); - - MlirContext context = contextWrapper->get(); - // Switch on the types that can be bulk loaded between the Python and - // MLIR-C APIs. - // See: https://docs.python.org/3/library/struct.html#format-characters - if (arrayInfo.format == "f") { - // f32 - assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); - return PyDenseElementsAttribute( - contextWrapper->getRef(), - bulkLoad(context, mlirDenseElementsAttrFloatGet, - mlirF32TypeGet(context), arrayInfo)); - } else if (arrayInfo.format == "d") { - // f64 - assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); - return PyDenseElementsAttribute( - contextWrapper->getRef(), - bulkLoad(context, mlirDenseElementsAttrDoubleGet, - mlirF64TypeGet(context), arrayInfo)); - } else if (isSignedIntegerFormat(arrayInfo.format)) { - if (arrayInfo.itemsize == 4) { - // i32 - MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeSignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrInt32Get, - elementType, arrayInfo)); - } else if (arrayInfo.itemsize == 8) { - // i64 - MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeSignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrInt64Get, - elementType, arrayInfo)); - } - } else if (isUnsignedIntegerFormat(arrayInfo.format)) { - if (arrayInfo.itemsize == 4) { - // unsigned i32 - MlirType elementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeUnsignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrUInt32Get, - elementType, arrayInfo)); - } else if (arrayInfo.itemsize == 8) { - // unsigned i64 - MlirType elementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeUnsignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrUInt64Get, - elementType, arrayInfo)); - } - } - - // TODO: Fall back to string-based get. - std::string message = "unimplemented array format conversion from format: "; - message.append(arrayInfo.format); - throw SetPyError(PyExc_ValueError, message); - } - - static PyDenseElementsAttribute getSplat(PyType shapedType, - PyAttribute &elementAttr) { - auto contextWrapper = - PyMlirContext::forContext(mlirTypeGetContext(shapedType)); - if (!mlirAttributeIsAInteger(elementAttr) && - !mlirAttributeIsAFloat(elementAttr)) { - std::string message = "Illegal element type for DenseElementsAttr: "; - message.append(py::repr(py::cast(elementAttr))); - throw SetPyError(PyExc_ValueError, message); - } - if (!mlirTypeIsAShaped(shapedType) || - !mlirShapedTypeHasStaticShape(shapedType)) { - std::string message = - "Expected a static ShapedType for the shaped_type parameter: "; - message.append(py::repr(py::cast(shapedType))); - throw SetPyError(PyExc_ValueError, message); - } - MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); - MlirType attrType = mlirAttributeGetType(elementAttr); - if (!mlirTypeEqual(shapedElementType, attrType)) { - std::string message = - "Shaped element type and attribute type must be equal: shaped="; - message.append(py::repr(py::cast(shapedType))); - message.append(", element="); - message.append(py::repr(py::cast(elementAttr))); - throw SetPyError(PyExc_ValueError, message); - } - - MlirAttribute elements = - mlirDenseElementsAttrSplatGet(shapedType, elementAttr); - return PyDenseElementsAttribute(contextWrapper->getRef(), elements); - } + return PyDialectDescriptor(self.getRef(), dialect); + }, + "Gets or loads a dialect by name, returning its descriptor object") + .def_property( + "allow_unregistered_dialects", + [](PyMlirContext &self) -> bool { + return mlirContextGetAllowUnregisteredDialects(self.get()); + }, + [](PyMlirContext &self, bool value) { + mlirContextSetAllowUnregisteredDialects(self.get(), value); + }); - intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } - - py::buffer_info accessBuffer() { - MlirType shapedType = mlirAttributeGetType(*this); - MlirType elementType = mlirShapedTypeGetElementType(shapedType); - - if (mlirTypeIsAF32(elementType)) { - // f32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); - } else if (mlirTypeIsAF64(elementType)) { - // f64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); - } else if (mlirTypeIsAInteger(elementType) && - mlirIntegerTypeGetWidth(elementType) == 32) { - if (mlirIntegerTypeIsSignless(elementType) || - mlirIntegerTypeIsSigned(elementType)) { - // i32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); - } else if (mlirIntegerTypeIsUnsigned(elementType)) { - // unsigned i32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); - } - } else if (mlirTypeIsAInteger(elementType) && - mlirIntegerTypeGetWidth(elementType) == 64) { - if (mlirIntegerTypeIsSignless(elementType) || - mlirIntegerTypeIsSigned(elementType)) { - // i64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); - } else if (mlirIntegerTypeIsUnsigned(elementType)) { - // unsigned i64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); - } - } + //---------------------------------------------------------------------------- + // Mapping of PyDialectDescriptor + //---------------------------------------------------------------------------- + py::class_(m, "DialectDescriptor") + .def_property_readonly("namespace", + [](PyDialectDescriptor &self) { + MlirStringRef ns = + mlirDialectGetNamespace(self.get()); + return py::str(ns.data, ns.length); + }) + .def("__repr__", [](PyDialectDescriptor &self) { + MlirStringRef ns = mlirDialectGetNamespace(self.get()); + std::string repr(""); + return repr; + }); - std::string message = "unimplemented array format."; - throw SetPyError(PyExc_ValueError, message); - } + //---------------------------------------------------------------------------- + // Mapping of PyDialects + //---------------------------------------------------------------------------- + py::class_(m, "Dialects") + .def("__getitem__", + [=](PyDialects &self, std::string keyName) { + MlirDialect dialect = + self.getDialectForKey(keyName, /*attrError=*/false); + py::object descriptor = + py::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createCustomDialectWrapper(keyName, std::move(descriptor)); + }) + .def("__getattr__", [=](PyDialects &self, std::string attrName) { + MlirDialect dialect = + self.getDialectForKey(attrName, /*attrError=*/true); + py::object descriptor = + py::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createCustomDialectWrapper(attrName, std::move(descriptor)); + }); - static void bindDerived(ClassTy &c) { - c.def("__len__", &PyDenseElementsAttribute::dunderLen) - .def_static("get", PyDenseElementsAttribute::getFromBuffer, - py::arg("array"), py::arg("signless") = true, - py::arg("context") = py::none(), - "Gets from a buffer or ndarray") - .def_static("get_splat", PyDenseElementsAttribute::getSplat, - py::arg("shaped_type"), py::arg("element_attr"), - "Gets a DenseElementsAttr where all values are the same") - .def_property_readonly("is_splat", - [](PyDenseElementsAttribute &self) -> bool { - return mlirDenseElementsAttrIsSplat(self); - }) - .def_buffer(&PyDenseElementsAttribute::accessBuffer); - } + //---------------------------------------------------------------------------- + // Mapping of PyDialect + //---------------------------------------------------------------------------- + py::class_(m, "Dialect") + .def(py::init(), "descriptor") + .def_property_readonly( + "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) + .def("__repr__", [](py::object self) { + auto clazz = self.attr("__class__"); + return py::str(""); + }); -private: - template - static MlirAttribute - bulkLoad(MlirContext context, - MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), - MlirType mlirElementType, py::buffer_info &arrayInfo) { - SmallVector shape(arrayInfo.shape.begin(), - arrayInfo.shape.begin() + arrayInfo.ndim); - auto shapedType = - mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType); - intptr_t numElements = arrayInfo.size; - const ElementTy *contents = static_cast(arrayInfo.ptr); - return ctor(shapedType, numElements, contents); - } - - static bool isUnsignedIntegerFormat(const std::string &format) { - if (format.empty()) - return false; - char code = format[0]; - return code == 'I' || code == 'B' || code == 'H' || code == 'L' || - code == 'Q'; - } - - static bool isSignedIntegerFormat(const std::string &format) { - if (format.empty()) - return false; - char code = format[0]; - return code == 'i' || code == 'b' || code == 'h' || code == 'l' || - code == 'q'; - } - - template - py::buffer_info bufferInfo(MlirType shapedType, - Type (*value)(MlirAttribute, intptr_t)) { - intptr_t rank = mlirShapedTypeGetRank(shapedType); - // Prepare the data for the buffer_info. - // Buffer is configured for read-only access below. - Type *data = static_cast( - const_cast(mlirDenseElementsAttrGetRawData(*this))); - // Prepare the shape for the buffer_info. - SmallVector shape; - for (intptr_t i = 0; i < rank; ++i) - shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); - // Prepare the strides for the buffer_info. - SmallVector strides; - intptr_t strideFactor = 1; - for (intptr_t i = 1; i < rank; ++i) { - strideFactor = 1; - for (intptr_t j = i; j < rank; ++j) { - strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); - } - strides.push_back(sizeof(Type) * strideFactor); - } - strides.push_back(sizeof(Type)); - return py::buffer_info(data, sizeof(Type), - py::format_descriptor::format(), rank, shape, - strides, /*readonly=*/true); - } -}; // namespace - -/// Refinement of the PyDenseElementsAttribute for attributes containing integer -/// (and boolean) values. Supports element access. -class PyDenseIntElementsAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; - static constexpr const char *pyClassName = "DenseIntElementsAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - /// Returns the element at the given linear position. Asserts if the index is - /// out of range. - py::int_ dunderGetItem(intptr_t pos) { - if (pos < 0 || pos >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds element"); - } - - MlirType type = mlirAttributeGetType(*this); - type = mlirShapedTypeGetElementType(type); - assert(mlirTypeIsAInteger(type) && - "expected integer element type in dense int elements attribute"); - // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. py::int_ is implicitly constructible - // from any C++ integral type and handles bitwidth correctly. - // TODO: consider caching the type properties in the constructor to avoid - // querying them on each element access. - unsigned width = mlirIntegerTypeGetWidth(type); - bool isUnsigned = mlirIntegerTypeIsUnsigned(type); - if (isUnsigned) { - if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(*this, pos); - } - if (width == 32) { - return mlirDenseElementsAttrGetUInt32Value(*this, pos); - } - if (width == 64) { - return mlirDenseElementsAttrGetUInt64Value(*this, pos); - } - } else { - if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(*this, pos); - } - if (width == 32) { - return mlirDenseElementsAttrGetInt32Value(*this, pos); - } - if (width == 64) { - return mlirDenseElementsAttrGetInt64Value(*this, pos); - } - } - throw SetPyError(PyExc_TypeError, "Unsupported integer type"); - } - - static void bindDerived(ClassTy &c) { - c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); - } -}; - -class PyDictAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; - static constexpr const char *pyClassName = "DictAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } - - static void bindDerived(ClassTy &c) { - c.def("__len__", &PyDictAttribute::dunderLen); - c.def_static( - "get", - [](py::dict attributes, DefaultingPyMlirContext context) { - SmallVector mlirNamedAttributes; - mlirNamedAttributes.reserve(attributes.size()); - for (auto &it : attributes) { - auto &mlir_attr = it.second.cast(); - auto name = it.first.cast(); - mlirNamedAttributes.push_back(mlirNamedAttributeGet( - mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), - toMlirStringRef(name)), - mlir_attr)); - } - MlirAttribute attr = - mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), - mlirNamedAttributes.data()); - return PyDictAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an uniqued dict attribute"); - c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { - MlirAttribute attr = - mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); - if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_KeyError, - "attempt to access a non-existent attribute"); - } - return PyAttribute(self.getContext(), attr); - }); - c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { - if (index < 0 || index >= self.dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds attribute"); - } - MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); - return PyNamedAttribute( - namedAttr.attribute, - std::string(mlirIdentifierStr(namedAttr.name).data)); - }); - } -}; - -/// Refinement of PyDenseElementsAttribute for attributes containing -/// floating-point values. Supports element access. -class PyDenseFPElementsAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; - static constexpr const char *pyClassName = "DenseFPElementsAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - py::float_ dunderGetItem(intptr_t pos) { - if (pos < 0 || pos >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds element"); - } - - MlirType type = mlirAttributeGetType(*this); - type = mlirShapedTypeGetElementType(type); - // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. py::float_ is implicitly constructible - // from float and double. - // TODO: consider caching the type properties in the constructor to avoid - // querying them on each element access. - if (mlirTypeIsAF32(type)) { - return mlirDenseElementsAttrGetFloatValue(*this, pos); - } - if (mlirTypeIsAF64(type)) { - return mlirDenseElementsAttrGetDoubleValue(*this, pos); - } - throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); - } - - static void bindDerived(ClassTy &c) { - c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); - } -}; - -class PyTypeAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; - static constexpr const char *pyClassName = "TypeAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirTypeAttrGet(value.get()); - return PyTypeAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets a uniqued Type attribute"); - c.def_property_readonly("value", [](PyTypeAttribute &self) { - return PyType(self.getContext()->getRef(), - mlirTypeAttrGetValue(self.get())); - }); - } -}; - -/// Unit Attribute subclass. Unit attributes don't have values. -class PyUnitAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; - static constexpr const char *pyClassName = "UnitAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - return PyUnitAttribute(context->getRef(), - mlirUnitAttrGet(context->get())); - }, - py::arg("context") = py::none(), "Create a Unit attribute."); - } -}; - -} // namespace - -//------------------------------------------------------------------------------ -// Builtin type subclasses. -//------------------------------------------------------------------------------ - -namespace { - -/// CRTP base classes for Python types that subclass Type and should be -/// castable from it (i.e. via something like IntegerType(t)). -/// By default, type class hierarchies are one level deep (i.e. a -/// concrete type class extends PyType); however, intermediate python-visible -/// base classes can be modeled by specifying a BaseTy. -template -class PyConcreteType : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirType); - - PyConcreteType() = default; - PyConcreteType(PyMlirContextRef contextRef, MlirType t) - : BaseTy(std::move(contextRef), t) {} - PyConcreteType(PyType &orig) - : PyConcreteType(orig.getContext(), castFrom(orig)) {} - - static MlirType castFrom(PyType &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(py::init(), py::keep_alive<0, 1>()); - cls.def_static("isinstance", [](PyType &otherType) -> bool { - return DerivedTy::isaFunction(otherType); - }); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -class PyIntegerType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; - static constexpr const char *pyClassName = "IntegerType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_signless", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create a signless integer type"); - c.def_static( - "get_signed", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeSignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create a signed integer type"); - c.def_static( - "get_unsigned", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create an unsigned integer type"); - c.def_property_readonly( - "width", - [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, - "Returns the width of the integer type"); - c.def_property_readonly( - "is_signless", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSignless(self); - }, - "Returns whether this is a signless integer"); - c.def_property_readonly( - "is_signed", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSigned(self); - }, - "Returns whether this is a signed integer"); - c.def_property_readonly( - "is_unsigned", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsUnsigned(self); - }, - "Returns whether this is an unsigned integer"); - } -}; - -/// Index Type subclass - IndexType. -class PyIndexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; - static constexpr const char *pyClassName = "IndexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirIndexTypeGet(context->get()); - return PyIndexType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a index type."); - } -}; - -/// Floating Point Type subclass - BF16Type. -class PyBF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; - static constexpr const char *pyClassName = "BF16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirBF16TypeGet(context->get()); - return PyBF16Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a bf16 type."); - } -}; - -/// Floating Point Type subclass - F16Type. -class PyF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; - static constexpr const char *pyClassName = "F16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF16TypeGet(context->get()); - return PyF16Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f16 type."); - } -}; - -/// Floating Point Type subclass - F32Type. -class PyF32Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; - static constexpr const char *pyClassName = "F32Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF32TypeGet(context->get()); - return PyF32Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f32 type."); - } -}; - -/// Floating Point Type subclass - F64Type. -class PyF64Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; - static constexpr const char *pyClassName = "F64Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF64TypeGet(context->get()); - return PyF64Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f64 type."); - } -}; - -/// None Type subclass - NoneType. -class PyNoneType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; - static constexpr const char *pyClassName = "NoneType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirNoneTypeGet(context->get()); - return PyNoneType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a none type."); - } -}; - -/// Complex Type subclass - ComplexType. -class PyComplexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; - static constexpr const char *pyClassName = "ComplexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType) { - // The element must be a floating point or integer scalar type. - if (mlirTypeIsAIntegerOrFloat(elementType)) { - MlirType t = mlirComplexTypeGet(elementType); - return PyComplexType(elementType.getContext(), t); - } - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); - }, - "Create a complex type"); - c.def_property_readonly( - "element_type", - [](PyComplexType &self) -> PyType { - MlirType t = mlirComplexTypeGetElementType(self); - return PyType(self.getContext(), t); - }, - "Returns element type."); - } -}; - -class PyShapedType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; - static constexpr const char *pyClassName = "ShapedType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_property_readonly( - "element_type", - [](PyShapedType &self) { - MlirType t = mlirShapedTypeGetElementType(self); - return PyType(self.getContext(), t); - }, - "Returns the element type of the shaped type."); - c.def_property_readonly( - "has_rank", - [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, - "Returns whether the given shaped type is ranked."); - c.def_property_readonly( - "rank", - [](PyShapedType &self) { - self.requireHasRank(); - return mlirShapedTypeGetRank(self); - }, - "Returns the rank of the given ranked shaped type."); - c.def_property_readonly( - "has_static_shape", - [](PyShapedType &self) -> bool { - return mlirShapedTypeHasStaticShape(self); - }, - "Returns whether the given shaped type has a static shape."); - c.def( - "is_dynamic_dim", - [](PyShapedType &self, intptr_t dim) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicDim(self, dim); - }, - "Returns whether the dim-th dimension of the given shaped type is " - "dynamic."); - c.def( - "get_dim_size", - [](PyShapedType &self, intptr_t dim) { - self.requireHasRank(); - return mlirShapedTypeGetDimSize(self, dim); - }, - "Returns the dim-th dimension of the given ranked shaped type."); - c.def_static( - "is_dynamic_size", - [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, - "Returns whether the given dimension size indicates a dynamic " - "dimension."); - c.def( - "is_dynamic_stride_or_offset", - [](PyShapedType &self, int64_t val) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicStrideOrOffset(val); - }, - "Returns whether the given value is used as a placeholder for dynamic " - "strides and offsets in shaped types."); - } - -private: - void requireHasRank() { - if (!mlirShapedTypeHasRank(*this)) { - throw SetPyError( - PyExc_ValueError, - "calling this method requires that the type has a rank."); - } - } -}; - -/// Vector Type subclass - VectorType. -class PyVectorType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; - static constexpr const char *pyClassName = "VectorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - DefaultingPyLocation loc) { - MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), - elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); - } - return PyVectorType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), - "Create a vector type"); - } -}; - -/// Ranked Tensor Type subclass - RankedTensorType. -class PyRankedTensorType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; - static constexpr const char *pyClassName = "RankedTensorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - DefaultingPyLocation loc) { - MlirType t = mlirRankedTensorTypeGetChecked( - loc, shape.size(), shape.data(), elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyRankedTensorType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(), - "Create a ranked tensor type"); - } -}; - -/// Unranked Tensor Type subclass - UnrankedTensorType. -class PyUnrankedTensorType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; - static constexpr const char *pyClassName = "UnrankedTensorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, DefaultingPyLocation loc) { - MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyUnrankedTensorType(elementType.getContext(), t); - }, - py::arg("element_type"), py::arg("loc") = py::none(), - "Create a unranked tensor type"); - } -}; - -class PyMemRefLayoutMapList; - -/// Ranked MemRef Type subclass - MemRefType. -class PyMemRefType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; - static constexpr const char *pyClassName = "MemRefType"; - using PyConcreteType::PyConcreteType; - - PyMemRefLayoutMapList getLayout(); - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - std::vector layout, PyAttribute *memorySpace, - DefaultingPyLocation loc) { - SmallVector maps; - maps.reserve(layout.size()); - for (PyAffineMap &map : layout) - maps.push_back(map); - - MlirAttribute memSpaceAttr = {}; - if (memorySpace) - memSpaceAttr = *memorySpace; - - MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), - shape.data(), maps.size(), - maps.data(), memSpaceAttr); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyMemRefType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::list(), py::arg("memory_space") = py::none(), - py::arg("loc") = py::none(), "Create a memref type") - .def_property_readonly("layout", &PyMemRefType::getLayout, - "The list of layout maps of the MemRef type.") - .def_property_readonly( - "memory_space", - [](PyMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); - }, - "Returns the memory space of the given MemRef type."); - } -}; - -/// A list of affine layout maps in a memref type. Internally, these are stored -/// as consecutive elements, random access is cheap. Both the type and the maps -/// are owned by the context, no need to worry about lifetime extension. -class PyMemRefLayoutMapList - : public Sliceable { -public: - static constexpr const char *pyClassName = "MemRefLayoutMapList"; - - PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length, - step), - memref(type) {} - - intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); } - - PyAffineMap getElement(intptr_t index) { - return PyAffineMap(memref.getContext(), - mlirMemRefTypeGetAffineMap(memref, index)); - } - - PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyMemRefLayoutMapList(memref, startIndex, length, step); - } - -private: - PyMemRefType memref; -}; - -PyMemRefLayoutMapList PyMemRefType::getLayout() { - return PyMemRefLayoutMapList(*this); -} - -/// Unranked MemRef Type subclass - UnrankedMemRefType. -class PyUnrankedMemRefType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; - static constexpr const char *pyClassName = "UnrankedMemRefType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, PyAttribute *memorySpace, - DefaultingPyLocation loc) { - MlirAttribute memSpaceAttr = {}; - if (memorySpace) - memSpaceAttr = *memorySpace; - - MlirType t = - mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyUnrankedMemRefType(elementType.getContext(), t); - }, - py::arg("element_type"), py::arg("memory_space"), - py::arg("loc") = py::none(), "Create a unranked memref type") - .def_property_readonly( - "memory_space", - [](PyUnrankedMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); - }, - "Returns the memory space of the given Unranked MemRef type."); - } -}; - -/// Tuple Type subclass - TupleType. -class PyTupleType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; - static constexpr const char *pyClassName = "TupleType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_tuple", - [](py::list elementList, DefaultingPyMlirContext context) { - intptr_t num = py::len(elementList); - // Mapping py::list to SmallVector. - SmallVector elements; - for (auto element : elementList) - elements.push_back(element.cast()); - MlirType t = mlirTupleTypeGet(context->get(), num, elements.data()); - return PyTupleType(context->getRef(), t); - }, - py::arg("elements"), py::arg("context") = py::none(), - "Create a tuple type"); - c.def( - "get_type", - [](PyTupleType &self, intptr_t pos) -> PyType { - MlirType t = mlirTupleTypeGetType(self, pos); - return PyType(self.getContext(), t); - }, - "Returns the pos-th type in the tuple type."); - c.def_property_readonly( - "num_types", - [](PyTupleType &self) -> intptr_t { - return mlirTupleTypeGetNumTypes(self); - }, - "Returns the number of types contained in a tuple."); - } -}; - -/// Function type. -class PyFunctionType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; - static constexpr const char *pyClassName = "FunctionType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector inputs, std::vector results, - DefaultingPyMlirContext context) { - SmallVector inputsRaw(inputs.begin(), inputs.end()); - SmallVector resultsRaw(results.begin(), results.end()); - MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(), - inputsRaw.data(), resultsRaw.size(), - resultsRaw.data()); - return PyFunctionType(context->getRef(), t); - }, - py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), - "Gets a FunctionType from a list of input and result types"); - c.def_property_readonly( - "inputs", - [](PyFunctionType &self) { - MlirType t = self; - auto contextRef = self.getContext(); - py::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; - ++i) { - types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i))); - } - return types; - }, - "Returns the list of input types in the FunctionType."); - c.def_property_readonly( - "results", - [](PyFunctionType &self) { - auto contextRef = self.getContext(); - py::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; - ++i) { - types.append( - PyType(contextRef, mlirFunctionTypeGetResult(self, i))); - } - return types; - }, - "Returns the list of result types in the FunctionType."); - } -}; - -} // namespace - -//------------------------------------------------------------------------------ -// PyAffineExpr and subclasses. -//------------------------------------------------------------------------------ - -namespace { -/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr -/// and should be castable from it. Intermediate hierarchy classes can be -/// modeled by specifying BaseTy. -template -class PyConcreteAffineExpr : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - // and redefine bindDerived. - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirAffineExpr); - - PyConcreteAffineExpr() = default; - PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) - : BaseTy(std::move(contextRef), affineExpr) {} - PyConcreteAffineExpr(PyAffineExpr &orig) - : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} - - static MlirAffineExpr castFrom(PyAffineExpr &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, - Twine("Cannot cast affine expression to ") + - DerivedTy::pyClassName + " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(py::init()); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -class PyAffineConstantExpr : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; - static constexpr const char *pyClassName = "AffineConstantExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineConstantExpr get(intptr_t value, - DefaultingPyMlirContext context) { - MlirAffineExpr affineExpr = - mlirAffineConstantExprGet(context->get(), static_cast(value)); - return PyAffineConstantExpr(context->getRef(), affineExpr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), - py::arg("context") = py::none()); - c.def_property_readonly("value", [](PyAffineConstantExpr &self) { - return mlirAffineConstantExprGetValue(self); - }); - } -}; - -class PyAffineDimExpr : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; - static constexpr const char *pyClassName = "AffineDimExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { - MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); - return PyAffineDimExpr(context->getRef(), affineExpr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), - py::arg("context") = py::none()); - c.def_property_readonly("position", [](PyAffineDimExpr &self) { - return mlirAffineDimExprGetPosition(self); - }); - } -}; - -class PyAffineSymbolExpr : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; - static constexpr const char *pyClassName = "AffineSymbolExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { - MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); - return PyAffineSymbolExpr(context->getRef(), affineExpr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), - py::arg("context") = py::none()); - c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { - return mlirAffineSymbolExprGetPosition(self); - }); - } -}; - -class PyAffineBinaryExpr : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; - static constexpr const char *pyClassName = "AffineBinaryExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - PyAffineExpr lhs() { - MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); - return PyAffineExpr(getContext(), lhsExpr); - } - - PyAffineExpr rhs() { - MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); - return PyAffineExpr(getContext(), rhsExpr); - } - - static void bindDerived(ClassTy &c) { - c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); - c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); - } -}; - -class PyAffineAddExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; - static constexpr const char *pyClassName = "AffineAddExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); - return PyAffineAddExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineAddExpr::get); - } -}; - -class PyAffineMulExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; - static constexpr const char *pyClassName = "AffineMulExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); - return PyAffineMulExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineMulExpr::get); - } -}; - -class PyAffineModExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; - static constexpr const char *pyClassName = "AffineModExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); - return PyAffineModExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineModExpr::get); - } -}; - -class PyAffineFloorDivExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; - static constexpr const char *pyClassName = "AffineFloorDivExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); - return PyAffineFloorDivExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineFloorDivExpr::get); - } -}; - -class PyAffineCeilDivExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; - static constexpr const char *pyClassName = "AffineCeilDivExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); - return PyAffineCeilDivExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineCeilDivExpr::get); - } -}; -} // namespace - -bool PyAffineExpr::operator==(const PyAffineExpr &other) { - return mlirAffineExprEqual(affineExpr, other.affineExpr); -} - -py::object PyAffineExpr::getCapsule() { - return py::reinterpret_steal( - mlirPythonAffineExprToCapsule(*this)); -} - -PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { - MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); - if (mlirAffineExprIsNull(rawAffineExpr)) - throw py::error_already_set(); - return PyAffineExpr( - PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), - rawAffineExpr); -} - -//------------------------------------------------------------------------------ -// PyAffineMap and utilities. -//------------------------------------------------------------------------------ - -namespace { -/// A list of expressions contained in an affine map. Internally these are -/// stored as a consecutive array leading to inexpensive random access. Both -/// the map and the expression are owned by the context so we need not bother -/// with lifetime extension. -class PyAffineMapExprList - : public Sliceable { -public: - static constexpr const char *pyClassName = "AffineExprList"; - - PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirAffineMapGetNumResults(map) : length, - step), - affineMap(map) {} - - intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } - - PyAffineExpr getElement(intptr_t pos) { - return PyAffineExpr(affineMap.getContext(), - mlirAffineMapGetResult(affineMap, pos)); - } - - PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyAffineMapExprList(affineMap, startIndex, length, step); - } - -private: - PyAffineMap affineMap; -}; -} // end namespace - -bool PyAffineMap::operator==(const PyAffineMap &other) { - return mlirAffineMapEqual(affineMap, other.affineMap); -} - -py::object PyAffineMap::getCapsule() { - return py::reinterpret_steal(mlirPythonAffineMapToCapsule(*this)); -} - -PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { - MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); - if (mlirAffineMapIsNull(rawAffineMap)) - throw py::error_already_set(); - return PyAffineMap( - PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), - rawAffineMap); -} - -//------------------------------------------------------------------------------ -// PyIntegerSet and utilities. -//------------------------------------------------------------------------------ - -class PyIntegerSetConstraint { -public: - PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} - - PyAffineExpr getExpr() { - return PyAffineExpr(set.getContext(), - mlirIntegerSetGetConstraint(set, pos)); - } - - bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } - - static void bind(py::module &m) { - py::class_(m, "IntegerSetConstraint") - .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) - .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); - } - -private: - PyIntegerSet set; - intptr_t pos; -}; - -class PyIntegerSetConstraintList - : public Sliceable { -public: - static constexpr const char *pyClassName = "IntegerSetConstraintList"; - - PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, - step), - set(set) {} - - intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } - - PyIntegerSetConstraint getElement(intptr_t pos) { - return PyIntegerSetConstraint(set, pos); - } - - PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyIntegerSetConstraintList(set, startIndex, length, step); - } - -private: - PyIntegerSet set; -}; - -bool PyIntegerSet::operator==(const PyIntegerSet &other) { - return mlirIntegerSetEqual(integerSet, other.integerSet); -} - -py::object PyIntegerSet::getCapsule() { - return py::reinterpret_steal( - mlirPythonIntegerSetToCapsule(*this)); -} - -PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { - MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); - if (mlirIntegerSetIsNull(rawIntegerSet)) - throw py::error_already_set(); - return PyIntegerSet( - PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), - rawIntegerSet); -} - -/// Attempts to populate `result` with the content of `list` casted to the -/// appropriate type (Python and C types are provided as template arguments). -/// Throws errors in case of failure, using "action" to describe what the caller -/// was attempting to do. -template -static void pyListToVector(py::list list, llvm::SmallVectorImpl &result, - StringRef action) { - result.reserve(py::len(list)); - for (py::handle item : list) { - try { - result.push_back(item.cast()); - } catch (py::cast_error &err) { - std::string msg = (llvm::Twine("Invalid expression when ") + action + - " (" + err.what() + ")") - .str(); - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { - std::string msg = (llvm::Twine("Invalid expression (None?) when ") + - action + " (" + err.what() + ")") - .str(); - throw py::cast_error(msg); - } - } -} - -//------------------------------------------------------------------------------ -// Populates the pybind11 IR submodule. -//------------------------------------------------------------------------------ - -void mlir::python::populateIRSubmodule(py::module &m) { - //---------------------------------------------------------------------------- - // Mapping of MlirContext - //---------------------------------------------------------------------------- - py::class_(m, "Context") - .def(py::init<>(&PyMlirContext::createNewContextForInit)) - .def_static("_get_live_count", &PyMlirContext::getLiveCount) - .def("_get_context_again", - [](PyMlirContext &self) { - PyMlirContextRef ref = PyMlirContext::forContext(self.get()); - return ref.releaseObject(); - }) - .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) - .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyMlirContext::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) - .def("__enter__", &PyMlirContext::contextEnter) - .def("__exit__", &PyMlirContext::contextExit) - .def_property_readonly_static( - "current", - [](py::object & /*class*/) { - auto *context = PyThreadContextEntry::getDefaultContext(); - if (!context) - throw SetPyError(PyExc_ValueError, "No current Context"); - return context; - }, - "Gets the Context bound to the current thread or raises ValueError") - .def_property_readonly( - "dialects", - [](PyMlirContext &self) { return PyDialects(self.getRef()); }, - "Gets a container for accessing dialects by name") - .def_property_readonly( - "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, - "Alias for 'dialect'") - .def( - "get_dialect_descriptor", - [=](PyMlirContext &self, std::string &name) { - MlirDialect dialect = mlirContextGetOrLoadDialect( - self.get(), {name.data(), name.size()}); - if (mlirDialectIsNull(dialect)) { - throw SetPyError(PyExc_ValueError, - Twine("Dialect '") + name + "' not found"); - } - return PyDialectDescriptor(self.getRef(), dialect); - }, - "Gets or loads a dialect by name, returning its descriptor object") - .def_property( - "allow_unregistered_dialects", - [](PyMlirContext &self) -> bool { - return mlirContextGetAllowUnregisteredDialects(self.get()); - }, - [](PyMlirContext &self, bool value) { - mlirContextSetAllowUnregisteredDialects(self.get(), value); - }); - - //---------------------------------------------------------------------------- - // Mapping of PyDialectDescriptor - //---------------------------------------------------------------------------- - py::class_(m, "DialectDescriptor") - .def_property_readonly("namespace", - [](PyDialectDescriptor &self) { - MlirStringRef ns = - mlirDialectGetNamespace(self.get()); - return py::str(ns.data, ns.length); - }) - .def("__repr__", [](PyDialectDescriptor &self) { - MlirStringRef ns = mlirDialectGetNamespace(self.get()); - std::string repr(""); - return repr; - }); - - //---------------------------------------------------------------------------- - // Mapping of PyDialects - //---------------------------------------------------------------------------- - py::class_(m, "Dialects") - .def("__getitem__", - [=](PyDialects &self, std::string keyName) { - MlirDialect dialect = - self.getDialectForKey(keyName, /*attrError=*/false); - py::object descriptor = - py::cast(PyDialectDescriptor{self.getContext(), dialect}); - return createCustomDialectWrapper(keyName, std::move(descriptor)); - }) - .def("__getattr__", [=](PyDialects &self, std::string attrName) { - MlirDialect dialect = - self.getDialectForKey(attrName, /*attrError=*/true); - py::object descriptor = - py::cast(PyDialectDescriptor{self.getContext(), dialect}); - return createCustomDialectWrapper(attrName, std::move(descriptor)); - }); - - //---------------------------------------------------------------------------- - // Mapping of PyDialect - //---------------------------------------------------------------------------- - py::class_(m, "Dialect") - .def(py::init(), "descriptor") - .def_property_readonly( - "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) - .def("__repr__", [](py::object self) { - auto clazz = self.attr("__class__"); - return py::str(""); - }); - - //---------------------------------------------------------------------------- - // Mapping of Location - //---------------------------------------------------------------------------- - py::class_(m, "Location") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) - .def("__enter__", &PyLocation::contextEnter) - .def("__exit__", &PyLocation::contextExit) - .def("__eq__", - [](PyLocation &self, PyLocation &other) -> bool { - return mlirLocationEqual(self, other); - }) - .def("__eq__", [](PyLocation &self, py::object other) { return false; }) - .def_property_readonly_static( - "current", - [](py::object & /*class*/) { - auto *loc = PyThreadContextEntry::getDefaultLocation(); - if (!loc) - throw SetPyError(PyExc_ValueError, "No current Location"); - return loc; - }, - "Gets the Location bound to the current thread or raises ValueError") - .def_static( - "unknown", - [](DefaultingPyMlirContext context) { - return PyLocation(context->getRef(), - mlirLocationUnknownGet(context->get())); - }, - py::arg("context") = py::none(), - "Gets a Location representing an unknown location") - .def_static( - "file", - [](std::string filename, int line, int col, - DefaultingPyMlirContext context) { - return PyLocation( - context->getRef(), - mlirLocationFileLineColGet( - context->get(), toMlirStringRef(filename), line, col)); - }, - py::arg("filename"), py::arg("line"), py::arg("col"), - py::arg("context") = py::none(), kContextGetFileLocationDocstring) - .def_property_readonly( - "context", - [](PyLocation &self) { return self.getContext().getObject(); }, - "Context that owns the Location") - .def("__repr__", [](PyLocation &self) { - PyPrintAccumulator printAccum; - mlirLocationPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }); + //---------------------------------------------------------------------------- + // Mapping of Location + //---------------------------------------------------------------------------- + py::class_(m, "Location") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) + .def("__enter__", &PyLocation::contextEnter) + .def("__exit__", &PyLocation::contextExit) + .def("__eq__", + [](PyLocation &self, PyLocation &other) -> bool { + return mlirLocationEqual(self, other); + }) + .def("__eq__", [](PyLocation &self, py::object other) { return false; }) + .def_property_readonly_static( + "current", + [](py::object & /*class*/) { + auto *loc = PyThreadContextEntry::getDefaultLocation(); + if (!loc) + throw SetPyError(PyExc_ValueError, "No current Location"); + return loc; + }, + "Gets the Location bound to the current thread or raises ValueError") + .def_static( + "unknown", + [](DefaultingPyMlirContext context) { + return PyLocation(context->getRef(), + mlirLocationUnknownGet(context->get())); + }, + py::arg("context") = py::none(), + "Gets a Location representing an unknown location") + .def_static( + "file", + [](std::string filename, int line, int col, + DefaultingPyMlirContext context) { + return PyLocation( + context->getRef(), + mlirLocationFileLineColGet( + context->get(), toMlirStringRef(filename), line, col)); + }, + py::arg("filename"), py::arg("line"), py::arg("col"), + py::arg("context") = py::none(), kContextGetFileLocationDocstring) + .def_property_readonly( + "context", + [](PyLocation &self) { return self.getContext().getObject(); }, + "Context that owns the Location") + .def("__repr__", [](PyLocation &self) { + PyPrintAccumulator printAccum; + mlirLocationPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }); //---------------------------------------------------------------------------- // Mapping of Module @@ -4022,22 +2259,6 @@ py::keep_alive<0, 1>(), "The underlying generic attribute of the NamedAttribute binding"); - // Builtin attribute bindings. - PyAffineMapAttribute::bind(m); - PyArrayAttribute::bind(m); - PyArrayAttribute::PyArrayAttributeIterator::bind(m); - PyBoolAttribute::bind(m); - PyDenseElementsAttribute::bind(m); - PyDenseFPElementsAttribute::bind(m); - PyDenseIntElementsAttribute::bind(m); - PyDictAttribute::bind(m); - PyFlatSymbolRefAttribute::bind(m); - PyFloatAttribute::bind(m); - PyIntegerAttribute::bind(m); - PyStringAttribute::bind(m); - PyTypeAttribute::bind(m); - PyUnitAttribute::bind(m); - //---------------------------------------------------------------------------- // Mapping of PyType. //---------------------------------------------------------------------------- @@ -4088,25 +2309,6 @@ return printAccum.join(); }); - // Builtin type bindings. - PyIntegerType::bind(m); - PyIndexType::bind(m); - PyBF16Type::bind(m); - PyF16Type::bind(m); - PyF32Type::bind(m); - PyF64Type::bind(m); - PyNoneType::bind(m); - PyComplexType::bind(m); - PyShapedType::bind(m); - PyVectorType::bind(m); - PyRankedTensorType::bind(m); - PyUnrankedTensorType::bind(m); - PyMemRefType::bind(m); - PyMemRefLayoutMapList::bind(m); - PyUnrankedMemRefType::bind(m); - PyTupleType::bind(m); - PyFunctionType::bind(m); - //---------------------------------------------------------------------------- // Mapping of Value. //---------------------------------------------------------------------------- @@ -4152,359 +2354,4 @@ PyOpResultList::bind(m); PyRegionIterator::bind(m); PyRegionList::bind(m); - - //---------------------------------------------------------------------------- - // Mapping of PyAffineExpr and derived classes. - //---------------------------------------------------------------------------- - py::class_(m, "AffineExpr") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAffineExpr::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) - .def("__add__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineAddExpr::get(self, other); - }) - .def("__mul__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineMulExpr::get(self, other); - }) - .def("__mod__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineModExpr::get(self, other); - }) - .def("__sub__", - [](PyAffineExpr &self, PyAffineExpr &other) { - auto negOne = - PyAffineConstantExpr::get(-1, *self.getContext().get()); - return PyAffineAddExpr::get(self, - PyAffineMulExpr::get(negOne, other)); - }) - .def("__eq__", [](PyAffineExpr &self, - PyAffineExpr &other) { return self == other; }) - .def("__eq__", - [](PyAffineExpr &self, py::object &other) { return false; }) - .def("__str__", - [](PyAffineExpr &self) { - PyPrintAccumulator printAccum; - mlirAffineExprPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }) - .def("__repr__", - [](PyAffineExpr &self) { - PyPrintAccumulator printAccum; - printAccum.parts.append("AffineExpr("); - mlirAffineExprPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def_property_readonly( - "context", - [](PyAffineExpr &self) { return self.getContext().getObject(); }) - .def_static( - "get_add", &PyAffineAddExpr::get, - "Gets an affine expression containing a sum of two expressions.") - .def_static( - "get_mul", &PyAffineMulExpr::get, - "Gets an affine expression containing a product of two expressions.") - .def_static("get_mod", &PyAffineModExpr::get, - "Gets an affine expression containing the modulo of dividing " - "one expression by another.") - .def_static("get_floor_div", &PyAffineFloorDivExpr::get, - "Gets an affine expression containing the rounded-down " - "result of dividing one expression by another.") - .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, - "Gets an affine expression containing the rounded-up result " - "of dividing one expression by another.") - .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), - py::arg("context") = py::none(), - "Gets a constant affine expression with the given value.") - .def_static( - "get_dim", &PyAffineDimExpr::get, py::arg("position"), - py::arg("context") = py::none(), - "Gets an affine expression of a dimension at the given position.") - .def_static( - "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), - py::arg("context") = py::none(), - "Gets an affine expression of a symbol at the given position.") - .def( - "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, - kDumpDocstring); - PyAffineConstantExpr::bind(m); - PyAffineDimExpr::bind(m); - PyAffineSymbolExpr::bind(m); - PyAffineBinaryExpr::bind(m); - PyAffineAddExpr::bind(m); - PyAffineMulExpr::bind(m); - PyAffineModExpr::bind(m); - PyAffineFloorDivExpr::bind(m); - PyAffineCeilDivExpr::bind(m); - - //---------------------------------------------------------------------------- - // Mapping of PyAffineMap. - //---------------------------------------------------------------------------- - py::class_(m, "AffineMap") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAffineMap::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) - .def("__eq__", - [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) - .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) - .def("__str__", - [](PyAffineMap &self) { - PyPrintAccumulator printAccum; - mlirAffineMapPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }) - .def("__repr__", - [](PyAffineMap &self) { - PyPrintAccumulator printAccum; - printAccum.parts.append("AffineMap("); - mlirAffineMapPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def_property_readonly( - "context", - [](PyAffineMap &self) { return self.getContext().getObject(); }, - "Context that owns the Affine Map") - .def( - "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, - kDumpDocstring) - .def_static( - "get", - [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, - DefaultingPyMlirContext context) { - SmallVector affineExprs; - pyListToVector( - exprs, affineExprs, "attempting to create an AffineMap"); - MlirAffineMap map = - mlirAffineMapGet(context->get(), dimCount, symbolCount, - affineExprs.size(), affineExprs.data()); - return PyAffineMap(context->getRef(), map); - }, - py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), - py::arg("context") = py::none(), - "Gets a map with the given expressions as results.") - .def_static( - "get_constant", - [](intptr_t value, DefaultingPyMlirContext context) { - MlirAffineMap affineMap = - mlirAffineMapConstantGet(context->get(), value); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an affine map with a single constant result") - .def_static( - "get_empty", - [](DefaultingPyMlirContext context) { - MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("context") = py::none(), "Gets an empty affine map.") - .def_static( - "get_identity", - [](intptr_t nDims, DefaultingPyMlirContext context) { - MlirAffineMap affineMap = - mlirAffineMapMultiDimIdentityGet(context->get(), nDims); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("n_dims"), py::arg("context") = py::none(), - "Gets an identity map with the given number of dimensions.") - .def_static( - "get_minor_identity", - [](intptr_t nDims, intptr_t nResults, - DefaultingPyMlirContext context) { - MlirAffineMap affineMap = - mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("n_dims"), py::arg("n_results"), - py::arg("context") = py::none(), - "Gets a minor identity map with the given number of dimensions and " - "results.") - .def_static( - "get_permutation", - [](std::vector permutation, - DefaultingPyMlirContext context) { - if (!isPermutation(permutation)) - throw py::cast_error("Invalid permutation when attempting to " - "create an AffineMap"); - MlirAffineMap affineMap = mlirAffineMapPermutationGet( - context->get(), permutation.size(), permutation.data()); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("permutation"), py::arg("context") = py::none(), - "Gets an affine map that permutes its inputs.") - .def("get_submap", - [](PyAffineMap &self, std::vector &resultPos) { - intptr_t numResults = mlirAffineMapGetNumResults(self); - for (intptr_t pos : resultPos) { - if (pos < 0 || pos >= numResults) - throw py::value_error("result position out of bounds"); - } - MlirAffineMap affineMap = mlirAffineMapGetSubMap( - self, resultPos.size(), resultPos.data()); - return PyAffineMap(self.getContext(), affineMap); - }) - .def("get_major_submap", - [](PyAffineMap &self, intptr_t nResults) { - if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); - MlirAffineMap affineMap = - mlirAffineMapGetMajorSubMap(self, nResults); - return PyAffineMap(self.getContext(), affineMap); - }) - .def("get_minor_submap", - [](PyAffineMap &self, intptr_t nResults) { - if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); - MlirAffineMap affineMap = - mlirAffineMapGetMinorSubMap(self, nResults); - return PyAffineMap(self.getContext(), affineMap); - }) - .def_property_readonly( - "is_permutation", - [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) - .def_property_readonly("is_projected_permutation", - [](PyAffineMap &self) { - return mlirAffineMapIsProjectedPermutation(self); - }) - .def_property_readonly( - "n_dims", - [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) - .def_property_readonly( - "n_inputs", - [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) - .def_property_readonly( - "n_symbols", - [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) - .def_property_readonly("results", [](PyAffineMap &self) { - return PyAffineMapExprList(self); - }); - PyAffineMapExprList::bind(m); - - //---------------------------------------------------------------------------- - // Mapping of PyIntegerSet. - //---------------------------------------------------------------------------- - py::class_(m, "IntegerSet") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyIntegerSet::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) - .def("__eq__", [](PyIntegerSet &self, - PyIntegerSet &other) { return self == other; }) - .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) - .def("__str__", - [](PyIntegerSet &self) { - PyPrintAccumulator printAccum; - mlirIntegerSetPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }) - .def("__repr__", - [](PyIntegerSet &self) { - PyPrintAccumulator printAccum; - printAccum.parts.append("IntegerSet("); - mlirIntegerSetPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def_property_readonly( - "context", - [](PyIntegerSet &self) { return self.getContext().getObject(); }) - .def( - "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, - kDumpDocstring) - .def_static( - "get", - [](intptr_t numDims, intptr_t numSymbols, py::list exprs, - std::vector eqFlags, DefaultingPyMlirContext context) { - if (exprs.size() != eqFlags.size()) - throw py::value_error( - "Expected the number of constraints to match " - "that of equality flags"); - if (exprs.empty()) - throw py::value_error("Expected non-empty list of constraints"); - - // Copy over to a SmallVector because std::vector has a - // specialization for booleans that packs data and does not - // expose a `bool *`. - SmallVector flags(eqFlags.begin(), eqFlags.end()); - - SmallVector affineExprs; - pyListToVector(exprs, affineExprs, - "attempting to create an IntegerSet"); - MlirIntegerSet set = mlirIntegerSetGet( - context->get(), numDims, numSymbols, exprs.size(), - affineExprs.data(), flags.data()); - return PyIntegerSet(context->getRef(), set); - }, - py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), - py::arg("eq_flags"), py::arg("context") = py::none()) - .def_static( - "get_empty", - [](intptr_t numDims, intptr_t numSymbols, - DefaultingPyMlirContext context) { - MlirIntegerSet set = - mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); - return PyIntegerSet(context->getRef(), set); - }, - py::arg("num_dims"), py::arg("num_symbols"), - py::arg("context") = py::none()) - .def("get_replaced", - [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, - intptr_t numResultDims, intptr_t numResultSymbols) { - if (static_cast(dimExprs.size()) != - mlirIntegerSetGetNumDims(self)) - throw py::value_error( - "Expected the number of dimension replacement expressions " - "to match that of dimensions"); - if (static_cast(symbolExprs.size()) != - mlirIntegerSetGetNumSymbols(self)) - throw py::value_error( - "Expected the number of symbol replacement expressions " - "to match that of symbols"); - - SmallVector dimAffineExprs, symbolAffineExprs; - pyListToVector( - dimExprs, dimAffineExprs, - "attempting to create an IntegerSet by replacing dimensions"); - pyListToVector( - symbolExprs, symbolAffineExprs, - "attempting to create an IntegerSet by replacing symbols"); - MlirIntegerSet set = mlirIntegerSetReplaceGet( - self, dimAffineExprs.data(), symbolAffineExprs.data(), - numResultDims, numResultSymbols); - return PyIntegerSet(self.getContext(), set); - }) - .def_property_readonly("is_canonical_empty", - [](PyIntegerSet &self) { - return mlirIntegerSetIsCanonicalEmpty(self); - }) - .def_property_readonly( - "n_dims", - [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) - .def_property_readonly( - "n_symbols", - [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) - .def_property_readonly( - "n_inputs", - [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) - .def_property_readonly("n_equalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumEqualities(self); - }) - .def_property_readonly("n_inequalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumInequalities(self); - }) - .def_property_readonly("constraints", [](PyIntegerSet &self) { - return PyIntegerSetConstraintList(self); - }); - PyIntegerSetConstraint::bind(m); - PyIntegerSetConstraintList::bind(m); } diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModule.h rename from mlir/lib/Bindings/Python/IRModules.h rename to mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -747,7 +747,10 @@ MlirIntegerSet integerSet; }; -void populateIRSubmodule(pybind11::module &m); +void populateIRAffine(pybind11::module &m); +void populateIRAttributes(pybind11::module &m); +void populateIRCore(pybind11::module &m); +void populateIRTypes(pybind11::module &m); } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -0,0 +1,678 @@ +//===- IRTypes.cpp - Exports builtin and standard types -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" + +#include "PybindUtils.h" + +#include "mlir-c/BuiltinTypes.h" + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +using llvm::SmallVector; +using llvm::Twine; + +namespace { + +/// Checks whether the given type is an integer or float type. +static int mlirTypeIsAIntegerOrFloat(MlirType type) { + return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || + mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); +} + +/// CRTP base classes for Python types that subclass Type and should be +/// castable from it (i.e. via something like IntegerType(t)). +/// By default, type class hierarchies are one level deep (i.e. a +/// concrete type class extends PyType); however, intermediate python-visible +/// base classes can be modeled by specifying a BaseTy. +template +class PyConcreteType : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = py::class_; + using IsAFunctionTy = bool (*)(MlirType); + + PyConcreteType() = default; + PyConcreteType(PyMlirContextRef contextRef, MlirType t) + : BaseTy(std::move(contextRef), t) {} + PyConcreteType(PyType &orig) + : PyConcreteType(orig.getContext(), castFrom(orig)) {} + + static MlirType castFrom(PyType &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = py::repr(py::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") + + DerivedTy::pyClassName + + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(py::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(py::init(), py::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyType &otherType) -> bool { + return DerivedTy::isaFunction(otherType); + }); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +class PyIntegerType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; + static constexpr const char *pyClassName = "IntegerType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_signless", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + py::arg("width"), py::arg("context") = py::none(), + "Create a signless integer type"); + c.def_static( + "get_signed", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeSignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + py::arg("width"), py::arg("context") = py::none(), + "Create a signed integer type"); + c.def_static( + "get_unsigned", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + py::arg("width"), py::arg("context") = py::none(), + "Create an unsigned integer type"); + c.def_property_readonly( + "width", + [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, + "Returns the width of the integer type"); + c.def_property_readonly( + "is_signless", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsSignless(self); + }, + "Returns whether this is a signless integer"); + c.def_property_readonly( + "is_signed", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsSigned(self); + }, + "Returns whether this is a signed integer"); + c.def_property_readonly( + "is_unsigned", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsUnsigned(self); + }, + "Returns whether this is an unsigned integer"); + } +}; + +/// Index Type subclass - IndexType. +class PyIndexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; + static constexpr const char *pyClassName = "IndexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirIndexTypeGet(context->get()); + return PyIndexType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a index type."); + } +}; + +/// Floating Point Type subclass - BF16Type. +class PyBF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; + static constexpr const char *pyClassName = "BF16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirBF16TypeGet(context->get()); + return PyBF16Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a bf16 type."); + } +}; + +/// Floating Point Type subclass - F16Type. +class PyF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; + static constexpr const char *pyClassName = "F16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF16TypeGet(context->get()); + return PyF16Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f16 type."); + } +}; + +/// Floating Point Type subclass - F32Type. +class PyF32Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; + static constexpr const char *pyClassName = "F32Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF32TypeGet(context->get()); + return PyF32Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f32 type."); + } +}; + +/// Floating Point Type subclass - F64Type. +class PyF64Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; + static constexpr const char *pyClassName = "F64Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF64TypeGet(context->get()); + return PyF64Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f64 type."); + } +}; + +/// None Type subclass - NoneType. +class PyNoneType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; + static constexpr const char *pyClassName = "NoneType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirNoneTypeGet(context->get()); + return PyNoneType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a none type."); + } +}; + +/// Complex Type subclass - ComplexType. +class PyComplexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; + static constexpr const char *pyClassName = "ComplexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType) { + // The element must be a floating point or integer scalar type. + if (mlirTypeIsAIntegerOrFloat(elementType)) { + MlirType t = mlirComplexTypeGet(elementType); + return PyComplexType(elementType.getContext(), t); + } + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point or integer type."); + }, + "Create a complex type"); + c.def_property_readonly( + "element_type", + [](PyComplexType &self) -> PyType { + MlirType t = mlirComplexTypeGetElementType(self); + return PyType(self.getContext(), t); + }, + "Returns element type."); + } +}; + +class PyShapedType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; + static constexpr const char *pyClassName = "ShapedType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_property_readonly( + "element_type", + [](PyShapedType &self) { + MlirType t = mlirShapedTypeGetElementType(self); + return PyType(self.getContext(), t); + }, + "Returns the element type of the shaped type."); + c.def_property_readonly( + "has_rank", + [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, + "Returns whether the given shaped type is ranked."); + c.def_property_readonly( + "rank", + [](PyShapedType &self) { + self.requireHasRank(); + return mlirShapedTypeGetRank(self); + }, + "Returns the rank of the given ranked shaped type."); + c.def_property_readonly( + "has_static_shape", + [](PyShapedType &self) -> bool { + return mlirShapedTypeHasStaticShape(self); + }, + "Returns whether the given shaped type has a static shape."); + c.def( + "is_dynamic_dim", + [](PyShapedType &self, intptr_t dim) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicDim(self, dim); + }, + "Returns whether the dim-th dimension of the given shaped type is " + "dynamic."); + c.def( + "get_dim_size", + [](PyShapedType &self, intptr_t dim) { + self.requireHasRank(); + return mlirShapedTypeGetDimSize(self, dim); + }, + "Returns the dim-th dimension of the given ranked shaped type."); + c.def_static( + "is_dynamic_size", + [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, + "Returns whether the given dimension size indicates a dynamic " + "dimension."); + c.def( + "is_dynamic_stride_or_offset", + [](PyShapedType &self, int64_t val) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicStrideOrOffset(val); + }, + "Returns whether the given value is used as a placeholder for dynamic " + "strides and offsets in shaped types."); + } + +private: + void requireHasRank() { + if (!mlirShapedTypeHasRank(*this)) { + throw SetPyError( + PyExc_ValueError, + "calling this method requires that the type has a rank."); + } + } +}; + +/// Vector Type subclass - VectorType. +class PyVectorType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; + static constexpr const char *pyClassName = "VectorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, + DefaultingPyLocation loc) { + MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), + elementType); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point or integer type."); + } + return PyVectorType(elementType.getContext(), t); + }, + py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), + "Create a vector type"); + } +}; + +/// Ranked Tensor Type subclass - RankedTensorType. +class PyRankedTensorType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr const char *pyClassName = "RankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, + DefaultingPyLocation loc) { + MlirType t = mlirRankedTensorTypeGetChecked( + loc, shape.size(), shape.data(), elementType); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyRankedTensorType(elementType.getContext(), t); + }, + py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(), + "Create a ranked tensor type"); + } +}; + +/// Unranked Tensor Type subclass - UnrankedTensorType. +class PyUnrankedTensorType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; + static constexpr const char *pyClassName = "UnrankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, DefaultingPyLocation loc) { + MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyUnrankedTensorType(elementType.getContext(), t); + }, + py::arg("element_type"), py::arg("loc") = py::none(), + "Create a unranked tensor type"); + } +}; + +class PyMemRefLayoutMapList; + +/// Ranked MemRef Type subclass - MemRefType. +class PyMemRefType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr const char *pyClassName = "MemRefType"; + using PyConcreteType::PyConcreteType; + + PyMemRefLayoutMapList getLayout(); + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, + std::vector layout, PyAttribute *memorySpace, + DefaultingPyLocation loc) { + SmallVector maps; + maps.reserve(layout.size()); + for (PyAffineMap &map : layout) + maps.push_back(map); + + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + + MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), + shape.data(), maps.size(), + maps.data(), memSpaceAttr); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyMemRefType(elementType.getContext(), t); + }, + py::arg("shape"), py::arg("element_type"), + py::arg("layout") = py::list(), py::arg("memory_space") = py::none(), + py::arg("loc") = py::none(), "Create a memref type") + .def_property_readonly("layout", &PyMemRefType::getLayout, + "The list of layout maps of the MemRef type.") + .def_property_readonly( + "memory_space", + [](PyMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); + }, + "Returns the memory space of the given MemRef type."); + } +}; + +/// A list of affine layout maps in a memref type. Internally, these are stored +/// as consecutive elements, random access is cheap. Both the type and the maps +/// are owned by the context, no need to worry about lifetime extension. +class PyMemRefLayoutMapList + : public Sliceable { +public: + static constexpr const char *pyClassName = "MemRefLayoutMapList"; + + PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length, + step), + memref(type) {} + + intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); } + + PyAffineMap getElement(intptr_t index) { + return PyAffineMap(memref.getContext(), + mlirMemRefTypeGetAffineMap(memref, index)); + } + + PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyMemRefLayoutMapList(memref, startIndex, length, step); + } + +private: + PyMemRefType memref; +}; + +PyMemRefLayoutMapList PyMemRefType::getLayout() { + return PyMemRefLayoutMapList(*this); +} + +/// Unranked MemRef Type subclass - UnrankedMemRefType. +class PyUnrankedMemRefType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; + static constexpr const char *pyClassName = "UnrankedMemRefType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, PyAttribute *memorySpace, + DefaultingPyLocation loc) { + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + + MlirType t = + mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyUnrankedMemRefType(elementType.getContext(), t); + }, + py::arg("element_type"), py::arg("memory_space"), + py::arg("loc") = py::none(), "Create a unranked memref type") + .def_property_readonly( + "memory_space", + [](PyUnrankedMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); + }, + "Returns the memory space of the given Unranked MemRef type."); + } +}; + +/// Tuple Type subclass - TupleType. +class PyTupleType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; + static constexpr const char *pyClassName = "TupleType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_tuple", + [](py::list elementList, DefaultingPyMlirContext context) { + intptr_t num = py::len(elementList); + // Mapping py::list to SmallVector. + SmallVector elements; + for (auto element : elementList) + elements.push_back(element.cast()); + MlirType t = mlirTupleTypeGet(context->get(), num, elements.data()); + return PyTupleType(context->getRef(), t); + }, + py::arg("elements"), py::arg("context") = py::none(), + "Create a tuple type"); + c.def( + "get_type", + [](PyTupleType &self, intptr_t pos) -> PyType { + MlirType t = mlirTupleTypeGetType(self, pos); + return PyType(self.getContext(), t); + }, + "Returns the pos-th type in the tuple type."); + c.def_property_readonly( + "num_types", + [](PyTupleType &self) -> intptr_t { + return mlirTupleTypeGetNumTypes(self); + }, + "Returns the number of types contained in a tuple."); + } +}; + +/// Function type. +class PyFunctionType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; + static constexpr const char *pyClassName = "FunctionType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector inputs, std::vector results, + DefaultingPyMlirContext context) { + SmallVector inputsRaw(inputs.begin(), inputs.end()); + SmallVector resultsRaw(results.begin(), results.end()); + MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(), + inputsRaw.data(), resultsRaw.size(), + resultsRaw.data()); + return PyFunctionType(context->getRef(), t); + }, + py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), + "Gets a FunctionType from a list of input and result types"); + c.def_property_readonly( + "inputs", + [](PyFunctionType &self) { + MlirType t = self; + auto contextRef = self.getContext(); + py::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; + ++i) { + types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i))); + } + return types; + }, + "Returns the list of input types in the FunctionType."); + c.def_property_readonly( + "results", + [](PyFunctionType &self) { + auto contextRef = self.getContext(); + py::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; + ++i) { + types.append( + PyType(contextRef, mlirFunctionTypeGetResult(self, i))); + } + return types; + }, + "Returns the list of result types in the FunctionType."); + } +}; + +} // namespace + +void mlir::python::populateIRTypes(py::module &m) { + PyIntegerType::bind(m); + PyIndexType::bind(m); + PyBF16Type::bind(m); + PyF16Type::bind(m); + PyF32Type::bind(m); + PyF64Type::bind(m); + PyNoneType::bind(m); + PyComplexType::bind(m); + PyShapedType::bind(m); + PyVectorType::bind(m); + PyRankedTensorType::bind(m); + PyUnrankedTensorType::bind(m); + PyMemRefType::bind(m); + PyMemRefLayoutMapList::bind(m); + PyUnrankedMemRefType::bind(m); + PyTupleType::bind(m); + PyFunctionType::bind(m); +} diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -12,7 +12,7 @@ #include "ExecutionEngine.h" #include "Globals.h" -#include "IRModules.h" +#include "IRModule.h" #include "Pass.h" namespace py = pybind11; @@ -211,7 +211,10 @@ // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); - populateIRSubmodule(irModule); + populateIRCore(irModule); + populateIRAffine(irModule); + populateIRAttributes(irModule); + populateIRTypes(irModule); // Define and populate PassManager submodule. auto passModule = diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,7 +8,7 @@ #include "Pass.h" -#include "IRModules.h" +#include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Pass.h"