diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -28,6 +28,7 @@ #define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr" #define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr" +#define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr" #define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr" #define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr" #define MLIR_PYTHON_CAPSULE_TYPE "mlir.ir.Type._CAPIPtr" @@ -106,6 +107,24 @@ return context; } +/** Creates a capsule object encapsulating the raw C-API MlirLocation. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the location in any way. */ +static inline PyObject *mlirPythonLocationToCapsule(MlirLocation loc) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(loc), + MLIR_PYTHON_CAPSULE_LOCATION, NULL); +} + +/** Extracts an MlirLocation from a capsule as produced from + * mlirPythonLocationToCapsule. If the capsule is not of the right type, then + * a null module is returned (as checked via mlirLocationIsNull). In such a + * case, the Python APIs will have already set an error. */ +static inline MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_LOCATION); + MlirLocation loc = {ptr}; + return loc; +} + /** Creates a capsule object encapsulating the raw C-API MlirModule. * The returned capsule does not extend or affect ownership of any Python * objects that reference the module in any way. */ diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -153,6 +153,14 @@ /// Gets the context that a location was created with. MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location); +/// Checks if the location is null. +static inline int mlirLocationIsNull(MlirLocation location) { + return !location.ptr; +} + +/// Checks if two locations are equal. +MLIR_CAPI_EXPORTED int mlirLocationEqual(MlirLocation l1, MlirLocation l2); + /** Prints a location by sending chunks of the string representation and * forwarding `userData to `callback`. Note that the callback may be called * several times with consecutive chunks of the string. */ diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -307,11 +307,24 @@ PyLocation(PyMlirContextRef contextRef, MlirLocation loc) : BaseContextObject(std::move(contextRef)), loc(loc) {} + operator MlirLocation() const { return loc; } + MlirLocation get() const { return loc; } + /// Enter and exit the context manager. pybind11::object contextEnter(); void contextExit(pybind11::object excType, pybind11::object excVal, pybind11::object excTb); + /// Gets a capsule wrapping the void* within the MlirContext. + pybind11::object getCapsule(); + + /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. + /// Note that PyMlirContext instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirContext + /// is taken by calling this function. + static PyLocation createFromCapsule(pybind11::object capsule); + +private: MlirLocation loc; }; @@ -324,6 +337,8 @@ static constexpr const char kTypeDescription[] = "[ThreadContextAware] mlir.ir.Location"; static PyLocation &resolve(); + + operator MlirLocation() const { return *get(); } }; /// Wrapper around MlirModule. @@ -568,7 +583,19 @@ PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) : BaseContextObject(std::move(contextRef)), attr(attr) {} bool operator==(const PyAttribute &other); + operator MlirAttribute() const { return attr; } + MlirAttribute get() const { return attr; } + /// Gets a capsule wrapping the void* within the MlirContext. + pybind11::object getCapsule(); + + /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. + /// Note that PyMlirContext instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirContext + /// is taken by calling this function. + static PyAttribute createFromCapsule(pybind11::object capsule); + +private: MlirAttribute attr; }; @@ -603,7 +630,18 @@ : BaseContextObject(std::move(contextRef)), type(type) {} bool operator==(const PyType &other); operator MlirType() const { return type; } + MlirType get() const { return type; } + + /// Gets a capsule wrapping the void* within the MlirContext. + pybind11::object getCapsule(); + /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. + /// Note that PyMlirContext instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirContext + /// is taken by calling this function. + static PyType createFromCapsule(pybind11::object capsule); + +private: MlirType type; }; diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -289,7 +289,7 @@ llvm::SmallVector argTypes; argTypes.reserve(pyArgTypes.size()); for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast().type); + argTypes.push_back(pyArg.cast()); } MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); @@ -640,6 +640,18 @@ // PyLocation //------------------------------------------------------------------------------ +py::object PyLocation::getCapsule() { + return py::reinterpret_steal(mlirPythonLocationToCapsule(*this)); +} + +PyLocation PyLocation::createFromCapsule(py::object capsule) { + MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); + if (mlirLocationIsNull(rawLoc)) + throw py::error_already_set(); + return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), + rawLoc); +} + py::object PyLocation::contextEnter() { return PyThreadContextEntry::pushLocation(*this); } @@ -879,7 +891,7 @@ // TODO: Verify result type originate from the same context. if (!result) throw SetPyError(PyExc_ValueError, "result type cannot be None"); - mlirResults.push_back(result->type); + mlirResults.push_back(*result); } } // Unpack/validate attributes. @@ -890,7 +902,7 @@ auto name = it.first.cast(); auto &attribute = it.second.cast(); // TODO: Verify attribute originates from the same context. - mlirAttributes.emplace_back(std::move(name), attribute.attr); + mlirAttributes.emplace_back(std::move(name), attribute); } } // Unpack/validate successors. @@ -908,7 +920,7 @@ // Apply unpacked/validated to the operation state. Beyond this // point, exceptions cannot be thrown or else the state will leak. MlirOperationState state = - mlirOperationStateGet(toMlirStringRef(name), location->loc); + mlirOperationStateGet(toMlirStringRef(name), location); if (!mlirOperands.empty()) mlirOperationStateAddOperands(&state, mlirOperands.size(), mlirOperands.data()); @@ -1076,6 +1088,18 @@ return mlirAttributeEqual(attr, other.attr); } +py::object PyAttribute::getCapsule() { + return py::reinterpret_steal(mlirPythonAttributeToCapsule(*this)); +} + +PyAttribute PyAttribute::createFromCapsule(py::object capsule) { + MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); + if (mlirAttributeIsNull(rawAttr)) + throw py::error_already_set(); + return PyAttribute( + PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); +} + //------------------------------------------------------------------------------ // PyNamedAttribute. //------------------------------------------------------------------------------ @@ -1093,6 +1117,18 @@ return mlirTypeEqual(type, other.type); } +py::object PyType::getCapsule() { + return py::reinterpret_steal(mlirPythonTypeToCapsule(*this)); +} + +PyType PyType::createFromCapsule(py::object capsule) { + MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); + if (mlirTypeIsNull(rawType)) + throw py::error_already_set(); + return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), + rawType); +} + //------------------------------------------------------------------------------ // PyValue and subclases. //------------------------------------------------------------------------------ @@ -1315,7 +1351,7 @@ void dunderSetItem(const std::string &name, PyAttribute attr) { mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), - attr.attr); + attr); } void dunderDelItem(const std::string &name) { @@ -1378,13 +1414,13 @@ : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} static MlirAttribute castFrom(PyAttribute &orig) { - if (!DerivedTy::isaFunction(orig.attr)) { + if (!DerivedTy::isaFunction(orig)) { auto origRepr = py::repr(py::cast(orig)).cast(); throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast attribute to ") + DerivedTy::pyClassName + " (from " + origRepr + ")"); } - return orig.attr; + return orig; } static void bind(py::module &m) { @@ -1408,8 +1444,7 @@ c.def_static( "get", [](PyType &type, double value, DefaultingPyLocation loc) { - MlirAttribute attr = - mlirFloatAttrDoubleGetChecked(type.type, value, loc->loc); + MlirAttribute attr = mlirFloatAttrDoubleGetChecked(type, value, loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirAttributeIsNull(attr)) { @@ -1443,7 +1478,7 @@ c.def_property_readonly( "value", [](PyFloatAttribute &self) { - return mlirFloatAttrGetValueDouble(self.attr); + return mlirFloatAttrGetValueDouble(self); }, "Returns the value of the float point attribute"); } @@ -1460,7 +1495,7 @@ c.def_static( "get", [](PyType &type, int64_t value) { - MlirAttribute attr = mlirIntegerAttrGet(type.type, value); + MlirAttribute attr = mlirIntegerAttrGet(type, value); return PyIntegerAttribute(type.getContext(), attr); }, py::arg("type"), py::arg("value"), @@ -1468,7 +1503,7 @@ c.def_property_readonly( "value", [](PyIntegerAttribute &self) { - return mlirIntegerAttrGetValueInt(self.attr); + return mlirIntegerAttrGetValueInt(self); }, "Returns the value of the integer attribute"); } @@ -1492,7 +1527,7 @@ "Gets an uniqued bool attribute"); c.def_property_readonly( "value", - [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self.attr); }, + [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, "Returns the value of the bool attribute"); } }; @@ -1517,7 +1552,7 @@ "get_typed", [](PyType &type, std::string value) { MlirAttribute attr = - mlirStringAttrTypedGet(type.type, value.size(), &value[0]); + mlirStringAttrTypedGet(type, value.size(), &value[0]); return PyStringAttribute(type.getContext(), attr); }, @@ -1525,7 +1560,7 @@ c.def_property_readonly( "value", [](PyStringAttribute &self) { - MlirStringRef stringRef = mlirStringAttrGetValue(self.attr); + MlirStringRef stringRef = mlirStringAttrGetValue(self); return py::str(stringRef.data, stringRef.length); }, "Returns the value of the string attribute"); @@ -1621,8 +1656,8 @@ PyAttribute &elementAttr) { auto contextWrapper = PyMlirContext::forContext(mlirTypeGetContext(shapedType)); - if (!mlirAttributeIsAInteger(elementAttr.attr) && - !mlirAttributeIsAFloat(elementAttr.attr)) { + if (!mlirAttributeIsAInteger(elementAttr) && + !mlirAttributeIsAFloat(elementAttr)) { std::string message = "Illegal element type for DenseElementsAttr: "; message.append(py::repr(py::cast(elementAttr))); throw SetPyError(PyExc_ValueError, message); @@ -1634,8 +1669,8 @@ message.append(py::repr(py::cast(shapedType))); throw SetPyError(PyExc_ValueError, message); } - MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType.type); - MlirType attrType = mlirAttributeGetType(elementAttr.attr); + MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); + MlirType attrType = mlirAttributeGetType(elementAttr); if (!mlirTypeEqual(shapedElementType, attrType)) { std::string message = "Shaped element type and attribute type must be equal: shaped="; @@ -1646,14 +1681,14 @@ } MlirAttribute elements = - mlirDenseElementsAttrSplatGet(shapedType.type, elementAttr.attr); + mlirDenseElementsAttrSplatGet(shapedType, elementAttr); return PyDenseElementsAttribute(contextWrapper->getRef(), elements); } - intptr_t dunderLen() { return mlirElementsAttrGetNumElements(attr); } + intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } py::buffer_info accessBuffer() { - MlirType shapedType = mlirAttributeGetType(this->attr); + MlirType shapedType = mlirAttributeGetType(*this); MlirType elementType = mlirShapedTypeGetElementType(shapedType); if (mlirTypeIsAF32(elementType)) { @@ -1699,7 +1734,7 @@ "Gets a DenseElementsAttr where all values are the same") .def_property_readonly("is_splat", [](PyDenseElementsAttribute &self) -> bool { - return mlirDenseElementsAttrIsSplat(self.attr); + return mlirDenseElementsAttrIsSplat(self); }) .def_buffer(&PyDenseElementsAttribute::accessBuffer); } @@ -1742,7 +1777,7 @@ // Prepare the data for the buffer_info. // Buffer is configured for read-only access below. Type *data = static_cast( - const_cast(mlirDenseElementsAttrGetRawData(this->attr))); + const_cast(mlirDenseElementsAttrGetRawData(*this))); // Prepare the shape for the buffer_info. SmallVector shape; for (intptr_t i = 0; i < rank; ++i) @@ -1782,7 +1817,7 @@ "attempt to access out of bounds element"); } - MlirType type = mlirAttributeGetType(attr); + MlirType type = mlirAttributeGetType(*this); type = mlirShapedTypeGetElementType(type); assert(mlirTypeIsAInteger(type) && "expected integer element type in dense int elements attribute"); @@ -1795,23 +1830,23 @@ bool isUnsigned = mlirIntegerTypeIsUnsigned(type); if (isUnsigned) { if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(attr, pos); + return mlirDenseElementsAttrGetBoolValue(*this, pos); } if (width == 32) { - return mlirDenseElementsAttrGetUInt32Value(attr, pos); + return mlirDenseElementsAttrGetUInt32Value(*this, pos); } if (width == 64) { - return mlirDenseElementsAttrGetUInt64Value(attr, pos); + return mlirDenseElementsAttrGetUInt64Value(*this, pos); } } else { if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(attr, pos); + return mlirDenseElementsAttrGetBoolValue(*this, pos); } if (width == 32) { - return mlirDenseElementsAttrGetInt32Value(attr, pos); + return mlirDenseElementsAttrGetInt32Value(*this, pos); } if (width == 64) { - return mlirDenseElementsAttrGetInt64Value(attr, pos); + return mlirDenseElementsAttrGetInt64Value(*this, pos); } } throw SetPyError(PyExc_TypeError, "Unsupported integer type"); @@ -1838,7 +1873,7 @@ "attempt to access out of bounds element"); } - MlirType type = mlirAttributeGetType(attr); + MlirType type = mlirAttributeGetType(*this); type = mlirShapedTypeGetElementType(type); // Dispatch element extraction to an appropriate C function based on the // elemental type of the attribute. py::float_ is implicitly constructible @@ -1846,10 +1881,10 @@ // TODO: consider caching the type properties in the constructor to avoid // querying them on each element access. if (mlirTypeIsAF32(type)) { - return mlirDenseElementsAttrGetFloatValue(attr, pos); + return mlirDenseElementsAttrGetFloatValue(*this, pos); } if (mlirTypeIsAF64(type)) { - return mlirDenseElementsAttrGetDoubleValue(attr, pos); + return mlirDenseElementsAttrGetDoubleValue(*this, pos); } throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); } @@ -1906,13 +1941,13 @@ : PyConcreteType(orig.getContext(), castFrom(orig)) {} static MlirType castFrom(PyType &orig) { - if (!DerivedTy::isaFunction(orig.type)) { + if (!DerivedTy::isaFunction(orig)) { auto origRepr = py::repr(py::cast(orig)).cast(); throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + DerivedTy::pyClassName + " (from " + origRepr + ")"); } - return orig.type; + return orig; } static void bind(py::module &m) { @@ -1958,24 +1993,24 @@ "Create an unsigned integer type"); c.def_property_readonly( "width", - [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); }, + [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, "Returns the width of the integer type"); c.def_property_readonly( "is_signless", [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSignless(self.type); + return mlirIntegerTypeIsSignless(self); }, "Returns whether this is a signless integer"); c.def_property_readonly( "is_signed", [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSigned(self.type); + return mlirIntegerTypeIsSigned(self); }, "Returns whether this is a signed integer"); c.def_property_readonly( "is_unsigned", [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsUnsigned(self.type); + return mlirIntegerTypeIsUnsigned(self); }, "Returns whether this is an unsigned integer"); } @@ -2101,8 +2136,8 @@ "get", [](PyType &elementType) { // The element must be a floating point or integer scalar type. - if (mlirTypeIsAIntegerOrFloat(elementType.type)) { - MlirType t = mlirComplexTypeGet(elementType.type); + if (mlirTypeIsAIntegerOrFloat(elementType)) { + MlirType t = mlirComplexTypeGet(elementType); return PyComplexType(elementType.getContext(), t); } throw SetPyError( @@ -2115,7 +2150,7 @@ c.def_property_readonly( "element_type", [](PyComplexType &self) -> PyType { - MlirType t = mlirComplexTypeGetElementType(self.type); + MlirType t = mlirComplexTypeGetElementType(self); return PyType(self.getContext(), t); }, "Returns element type."); @@ -2132,34 +2167,32 @@ c.def_property_readonly( "element_type", [](PyShapedType &self) { - MlirType t = mlirShapedTypeGetElementType(self.type); + MlirType t = mlirShapedTypeGetElementType(self); return PyType(self.getContext(), t); }, "Returns the element type of the shaped type."); c.def_property_readonly( "has_rank", - [](PyShapedType &self) -> bool { - return mlirShapedTypeHasRank(self.type); - }, + [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, "Returns whether the given shaped type is ranked."); c.def_property_readonly( "rank", [](PyShapedType &self) { self.requireHasRank(); - return mlirShapedTypeGetRank(self.type); + return mlirShapedTypeGetRank(self); }, "Returns the rank of the given ranked shaped type."); c.def_property_readonly( "has_static_shape", [](PyShapedType &self) -> bool { - return mlirShapedTypeHasStaticShape(self.type); + return mlirShapedTypeHasStaticShape(self); }, "Returns whether the given shaped type has a static shape."); c.def( "is_dynamic_dim", [](PyShapedType &self, intptr_t dim) -> bool { self.requireHasRank(); - return mlirShapedTypeIsDynamicDim(self.type, dim); + return mlirShapedTypeIsDynamicDim(self, dim); }, "Returns whether the dim-th dimension of the given shaped type is " "dynamic."); @@ -2167,7 +2200,7 @@ "get_dim_size", [](PyShapedType &self, intptr_t dim) { self.requireHasRank(); - return mlirShapedTypeGetDimSize(self.type, dim); + return mlirShapedTypeGetDimSize(self, dim); }, "Returns the dim-th dimension of the given ranked shaped type."); c.def_static( @@ -2187,7 +2220,7 @@ private: void requireHasRank() { - if (!mlirShapedTypeHasRank(type)) { + if (!mlirShapedTypeHasRank(*this)) { throw SetPyError( PyExc_ValueError, "calling this method requires that the type has a rank."); @@ -2208,7 +2241,7 @@ [](std::vector shape, PyType &elementType, DefaultingPyLocation loc) { MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(), - elementType.type, loc->loc); + elementType, loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2239,7 +2272,7 @@ [](std::vector shape, PyType &elementType, DefaultingPyLocation loc) { MlirType t = mlirRankedTensorTypeGetChecked( - shape.size(), shape.data(), elementType.type, loc->loc); + shape.size(), shape.data(), elementType, loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2270,8 +2303,7 @@ c.def_static( "get", [](PyType &elementType, DefaultingPyLocation loc) { - MlirType t = - mlirUnrankedTensorTypeGetChecked(elementType.type, loc->loc); + MlirType t = mlirUnrankedTensorTypeGetChecked(elementType, loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2306,8 +2338,7 @@ [](PyType &elementType, std::vector shape, unsigned memorySpace, DefaultingPyLocation loc) { MlirType t = mlirMemRefTypeContiguousGetChecked( - elementType.type, shape.size(), shape.data(), memorySpace, - loc->loc); + elementType, shape.size(), shape.data(), memorySpace, loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2326,14 +2357,14 @@ .def_property_readonly( "num_affine_maps", [](PyMemRefType &self) -> intptr_t { - return mlirMemRefTypeGetNumAffineMaps(self.type); + return mlirMemRefTypeGetNumAffineMaps(self); }, "Returns the number of affine layout maps in the given MemRef " "type.") .def_property_readonly( "memory_space", [](PyMemRefType &self) -> unsigned { - return mlirMemRefTypeGetMemorySpace(self.type); + return mlirMemRefTypeGetMemorySpace(self); }, "Returns the memory space of the given MemRef type."); } @@ -2352,8 +2383,8 @@ "get", [](PyType &elementType, unsigned memorySpace, DefaultingPyLocation loc) { - MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type, - memorySpace, loc->loc); + MlirType t = + mlirUnrankedMemRefTypeGetChecked(elementType, memorySpace, loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2372,7 +2403,7 @@ .def_property_readonly( "memory_space", [](PyUnrankedMemRefType &self) -> unsigned { - return mlirUnrankedMemrefGetMemorySpace(self.type); + return mlirUnrankedMemrefGetMemorySpace(self); }, "Returns the memory space of the given Unranked MemRef type."); } @@ -2393,7 +2424,7 @@ // Mapping py::list to SmallVector. SmallVector elements; for (auto element : elementList) - elements.push_back(element.cast().type); + elements.push_back(element.cast()); MlirType t = mlirTupleTypeGet(context->get(), num, elements.data()); return PyTupleType(context->getRef(), t); }, @@ -2402,14 +2433,14 @@ c.def( "get_type", [](PyTupleType &self, intptr_t pos) -> PyType { - MlirType t = mlirTupleTypeGetType(self.type, pos); + MlirType t = mlirTupleTypeGetType(self, pos); return PyType(self.getContext(), t); }, "Returns the pos-th type in the tuple type."); c.def_property_readonly( "num_types", [](PyTupleType &self) -> intptr_t { - return mlirTupleTypeGetNumTypes(self.type); + return mlirTupleTypeGetNumTypes(self); }, "Returns the number of types contained in a tuple."); } @@ -2439,11 +2470,11 @@ c.def_property_readonly( "inputs", [](PyFunctionType &self) { - MlirType t = self.type; + MlirType t = self; auto contextRef = self.getContext(); py::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self.type); - i < e; ++i) { + for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; + ++i) { types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i))); } return types; @@ -2452,12 +2483,12 @@ c.def_property_readonly( "results", [](PyFunctionType &self) { - MlirType t = self.type; auto contextRef = self.getContext(); py::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self.type); - i < e; ++i) { - types.append(PyType(contextRef, mlirFunctionTypeGetResult(t, i))); + for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; + ++i) { + types.append( + PyType(contextRef, mlirFunctionTypeGetResult(self, i))); } return types; }, @@ -2584,8 +2615,15 @@ // Mapping of Location //---------------------------------------------------------------------------- py::class_(m, "Location") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) .def("__enter__", &PyLocation::contextEnter) .def("__exit__", &PyLocation::contextExit) + .def("__eq__", + [](PyLocation &self, PyLocation &other) -> bool { + return mlirLocationEqual(self, other); + }) + .def("__eq__", [](PyLocation &self, py::object other) { return false; }) .def_property_readonly_static( "current", [](py::object & /*class*/) { @@ -2620,7 +2658,7 @@ "Context that owns the Location") .def("__repr__", [](PyLocation &self) { PyPrintAccumulator printAccum; - mlirLocationPrint(self.loc, printAccum.getCallback(), + mlirLocationPrint(self, printAccum.getCallback(), printAccum.getUserData()); return printAccum.join(); }); @@ -2650,7 +2688,7 @@ .def_static( "create", [](DefaultingPyLocation loc) { - MlirModule module = mlirModuleCreateEmpty(loc->loc); + MlirModule module = mlirModuleCreateEmpty(loc); return PyModule::forModule(module).releaseObject(); }, py::arg("loc") = py::none(), "Creates an empty module") @@ -2881,6 +2919,9 @@ // Mapping of PyAttribute. //---------------------------------------------------------------------------- py::class_(m, "Attribute") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyAttribute::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) .def_static( "parse", [](std::string attrSpec, DefaultingPyMlirContext context) { @@ -2904,25 +2945,25 @@ .def_property_readonly("type", [](PyAttribute &self) { return PyType(self.getContext()->getRef(), - mlirAttributeGetType(self.attr)); + mlirAttributeGetType(self)); }) .def( "get_named", [](PyAttribute &self, std::string name) { - return PyNamedAttribute(self.attr, std::move(name)); + return PyNamedAttribute(self, std::move(name)); }, py::keep_alive<0, 1>(), "Binds a name to the attribute") .def("__eq__", [](PyAttribute &self, PyAttribute &other) { return self == other; }) .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) .def( - "dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); }, + "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, kDumpDocstring) .def( "__str__", [](PyAttribute &self) { PyPrintAccumulator printAccum; - mlirAttributePrint(self.attr, printAccum.getCallback(), + mlirAttributePrint(self, printAccum.getCallback(), printAccum.getUserData()); return printAccum.join(); }, @@ -2935,7 +2976,7 @@ // being excessive. PyPrintAccumulator printAccum; printAccum.parts.append("Attribute("); - mlirAttributePrint(self.attr, printAccum.getCallback(), + mlirAttributePrint(self, printAccum.getCallback(), printAccum.getUserData()); printAccum.parts.append(")"); return printAccum.join(); @@ -2990,6 +3031,8 @@ // Mapping of PyType. //---------------------------------------------------------------------------- py::class_(m, "Type") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) .def_static( "parse", [](std::string typeSpec, DefaultingPyMlirContext context) { @@ -3012,12 +3055,12 @@ .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) .def("__eq__", [](PyType &self, py::object &other) { return false; }) .def( - "dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring) + "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) .def( "__str__", [](PyType &self) { PyPrintAccumulator printAccum; - mlirTypePrint(self.type, printAccum.getCallback(), + mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); return printAccum.join(); }, @@ -3029,8 +3072,7 @@ // assembly forms and printing them is useful. PyPrintAccumulator printAccum; printAccum.parts.append("Type("); - mlirTypePrint(self.type, printAccum.getCallback(), - printAccum.getUserData()); + mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); printAccum.parts.append(")"); return printAccum.join(); }); diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -50,7 +50,7 @@ Defaulting() = default; Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} - ReferrentTy *get() { return referrent; } + ReferrentTy *get() const { return referrent; } ReferrentTy *operator->() { return referrent; } private: diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -119,6 +119,10 @@ return wrap(UnknownLoc::get(unwrap(context))); } +int mlirLocationEqual(MlirLocation l1, MlirLocation l2) { + return unwrap(l1) == unwrap(l2); +} + MlirContext mlirLocationGetContext(MlirLocation location) { return wrap(unwrap(location).getContext()); } diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -74,6 +74,20 @@ run(testAttrEqDoesNotRaise) +# CHECK-LABEL: TEST: testAttrCapsule +def testAttrCapsule(): + with Context() as ctx: + a1 = Attribute.parse('"attr1"') + # CHECK: mlir.ir.Attribute._CAPIPtr + attr_capsule = a1._CAPIPtr + print(attr_capsule) + a2 = Attribute._CAPICreate(attr_capsule) + assert a2 == a1 + assert a2.context is ctx + +run(testAttrCapsule) + + # CHECK-LABEL: TEST: testStandardAttrCasts def testStandardAttrCasts(): with Context(): diff --git a/mlir/test/Bindings/Python/ir_location.py b/mlir/test/Bindings/Python/ir_location.py --- a/mlir/test/Bindings/Python/ir_location.py +++ b/mlir/test/Bindings/Python/ir_location.py @@ -38,3 +38,16 @@ run(testFileLineCol) + +# CHECK-LABEL: TEST: testLocationCapsule +def testLocationCapsule(): + with Context() as ctx: + loc1 = Location.file("foo.txt", 123, 56) + # CHECK: mlir.ir.Location._CAPIPtr + loc_capsule = loc1._CAPIPtr + print(loc_capsule) + loc2 = Location._CAPICreate(loc_capsule) + assert loc2 == loc1 + assert loc2.context is ctx + +run(testLocationCapsule) diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -74,6 +74,20 @@ run(testTypeEqDoesNotRaise) +# CHECK-LABEL: TEST: testTypeCapsule +def testTypeCapsule(): + with Context() as ctx: + t1 = Type.parse("i32", ctx) + # CHECK: mlir.ir.Type._CAPIPtr + type_capsule = t1._CAPIPtr + print(type_capsule) + t2 = Type._CAPICreate(type_capsule) + assert t2 == t1 + assert t2.context is ctx + +run(testTypeCapsule) + + # CHECK-LABEL: TEST: testStandardTypeCasts def testStandardTypeCasts(): ctx = Context()