diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -103,6 +103,9 @@ /** Creates a location with unknown position owned by the given context. */ MlirLocation mlirLocationUnknownGet(MlirContext context); +/** Gets the context that a location was created with. */ +MlirContext mlirLocationGetContext(MlirLocation location); + /** Prints a location by sending chunks of the string representation and * forwarding `userData to `callback`. Note that the callback may be called * several times with consecutive chunks of the string. */ @@ -119,6 +122,9 @@ /** Parses a module from the string and transfers ownership to the caller. */ MlirModule mlirModuleCreateParse(MlirContext context, const char *module); +/** Gets the context that a module was created with. */ +MlirContext mlirModuleGetContext(MlirModule module); + /** Checks whether a module is null. */ inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; } @@ -342,6 +348,9 @@ /** Parses an attribute. The attribute is owned by the context. */ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr); +/** Gets the context that an attribute was created with. */ +MlirContext mlirAttributeGetContext(MlirAttribute attribute); + /** Checks whether an attribute is null. */ inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; } diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -12,6 +12,7 @@ #include #include "mlir-c/IR.h" +#include "llvm/ADT/DenseMap.h" namespace mlir { namespace python { @@ -19,28 +20,95 @@ class PyMlirContext; class PyModule; +/// Holds a C++ PyMlirContext and associated py::object, making it convenient +/// to have an auto-releasing C++-side keep-alive reference to the context. +/// The reference to the PyMlirContext is a simple C++ reference and the +/// py::object holds the reference count which keeps it alive. +class PyMlirContextRef { +public: + PyMlirContextRef(PyMlirContext &referrent, pybind11::object object) + : referrent(referrent), object(std::move(object)) {} + ~PyMlirContextRef() { printf("REF DESTROY\n"); } + + /// Releases the object held by this instance, causing its reference count + /// to remain artifically inflated by one. This must be used to return + /// the referenced PyMlirContext from a function. Otherwise, the destructor + /// of this reference would be called prior to the default take_ownership + /// policy assuming that the reference count has been transferred to it. + PyMlirContext *release(); + + PyMlirContext &operator->() { return referrent; } + pybind11::object getObject() { return object; } + +private: + PyMlirContext &referrent; + pybind11::object object; +}; + /// Wrapper around MlirContext. class PyMlirContext { public: - PyMlirContext() { context = mlirContextCreate(); } - ~PyMlirContext() { mlirContextDestroy(context); } + PyMlirContext() = delete; + PyMlirContext(const PyMlirContext &) = delete; + PyMlirContext(PyMlirContext &&) = delete; + + /// Returns a context reference for the singleton PyMlirContext wrapper for + /// the given context. + static PyMlirContextRef forContext(MlirContext context); + ~PyMlirContext(); MlirContext context; + + /// Gets the count of live context objects. Used for testing. + static size_t getLiveCount(); + +private: + PyMlirContext(MlirContext context); + + // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, + // preserving the relationship that an MlirContext maps to a single + // PyMlirContext wrapper. This could be replaced in the future with an + // extension mechanism on the MlirContext for stashing user pointers. + // Note that this holds a handle, which does not imply ownership. + // Mappings will be removed when the context is destructed. + using LiveContextMap = + llvm::DenseMap>; + static LiveContextMap &getLiveContexts(); +}; + +/// Base class for all objects that directly or indirectly depend on an +/// MlirContext. The lifetime of the context will extend at least to the +/// lifetime of these instances. +/// Immutable objects that depend on a context extend this directly. +class BaseContextObject { +public: + BaseContextObject(MlirContext context) + : contextRef(PyMlirContext::forContext(context)) {} + BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {} + + /// Accesses the context reference. + PyMlirContextRef &getContext() { return contextRef; } + +private: + PyMlirContextRef contextRef; }; /// Wrapper around an MlirLocation. -class PyLocation { +class PyLocation : public BaseContextObject { public: - PyLocation(MlirLocation loc) : loc(loc) {} + PyLocation(MlirLocation loc) + : BaseContextObject(mlirLocationGetContext(loc)), loc(loc) {} MlirLocation loc; }; /// Wrapper around MlirModule. -class PyModule { +class PyModule : public BaseContextObject { public: - PyModule(MlirModule module) : module(module) {} + PyModule(MlirModule module) + : BaseContextObject(mlirModuleGetContext(module)), module(module) {} PyModule(PyModule &) = delete; - PyModule(PyModule &&other) { + PyModule(PyModule &&other) + : BaseContextObject(std::move(other.getContext())) { module = other.module; other.module.ptr = nullptr; } @@ -120,9 +188,10 @@ /// Wrapper around the generic MlirAttribute. /// The lifetime of a type is bound by the PyContext that created it. -class PyAttribute { +class PyAttribute : public BaseContextObject { public: - PyAttribute(MlirAttribute attr) : attr(attr) {} + PyAttribute(MlirAttribute attr) + : BaseContextObject(mlirAttributeGetContext(attr)), attr(attr) {} bool operator==(const PyAttribute &other); MlirAttribute attr; @@ -153,9 +222,10 @@ /// Wrapper around the generic MlirType. /// The lifetime of a type is bound by the PyContext that created it. -class PyType { +class PyType : public BaseContextObject { public: - PyType(MlirType type) : type(type) {} + PyType(MlirType type) + : BaseContextObject(mlirTypeGetContext(type)), type(type) {} bool operator==(const PyType &other); operator MlirType() const { return type; } diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -170,6 +170,49 @@ } // namespace +//------------------------------------------------------------------------------ +// PyMlirContext +//------------------------------------------------------------------------------ + +PyMlirContext *PyMlirContextRef::release() { + object.release(); + return &referrent; +} + +PyMlirContext::PyMlirContext(MlirContext context) : context(context) {} + +PyMlirContext::~PyMlirContext() { + // Note that the only public way to construct an instance is via the + // forContext method, which always puts the associated handle into + // liveContexts. + getLiveContexts().erase(context.ptr); + mlirContextDestroy(context); +} + +PyMlirContextRef PyMlirContext::forContext(MlirContext context) { + auto &liveContexts = getLiveContexts(); + auto it = liveContexts.find(context.ptr); + if (it == liveContexts.end()) { + // Create + PyMlirContext *unownedContextWrapper = new PyMlirContext(context); + py::object pyRef = py::cast(unownedContextWrapper); + liveContexts[context.ptr] = std::make_pair(pyRef, unownedContextWrapper); + return PyMlirContextRef(*unownedContextWrapper, std::move(pyRef)); + } else { + // Existing + py::object pyRef = py::reinterpret_borrow(it->second.first); + return PyMlirContextRef(*it->second.second, std::move(pyRef)); + } +} + +PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { + // Heap allocate so it lives forever. + static LiveContextMap *liveContexts = new LiveContextMap(); + return *liveContexts; +} + +size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } + //------------------------------------------------------------------------------ // PyBlock, PyRegion, and PyOperation. //------------------------------------------------------------------------------ @@ -272,7 +315,7 @@ mlirStringAttrGet(context.context, value.size(), &value[0]); return PyStringAttribute(attr); }, - py::keep_alive<0, 1>(), "Gets a uniqued string attribute"); + "Gets a uniqued string attribute"); c.def_static( "get_typed", [](PyType &type, std::string value) { @@ -280,7 +323,7 @@ mlirStringAttrTypedGet(type.type, value.size(), &value[0]); return PyStringAttribute(attr); }, - py::keep_alive<0, 1>(), + "Gets a uniqued string attribute associated to a type"); c.def_property_readonly( "value", @@ -351,21 +394,21 @@ MlirType t = mlirIntegerTypeGet(context.context, width); return PyIntegerType(t); }, - py::keep_alive<0, 1>(), "Create a signless integer type"); + "Create a signless integer type"); c.def_static( "get_signed", [](PyMlirContext &context, unsigned width) { MlirType t = mlirIntegerTypeSignedGet(context.context, width); return PyIntegerType(t); }, - py::keep_alive<0, 1>(), "Create a signed integer type"); + "Create a signed integer type"); c.def_static( "get_unsigned", [](PyMlirContext &context, unsigned width) { MlirType t = mlirIntegerTypeUnsignedGet(context.context, width); return PyIntegerType(t); }, - py::keep_alive<0, 1>(), "Create an unsigned integer type"); + "Create an unsigned integer type"); c.def_property_readonly( "width", [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); }, @@ -403,7 +446,7 @@ MlirType t = mlirIndexTypeGet(context.context); return PyIndexType(t); }), - py::keep_alive<0, 1>(), "Create a index type."); + "Create a index type."); } }; @@ -419,7 +462,7 @@ MlirType t = mlirBF16TypeGet(context.context); return PyBF16Type(t); }), - py::keep_alive<0, 1>(), "Create a bf16 type."); + "Create a bf16 type."); } }; @@ -435,7 +478,7 @@ MlirType t = mlirF16TypeGet(context.context); return PyF16Type(t); }), - py::keep_alive<0, 1>(), "Create a f16 type."); + "Create a f16 type."); } }; @@ -451,7 +494,7 @@ MlirType t = mlirF32TypeGet(context.context); return PyF32Type(t); }), - py::keep_alive<0, 1>(), "Create a f32 type."); + "Create a f32 type."); } }; @@ -467,7 +510,7 @@ MlirType t = mlirF64TypeGet(context.context); return PyF64Type(t); }), - py::keep_alive<0, 1>(), "Create a f64 type."); + "Create a f64 type."); } }; @@ -483,7 +526,7 @@ MlirType t = mlirNoneTypeGet(context.context); return PyNoneType(t); }), - py::keep_alive<0, 1>(), "Create a none type."); + "Create a none type."); } }; @@ -509,7 +552,7 @@ py::repr(py::cast(elementType)).cast() + "' and expected floating point or integer type."); }, - py::keep_alive<0, 1>(), "Create a complex type"); + "Create a complex type"); c.def_property_readonly( "element_type", [](PyComplexType &self) -> PyType { @@ -533,7 +576,7 @@ MlirType t = mlirShapedTypeGetElementType(self.type); return PyType(t); }, - py::keep_alive<0, 1>(), "Returns the element type of the shaped type."); + "Returns the element type of the shaped type."); c.def_property_readonly( "has_rank", [](PyShapedType &self) -> bool { @@ -618,7 +661,7 @@ } return PyVectorType(t); }, - py::keep_alive<0, 2>(), "Create a vector type"); + "Create a vector type"); } }; @@ -650,7 +693,7 @@ } return PyRankedTensorType(t); }, - py::keep_alive<0, 2>(), "Create a ranked tensor type"); + "Create a ranked tensor type"); } }; @@ -682,7 +725,7 @@ } return PyUnrankedTensorType(t); }, - py::keep_alive<0, 1>(), "Create a unranked tensor type"); + "Create a unranked tensor type"); } }; @@ -717,7 +760,7 @@ } return PyMemRefType(t); }, - py::keep_alive<0, 1>(), "Create a memref type") + "Create a memref type") .def_property_readonly( "num_affine_maps", [](PyMemRefType &self) -> intptr_t { @@ -762,7 +805,7 @@ } return PyUnrankedMemRefType(t); }, - py::keep_alive<0, 1>(), "Create a unranked memref type") + "Create a unranked memref type") .def_property_readonly( "memory_space", [](PyUnrankedMemRefType &self) -> unsigned { @@ -791,14 +834,14 @@ MlirType t = mlirTupleTypeGet(context.context, num, elements.data()); return PyTupleType(t); }, - py::keep_alive<0, 1>(), "Create a tuple type"); + "Create a tuple type"); c.def( "get_type", [](PyTupleType &self, intptr_t pos) -> PyType { MlirType t = mlirTupleTypeGetType(self.type, pos); return PyType(t); }, - py::keep_alive<0, 1>(), "Returns the pos-th type in the tuple type."); + "Returns the pos-th type in the tuple type."); c.def_property_readonly( "num_types", [](PyTupleType &self) -> intptr_t { @@ -817,7 +860,17 @@ void mlir::python::populateIRSubmodule(py::module &m) { // Mapping of MlirContext py::class_(m, "Context") - .def(py::init<>()) + .def(py::init<>([]() { + MlirContext context = mlirContextCreate(); + auto contextRef = PyMlirContext::forContext(context); + return contextRef.release(); + })) + .def_static("_get_live_count", &PyMlirContext::getLiveCount) + .def("_get_context_again", + [](PyMlirContext &self) { + auto ref = PyMlirContext::forContext(self.context); + return ref.release(); + }) .def( "parse_module", [](PyMlirContext &self, const std::string module) { @@ -832,7 +885,7 @@ } return PyModule(moduleRef); }, - py::keep_alive<0, 1>(), kContextParseDocstring) + kContextParseDocstring) .def( "parse_attr", [](PyMlirContext &self, std::string attrSpec) { @@ -861,21 +914,21 @@ } return PyType(type); }, - py::keep_alive<0, 1>(), kContextParseTypeDocstring) + kContextParseTypeDocstring) .def( "get_unknown_location", [](PyMlirContext &self) { return PyLocation(mlirLocationUnknownGet(self.context)); }, - py::keep_alive<0, 1>(), kContextGetUnknownLocationDocstring) + kContextGetUnknownLocationDocstring) .def( "get_file_location", [](PyMlirContext &self, std::string filename, int line, int col) { return PyLocation(mlirLocationFileLineColGet( self.context, filename.c_str(), line, col)); }, - py::keep_alive<0, 1>(), kContextGetFileLocationDocstring, - py::arg("filename"), py::arg("line"), py::arg("col")) + kContextGetFileLocationDocstring, py::arg("filename"), + py::arg("line"), py::arg("col")) .def( "create_region", [](PyMlirContext &self) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -48,6 +48,10 @@ return wrap(UnknownLoc::get(unwrap(context))); } +MlirContext mlirLocationGetContext(MlirLocation location) { + return wrap(unwrap(location).getContext()); +} + void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData) { detail::CallbackOstream stream(callback, userData); @@ -70,6 +74,10 @@ return MlirModule{owning.release().getOperation()}; } +MlirContext mlirModuleGetContext(MlirModule module) { + return wrap(unwrap(module).getContext()); +} + void mlirModuleDestroy(MlirModule module) { // Transfer ownership to an OwningModuleRef so that its destructor is called. OwningModuleRef(unwrap(module)); @@ -349,6 +357,10 @@ return wrap(mlir::parseAttribute(attr, unwrap(context))); } +MlirContext mlirAttributeGetContext(MlirAttribute attribute) { + return wrap(unwrap(attribute).getContext()); +} + int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { return unwrap(a1) == unwrap(a2); } diff --git a/mlir/test/Bindings/Python/context_lifecycle.py b/mlir/test/Bindings/Python/context_lifecycle.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/context_lifecycle.py @@ -0,0 +1,42 @@ +# RUN: %PYTHON %s +# Standalone sanity check of context life-cycle. +import gc +import mlir + +assert mlir.ir.Context._get_live_count() == 0 + +# Create first context. +print("CREATE C1") +c1 = mlir.ir.Context() +assert mlir.ir.Context._get_live_count() == 1 +c1_repr = repr(c1) +print("C1 = ", c1_repr) + +print("GETTING AGAIN...") +c2 = c1._get_context_again() +c2_repr = repr(c2) +assert mlir.ir.Context._get_live_count() == 1 +assert c1_repr == c2_repr + +print("C2 =", c2) + +# Make sure new contexts on constructor. +print("CREATE C3") +c3 = mlir.ir.Context() +assert mlir.ir.Context._get_live_count() == 2 +c3_repr = repr(c3) +print("C3 =", c3) +assert c3_repr != c1_repr +print("FREE C3") +c3 = None +gc.collect() +assert mlir.ir.Context._get_live_count() == 1 + +print("Free C1") +c1 = None +gc.collect() +assert mlir.ir.Context._get_live_count() == 1 +print("Free C2") +c2 = None +gc.collect() +assert mlir.ir.Context._get_live_count() == 0