diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -28,7 +28,7 @@ public: PyMlirContextRef(PyMlirContext &referrent, pybind11::object object) : referrent(referrent), object(std::move(object)) {} - ~PyMlirContextRef() { printf("REF DESTROY\n"); } + ~PyMlirContextRef() {} /// Releases the object held by this instance, causing its reference count /// to remain artifically inflated by one. This must be used to return @@ -57,7 +57,15 @@ static PyMlirContextRef forContext(MlirContext context); ~PyMlirContext(); - MlirContext context; + /// Accesses the underlying MlirContext. + MlirContext get() { return context; } + + /// Gets a strong reference to this context, which will ensure it is kept + /// alive for the life of the reference. + PyMlirContextRef getRef() { + return PyMlirContextRef( + *this, pybind11::reinterpret_borrow(handle)); + } /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); @@ -74,6 +82,10 @@ using LiveContextMap = llvm::DenseMap>; static LiveContextMap &getLiveContexts(); + + MlirContext context; + // The handle is set as part of lookup with forContext() (post construction). + pybind11::handle handle; }; /// Base class for all objects that directly or indirectly depend on an @@ -82,9 +94,11 @@ /// Immutable objects that depend on a context extend this directly. class BaseContextObject { public: + /// Preferred constructor that uses an explicit context ref. + BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {} + /// Convenience constructor that will look up the context in the live list. BaseContextObject(MlirContext context) : contextRef(PyMlirContext::forContext(context)) {} - BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {} /// Accesses the context reference. PyMlirContextRef &getContext() { return contextRef; } @@ -96,6 +110,8 @@ /// Wrapper around an MlirLocation. class PyLocation : public BaseContextObject { public: + PyLocation(PyMlirContextRef contextRef, MlirLocation loc) + : BaseContextObject(std::move(contextRef)), loc(loc) {} PyLocation(MlirLocation loc) : BaseContextObject(mlirLocationGetContext(loc)), loc(loc) {} MlirLocation loc; @@ -104,6 +120,8 @@ /// Wrapper around MlirModule. class PyModule : public BaseContextObject { public: + PyModule(PyMlirContextRef contextRef, MlirModule module) + : BaseContextObject(std::move(contextRef)), module(module) {} PyModule(MlirModule module) : BaseContextObject(mlirModuleGetContext(module)), module(module) {} PyModule(PyModule &) = delete; @@ -190,6 +208,8 @@ /// The lifetime of a type is bound by the PyContext that created it. class PyAttribute : public BaseContextObject { public: + PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) + : BaseContextObject(std::move(contextRef)), attr(attr) {} PyAttribute(MlirAttribute attr) : BaseContextObject(mlirAttributeGetContext(attr)), attr(attr) {} bool operator==(const PyAttribute &other); @@ -224,6 +244,8 @@ /// The lifetime of a type is bound by the PyContext that created it. class PyType : public BaseContextObject { public: + PyType(PyMlirContextRef contextRef, MlirType type) + : BaseContextObject(std::move(contextRef)), type(type) {} PyType(MlirType type) : BaseContextObject(mlirTypeGetContext(type)), type(type) {} bool operator==(const PyType &other); diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -185,30 +185,32 @@ // Note that the only public way to construct an instance is via the // forContext method, which always puts the associated handle into // liveContexts. + py::gil_scoped_acquire acquire; getLiveContexts().erase(context.ptr); mlirContextDestroy(context); } PyMlirContextRef PyMlirContext::forContext(MlirContext context) { + py::gil_scoped_acquire acquire; auto &liveContexts = getLiveContexts(); auto it = liveContexts.find(context.ptr); if (it == liveContexts.end()) { - // Create + // Create. PyMlirContext *unownedContextWrapper = new PyMlirContext(context); py::object pyRef = py::cast(unownedContextWrapper); + unownedContextWrapper->handle = pyRef; liveContexts[context.ptr] = std::make_pair(pyRef, unownedContextWrapper); return PyMlirContextRef(*unownedContextWrapper, std::move(pyRef)); } else { - // Existing + // Use existing. py::object pyRef = py::reinterpret_borrow(it->second.first); return PyMlirContextRef(*it->second.second, std::move(pyRef)); } } PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { - // Heap allocate so it lives forever. - static LiveContextMap *liveContexts = new LiveContextMap(); - return *liveContexts; + static LiveContextMap liveContexts; + return liveContexts; } size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } @@ -277,9 +279,11 @@ using IsAFunctionTy = int (*)(MlirAttribute); PyConcreteAttribute() = default; + PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) + : BaseTy(std::move(contextRef), attr) {} PyConcreteAttribute(MlirAttribute attr) : BaseTy(attr) {} PyConcreteAttribute(PyAttribute &orig) - : PyConcreteAttribute(castFrom(orig)) {} + : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} static MlirAttribute castFrom(PyAttribute &orig) { if (!DerivedTy::isaFunction(orig.attr)) { @@ -312,8 +316,8 @@ "get", [](PyMlirContext &context, std::string value) { MlirAttribute attr = - mlirStringAttrGet(context.context, value.size(), &value[0]); - return PyStringAttribute(attr); + mlirStringAttrGet(context.get(), value.size(), &value[0]); + return PyStringAttribute(context.getRef(), attr); }, "Gets a uniqued string attribute"); c.def_static( @@ -358,8 +362,11 @@ using IsAFunctionTy = int (*)(MlirType); PyConcreteType() = default; + PyConcreteType(PyMlirContextRef contextRef, MlirType t) + : BaseTy(std::move(contextRef), t) {} PyConcreteType(MlirType t) : BaseTy(t) {} - PyConcreteType(PyType &orig) : PyConcreteType(castFrom(orig)) {} + PyConcreteType(PyType &orig) + : PyConcreteType(orig.getContext(), castFrom(orig)) {} static MlirType castFrom(PyType &orig) { if (!DerivedTy::isaFunction(orig.type)) { @@ -391,22 +398,22 @@ c.def_static( "get_signless", [](PyMlirContext &context, unsigned width) { - MlirType t = mlirIntegerTypeGet(context.context, width); - return PyIntegerType(t); + MlirType t = mlirIntegerTypeGet(context.get(), width); + return PyIntegerType(context.getRef(), t); }, "Create a signless integer type"); c.def_static( "get_signed", [](PyMlirContext &context, unsigned width) { - MlirType t = mlirIntegerTypeSignedGet(context.context, width); - return PyIntegerType(t); + MlirType t = mlirIntegerTypeSignedGet(context.get(), width); + return PyIntegerType(context.getRef(), t); }, "Create a signed integer type"); c.def_static( "get_unsigned", [](PyMlirContext &context, unsigned width) { - MlirType t = mlirIntegerTypeUnsignedGet(context.context, width); - return PyIntegerType(t); + MlirType t = mlirIntegerTypeUnsignedGet(context.get(), width); + return PyIntegerType(context.getRef(), t); }, "Create an unsigned integer type"); c.def_property_readonly( @@ -443,8 +450,8 @@ static void bindDerived(ClassTy &c) { c.def(py::init([](PyMlirContext &context) { - MlirType t = mlirIndexTypeGet(context.context); - return PyIndexType(t); + MlirType t = mlirIndexTypeGet(context.get()); + return PyIndexType(context.getRef(), t); }), "Create a index type."); } @@ -459,8 +466,8 @@ static void bindDerived(ClassTy &c) { c.def(py::init([](PyMlirContext &context) { - MlirType t = mlirBF16TypeGet(context.context); - return PyBF16Type(t); + MlirType t = mlirBF16TypeGet(context.get()); + return PyBF16Type(context.getRef(), t); }), "Create a bf16 type."); } @@ -475,8 +482,8 @@ static void bindDerived(ClassTy &c) { c.def(py::init([](PyMlirContext &context) { - MlirType t = mlirF16TypeGet(context.context); - return PyF16Type(t); + MlirType t = mlirF16TypeGet(context.get()); + return PyF16Type(context.getRef(), t); }), "Create a f16 type."); } @@ -491,8 +498,8 @@ static void bindDerived(ClassTy &c) { c.def(py::init([](PyMlirContext &context) { - MlirType t = mlirF32TypeGet(context.context); - return PyF32Type(t); + MlirType t = mlirF32TypeGet(context.get()); + return PyF32Type(context.getRef(), t); }), "Create a f32 type."); } @@ -507,8 +514,8 @@ static void bindDerived(ClassTy &c) { c.def(py::init([](PyMlirContext &context) { - MlirType t = mlirF64TypeGet(context.context); - return PyF64Type(t); + MlirType t = mlirF64TypeGet(context.get()); + return PyF64Type(context.getRef(), t); }), "Create a f64 type."); } @@ -523,8 +530,8 @@ static void bindDerived(ClassTy &c) { c.def(py::init([](PyMlirContext &context) { - MlirType t = mlirNoneTypeGet(context.context); - return PyNoneType(t); + MlirType t = mlirNoneTypeGet(context.get()); + return PyNoneType(context.getRef(), t); }), "Create a none type."); } @@ -831,8 +838,8 @@ SmallVector elements; for (auto element : elementList) elements.push_back(element.cast().type); - MlirType t = mlirTupleTypeGet(context.context, num, elements.data()); - return PyTupleType(t); + MlirType t = mlirTupleTypeGet(context.get(), num, elements.data()); + return PyTupleType(context.getRef(), t); }, "Create a tuple type"); c.def( @@ -868,14 +875,13 @@ .def_static("_get_live_count", &PyMlirContext::getLiveCount) .def("_get_context_again", [](PyMlirContext &self) { - auto ref = PyMlirContext::forContext(self.context); + auto ref = PyMlirContext::forContext(self.get()); return ref.release(); }) .def( "parse_module", [](PyMlirContext &self, const std::string module) { - auto moduleRef = - mlirModuleCreateParse(self.context, module.c_str()); + auto moduleRef = mlirModuleCreateParse(self.get(), module.c_str()); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirModuleIsNull(moduleRef)) { @@ -883,14 +889,14 @@ PyExc_ValueError, "Unable to parse module assembly (see diagnostics)"); } - return PyModule(moduleRef); + return PyModule(self.getRef(), moduleRef); }, kContextParseDocstring) .def( "parse_attr", [](PyMlirContext &self, std::string attrSpec) { MlirAttribute type = - mlirAttributeParseGet(self.context, attrSpec.c_str()); + mlirAttributeParseGet(self.get(), attrSpec.c_str()); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirAttributeIsNull(type)) { @@ -904,7 +910,7 @@ .def( "parse_type", [](PyMlirContext &self, std::string typeSpec) { - MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str()); + MlirType type = mlirTypeParseGet(self.get(), typeSpec.c_str()); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(type)) { @@ -918,14 +924,16 @@ .def( "get_unknown_location", [](PyMlirContext &self) { - return PyLocation(mlirLocationUnknownGet(self.context)); + return PyLocation(self.getRef(), + mlirLocationUnknownGet(self.get())); }, kContextGetUnknownLocationDocstring) .def( "get_file_location", [](PyMlirContext &self, std::string filename, int line, int col) { - return PyLocation(mlirLocationFileLineColGet( - self.context, filename.c_str(), line, col)); + return PyLocation(self.getRef(), + mlirLocationFileLineColGet( + self.get(), filename.c_str(), line, col)); }, kContextGetFileLocationDocstring, py::arg("filename"), py::arg("line"), py::arg("col")) @@ -935,7 +943,7 @@ // The creating context is explicitly captured on regions to // facilitate illegal assemblies of objects from multiple contexts // that would invalidate the memory model. - return PyRegion(self.context, mlirRegionCreate(), + return PyRegion(self.get(), mlirRegionCreate(), /*detached=*/true); }, py::keep_alive<0, 1>(), kContextCreateRegionDocstring) @@ -946,7 +954,7 @@ // types must be from the same context. for (auto pyType : pyTypes) { if (!mlirContextEqual(mlirTypeGetContext(pyType.type), - self.context)) { + self.get())) { throw SetPyError( PyExc_ValueError, "All types used to construct a block must be from " @@ -955,8 +963,7 @@ } llvm::SmallVector types(pyTypes.begin(), pyTypes.end()); - return PyBlock(self.context, - mlirBlockCreate(types.size(), &types[0]), + return PyBlock(self.get(), mlirBlockCreate(types.size(), &types[0]), /*detached=*/true); }, py::keep_alive<0, 1>(), kContextCreateBlockDocstring);