diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -19,32 +19,61 @@ class PyMlirContext; class PyModule; +class PyOperation; -/// Holds a C++ PyMlirContext and associated py::object, making it convenient -/// to have an auto-releasing C++-side keep-alive reference to the context. -/// The reference to the PyMlirContext is a simple C++ reference and the -/// py::object holds the reference count which keeps it alive. -class PyMlirContextRef { +/// Template for a reference to a concrete type which captures a python +/// reference to its underlying python object. +template +class PyObjectRef { public: - PyMlirContextRef(PyMlirContext &referrent, pybind11::object object) - : referrent(referrent), object(std::move(object)) {} - ~PyMlirContextRef() {} + PyObjectRef() : referrent(nullptr) {} + PyObjectRef(T *referrent, pybind11::object object) + : referrent(referrent), object(std::move(object)) { + assert(this->referrent && + "cannot construct PyObjectRef with null referrent"); + assert(this->object && "cannot construct PyObjectRef with null object"); + } + PyObjectRef(PyObjectRef &&other) + : referrent(other.referrent), object(std::move(other.object)) { + other.referrent = nullptr; + assert(!other.object); + } + PyObjectRef(const PyObjectRef &other) + : referrent(other.referrent), object(other.object /* copies */) {} + ~PyObjectRef() {} + + int getRefCount() { + if (!object) + return 0; + return object.ref_count(); + } - /// Releases the object held by this instance, causing its reference count - /// to remain artifically inflated by one. This must be used to return - /// the referenced PyMlirContext from a function. Otherwise, the destructor - /// of this reference would be called prior to the default take_ownership - /// policy assuming that the reference count has been transferred to it. - PyMlirContext *release(); + /// Releases the object held by this instance, returning it. + /// This is the proper thing to return from a function that wants to return + /// the reference. Note that this does not work from initializers. + pybind11::object releaseObject() { + assert(referrent && object); + referrent = nullptr; + return std::move(object); + } - PyMlirContext &operator->() { return referrent; } - pybind11::object getObject() { return object; } + T *operator->() { + assert(referrent && object); + return referrent; + } + pybind11::object getObject() { + assert(referrent && object); + return object; + } + operator bool() const { return referrent && object; } private: - PyMlirContext &referrent; + T *referrent; pybind11::object object; }; +using PyMlirContextRef = PyObjectRef; + /// Wrapper around MlirContext. class PyMlirContext { public: @@ -52,6 +81,16 @@ PyMlirContext(const PyMlirContext &) = delete; PyMlirContext(PyMlirContext &&) = delete; + /// For the case of a python __init__ (py::init) method, pybind11 is quite + /// strict about needing to return a pointer that is not yet associated to + /// an py::object. Since the forContext() method acts like a pool, possibly + /// returning a recycled context, it does not satisfy this need. The usual + /// way in python to accomplish such a thing is to override __new__, but + /// that is also not supported by pybind11. Instead, we use this entry + /// point which always constructs a fresh context (which cannot alias an + /// existing one because it is fresh). + static PyMlirContext *createNewContextForInit(); + /// Returns a context reference for the singleton PyMlirContext wrapper for /// the given context. static PyMlirContextRef forContext(MlirContext context); @@ -63,29 +102,37 @@ /// Gets a strong reference to this context, which will ensure it is kept /// alive for the life of the reference. PyMlirContextRef getRef() { - return PyMlirContextRef( - *this, pybind11::reinterpret_borrow(handle)); + return PyMlirContextRef(this, pybind11::cast(this)); } /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); + /// Gets the count of live operations associated with this context. + /// Used for testing. + size_t getLiveOperationCount(); + private: PyMlirContext(MlirContext context); - // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, // preserving the relationship that an MlirContext maps to a single // PyMlirContext wrapper. This could be replaced in the future with an // extension mechanism on the MlirContext for stashing user pointers. // Note that this holds a handle, which does not imply ownership. // Mappings will be removed when the context is destructed. - using LiveContextMap = - llvm::DenseMap>; + using LiveContextMap = llvm::DenseMap; static LiveContextMap &getLiveContexts(); + // Interns all live operations associated with this context. Operations + // tracked in this map are valid. When an operation is invalidated, it is + // removed from this map, and while it still exists as an instance, any + // attempt to access it will raise an error. + using LiveOperationMap = + llvm::DenseMap>; + LiveOperationMap liveOperations; + MlirContext context; - // The handle is set as part of lookup with forContext() (post construction). - pybind11::handle handle; + friend class PyOperation; }; /// Base class for all objects that directly or indirectly depend on an @@ -94,7 +141,10 @@ /// Immutable objects that depend on a context extend this directly. class BaseContextObject { public: - BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {} + BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { + assert(this->contextRef && + "context object constructed with null context ref"); + } /// Accesses the context reference. PyMlirContextRef &getContext() { return contextRef; } @@ -112,22 +162,88 @@ }; /// Wrapper around MlirModule. +/// This is the top-level, user-owned object that contains regions/ops/blocks. +class PyModule; +using PyModuleRef = PyObjectRef; class PyModule : public BaseContextObject { public: - PyModule(PyMlirContextRef contextRef, MlirModule module) - : BaseContextObject(std::move(contextRef)), module(module) {} + /// Creates a reference to the module + static PyModuleRef create(PyMlirContextRef contextRef, MlirModule module); PyModule(PyModule &) = delete; - PyModule(PyModule &&other) - : BaseContextObject(std::move(other.getContext())) { - module = other.module; - other.module.ptr = nullptr; - } ~PyModule() { if (module.ptr) mlirModuleDestroy(module); } + /// Gets the backing MlirModule. + MlirModule get() { return module; } + + /// Gets a strong reference to this module. + PyModuleRef getRef() { + return PyModuleRef(this, + pybind11::reinterpret_borrow(handle)); + } + +private: + PyModule(PyMlirContextRef contextRef, MlirModule module) + : BaseContextObject(std::move(contextRef)), module(module) {} MlirModule module; + pybind11::handle handle; +}; + +/// Wrapper around PyOperation. +/// Operations exist in either an attached or detached state. In the detached +/// state (as on creation), an operation is owned by the creator and its +/// lifetime extends either until its reference count dropping to zero or it +/// being attached to a parent. +/// When attached to a parent, the operation will capture a reference to its +/// parent PyOperation. Attached operations are valid until they are removed +/// by their parent or a bulk IR modification takes place. +class PyOperation; +using PyOperationRef = PyObjectRef; +class PyOperation : public BaseContextObject { +public: + ~PyOperation(); + /// Returns a PyOperation for the given MlirOperation, optionally associating + /// it with a parentKeepAlive (which must match on all such calls for the + /// same operation). + static PyOperationRef + forOperation(PyMlirContextRef contextRef, MlirOperation operation, + pybind11::object parentKeepAlive = pybind11::object()); + + /// Creates a detached operation. The operation must not be associated with + /// any existing live operation. + static PyOperationRef + createDetached(PyMlirContextRef contextRef, MlirOperation operation, + pybind11::object parentKeepAlive = pybind11::object()); + + /// Gets the backing operation. + MlirOperation get() { + checkValid(); + return operation; + } + + PyOperationRef getRef() { + return PyOperationRef( + this, pybind11::reinterpret_borrow(handle)); + } + + bool isAttached() { return attached; } + void checkValid(); + +private: + PyOperation(PyMlirContextRef contextRef, MlirOperation operation); + static PyOperationRef createInstance(PyMlirContextRef contextRef, + MlirOperation operation, + pybind11::object parentKeepAlive); + + MlirOperation operation; + pybind11::handle handle; + // Keeps the parent alive, regardless of whether it is an Operation or + // Module. + pybind11::object parentKeepAlive; + bool attached = true; + bool valid = true; }; /// Wrapper around an MlirRegion. diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -174,13 +174,12 @@ // PyMlirContext //------------------------------------------------------------------------------ -PyMlirContext *PyMlirContextRef::release() { - object.release(); - return &referrent; +PyMlirContext::PyMlirContext(MlirContext context) : context(context) { + py::gil_scoped_acquire acquire; + auto &liveContexts = getLiveContexts(); + liveContexts[context.ptr] = this; } -PyMlirContext::PyMlirContext(MlirContext context) : context(context) {} - PyMlirContext::~PyMlirContext() { // Note that the only public way to construct an instance is via the // forContext method, which always puts the associated handle into @@ -190,6 +189,11 @@ mlirContextDestroy(context); } +PyMlirContext *PyMlirContext::createNewContextForInit() { + MlirContext context = mlirContextCreate(); + return new PyMlirContext(context); +} + PyMlirContextRef PyMlirContext::forContext(MlirContext context) { py::gil_scoped_acquire acquire; auto &liveContexts = getLiveContexts(); @@ -198,14 +202,13 @@ // Create. PyMlirContext *unownedContextWrapper = new PyMlirContext(context); py::object pyRef = py::cast(unownedContextWrapper); - unownedContextWrapper->handle = pyRef; - liveContexts[context.ptr] = std::make_pair(pyRef, unownedContextWrapper); - return PyMlirContextRef(*unownedContextWrapper, std::move(pyRef)); - } else { - // Use existing. - py::object pyRef = py::reinterpret_borrow(it->second.first); - return PyMlirContextRef(*it->second.second, std::move(pyRef)); + assert(pyRef && "cast to py::object failed"); + liveContexts[context.ptr] = unownedContextWrapper; + return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); } + // Use existing. + py::object pyRef = py::cast(it->second); + return PyMlirContextRef(it->second, std::move(pyRef)); } PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { @@ -215,8 +218,99 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } +size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } + +//------------------------------------------------------------------------------ +// PyModule +//------------------------------------------------------------------------------ + +PyModuleRef PyModule::create(PyMlirContextRef contextRef, MlirModule module) { + PyModule *unownedModule = new PyModule(std::move(contextRef), module); + // Note that the default return value policy on cast is automatic_reference, + // which does not take ownership (delete will not be called). + // Just be explicit. + py::object pyRef = + py::cast(unownedModule, py::return_value_policy::take_ownership); + unownedModule->handle = pyRef; + return PyModuleRef(unownedModule, std::move(pyRef)); +} + +//------------------------------------------------------------------------------ +// PyOperation +//------------------------------------------------------------------------------ + +PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) + : BaseContextObject(std::move(contextRef)), operation(operation) {} + +PyOperation::~PyOperation() { + auto &liveOperations = getContext()->liveOperations; + assert(liveOperations.count(operation.ptr) == 1 && + "destroying operation not in live map"); + liveOperations.erase(operation.ptr); + if (!isAttached()) { + mlirOperationDestroy(operation); + } +} + +PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, + MlirOperation operation, + py::object parentKeepAlive) { + auto &liveOperations = contextRef->liveOperations; + // Create. + PyOperation *unownedOperation = + new PyOperation(std::move(contextRef), operation); + // Note that the default return value policy on cast is automatic_reference, + // which does not take ownership (delete will not be called). + // Just be explicit. + py::object pyRef = + py::cast(unownedOperation, py::return_value_policy::take_ownership); + unownedOperation->handle = pyRef; + if (parentKeepAlive) { + unownedOperation->parentKeepAlive = std::move(parentKeepAlive); + } + liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); + return PyOperationRef(unownedOperation, std::move(pyRef)); +} + +PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, + MlirOperation operation, + py::object parentKeepAlive) { + auto &liveOperations = contextRef->liveOperations; + auto it = liveOperations.find(operation.ptr); + if (it == liveOperations.end()) { + // Create. + return createInstance(std::move(contextRef), operation, + std::move(parentKeepAlive)); + } + // Use existing. + PyOperation *existing = it->second.second; + assert(existing->parentKeepAlive.is(parentKeepAlive)); + py::object pyRef = py::reinterpret_borrow(it->second.first); + return PyOperationRef(existing, std::move(pyRef)); +} + +PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, + MlirOperation operation, + py::object parentKeepAlive) { + auto &liveOperations = contextRef->liveOperations; + assert(liveOperations.count(operation.ptr) == 0 && + "cannot create detached operation that already exists"); + (void)liveOperations; + + PyOperationRef created = createInstance(std::move(contextRef), operation, + std::move(parentKeepAlive)); + created->attached = false; + return created; +} + +void PyOperation::checkValid() { + if (!valid) { + throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); + } +} + //------------------------------------------------------------------------------ -// PyBlock, PyRegion, and PyOperation. +// PyBlock, PyRegion. //------------------------------------------------------------------------------ void PyRegion::attachToParent() { @@ -865,29 +959,27 @@ void mlir::python::populateIRSubmodule(py::module &m) { // Mapping of MlirContext py::class_(m, "Context") - .def(py::init<>([]() { - MlirContext context = mlirContextCreate(); - auto contextRef = PyMlirContext::forContext(context); - return contextRef.release(); - })) + .def(py::init<>(&PyMlirContext::createNewContextForInit)) .def_static("_get_live_count", &PyMlirContext::getLiveCount) .def("_get_context_again", [](PyMlirContext &self) { - auto ref = PyMlirContext::forContext(self.get()); - return ref.release(); + PyMlirContextRef ref = PyMlirContext::forContext(self.get()); + return ref.releaseObject(); }) + .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) .def( "parse_module", - [](PyMlirContext &self, const std::string module) { - auto moduleRef = mlirModuleCreateParse(self.get(), module.c_str()); + [](PyMlirContext &self, const std::string moduleAsm) { + MlirModule module = + mlirModuleCreateParse(self.get(), moduleAsm.c_str()); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. - if (mlirModuleIsNull(moduleRef)) { + if (mlirModuleIsNull(module)) { throw SetPyError( PyExc_ValueError, "Unable to parse module assembly (see diagnostics)"); } - return PyModule(self.getRef(), moduleRef); + return PyModule::create(self.getRef(), module).releaseObject(); }, kContextParseDocstring) .def( @@ -975,16 +1067,25 @@ // Mapping of Module py::class_(m, "Module") + .def_property_readonly( + "operation", + [](PyModule &self) { + return PyOperation::forOperation(self.getContext(), + mlirModuleGetOperation(self.get()), + self.getRef().releaseObject()) + .releaseObject(); + }, + "Accesses the module as an operation") .def( "dump", [](PyModule &self) { - mlirOperationDump(mlirModuleGetOperation(self.module)); + mlirOperationDump(mlirModuleGetOperation(self.get())); }, kDumpDocstring) .def( "__str__", [](PyModule &self) { - auto operation = mlirModuleGetOperation(self.module); + MlirOperation operation = mlirModuleGetOperation(self.get()); PyPrintAccumulator printAccum; mlirOperationPrint(operation, printAccum.getCallback(), printAccum.getUserData()); @@ -992,6 +1093,31 @@ }, kOperationStrDunderDocstring); + // Mapping of Operation. + py::class_(m, "Operation") + .def_property_readonly( + "first_region", + [](PyOperation &self) { + self.checkValid(); + if (mlirOperationGetNumRegions(self.get()) == 0) { + throw SetPyError(PyExc_IndexError, "Operation has no regions"); + } + return PyRegion(self.getContext()->get(), + mlirOperationGetRegion(self.get(), 0), + /*detached=*/false); + }, + py::keep_alive<0, 1>(), "Gets the operation's first region") + .def( + "__str__", + [](PyOperation &self) { + self.checkValid(); + PyPrintAccumulator printAccum; + mlirOperationPrint(self.get(), printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + kTypeStrDunderDocstring); + // Mapping of PyRegion. py::class_(m, "Region") .def( diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -1,16 +1,21 @@ # RUN: %PYTHON %s | FileCheck %s +import gc import mlir def run(f): print("\nTEST:", f.__name__) f() + gc.collect() + assert mlir.ir.Context._get_live_count() == 0 # CHECK-LABEL: TEST: testParsePrint def testParsePrint(): ctx = mlir.ir.Context() t = ctx.parse_attr('"hello"') + ctx = None + gc.collect() # CHECK: "hello" print(str(t)) # CHECK: Attribute("hello") diff --git a/mlir/test/Bindings/Python/ir_location.py b/mlir/test/Bindings/Python/ir_location.py --- a/mlir/test/Bindings/Python/ir_location.py +++ b/mlir/test/Bindings/Python/ir_location.py @@ -1,15 +1,21 @@ # RUN: %PYTHON %s | FileCheck %s +import gc import mlir def run(f): print("\nTEST:", f.__name__) f() + gc.collect() + assert mlir.ir.Context._get_live_count() == 0 + # CHECK-LABEL: TEST: testUnknown def testUnknown(): ctx = mlir.ir.Context() loc = ctx.get_unknown_location() + ctx = None + gc.collect() # CHECK: unknown str: loc(unknown) print("unknown str:", str(loc)) # CHECK: unknown repr: loc(unknown) @@ -22,6 +28,8 @@ def testFileLineCol(): ctx = mlir.ir.Context() loc = ctx.get_file_location("foo.txt", 123, 56) + ctx = None + gc.collect() # CHECK: file str: loc("foo.txt":123:56) print("file str:", str(loc)) # CHECK: file repr: loc("foo.txt":123:56) diff --git a/mlir/test/Bindings/Python/ir_module.py b/mlir/test/Bindings/Python/ir_module.py --- a/mlir/test/Bindings/Python/ir_module.py +++ b/mlir/test/Bindings/Python/ir_module.py @@ -1,10 +1,14 @@ # RUN: %PYTHON %s | FileCheck %s +import gc import mlir def run(f): print("\nTEST:", f.__name__) f() + gc.collect() + assert mlir.ir.Context._get_live_count() == 0 + # Verify successful parse. # CHECK-LABEL: TEST: testParseSuccess @@ -12,6 +16,9 @@ def testParseSuccess(): ctx = mlir.ir.Context() module = ctx.parse_module(r"""module @successfulParse {}""") + print("CLEAR CONTEXT") + ctx = None # Ensure that module captures the context. + gc.collect() module.dump() # Just outputs to stderr. Verifies that it functions. print(str(module)) @@ -47,3 +54,33 @@ print(str(module)) run(testRoundtripUnicode) + + +# Tests that module.operation works and correctly interns instances. +# CHECK-LABEL: TEST: testModuleOperation +def testModuleOperation(): + ctx = mlir.ir.Context() + module = ctx.parse_module(r"""module @successfulParse {}""") + op1 = module.operation + assert ctx._get_live_operation_count() == 1 + # CHECK: module @successfulParse + print(op1) + + # Ensure that operations are the same on multiple calls. + op2 = module.operation + assert ctx._get_live_operation_count() == 1 + assert op1 is op2 + + # Ensure that if module is de-referenced, the operations are still valid. + module = None + gc.collect() + print(op1) + + # Collect and verify lifetime. + op1 = None + op2 = None + gc.collect() + print("LIVE OPERATIONS:", ctx._get_live_operation_count()) + assert ctx._get_live_operation_count() == 0 + +run(testModuleOperation) diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -1,10 +1,13 @@ # RUN: %PYTHON %s | FileCheck %s +import gc import mlir def run(f): print("\nTEST:", f.__name__) f() + gc.collect() + assert mlir.ir.Context._get_live_count() == 0 # CHECK-LABEL: TEST: testDetachedRegionBlock diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -1,16 +1,21 @@ # RUN: %PYTHON %s | FileCheck %s +import gc import mlir def run(f): print("\nTEST:", f.__name__) f() + gc.collect() + assert mlir.ir.Context._get_live_count() == 0 # CHECK-LABEL: TEST: testParsePrint def testParsePrint(): ctx = mlir.ir.Context() t = ctx.parse_type("i32") + ctx = None + gc.collect() # CHECK: i32 print(str(t)) # CHECK: Type(i32)