diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -42,13 +42,17 @@ dialectSearchPrefixes.swap(newValues); } + /// Clears positive and negative caches regarding what implementations are + /// available. Future lookups will do more expensive existence checks. + void clearImportCache(); + /// Loads a python module corresponding to the given dialect namespace. /// No-ops if the module has already been loaded or is not found. Raises /// an error on any evaluation issues. /// Note that this returns void because it is expected that the module /// contains calls to decorators and helpers that register the salient /// entities. - void loadDialectModule(const std::string &dialectNamespace); + void loadDialectModule(llvm::StringRef dialectNamespace); /// Decorator for registering a custom Dialect class. The class object must /// have a DIALECT_NAMESPACE attribute. @@ -65,27 +69,39 @@ /// This is intended to be called by implementation code. void registerOperationImpl(const std::string &operationName, pybind11::object pyClass, - pybind11::object rawClass); + pybind11::object rawOpViewClass); /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. llvm::Optional lookupDialectClass(const std::string &dialectNamespace); + /// Looks up a registered raw OpView class by operation name. Note that this + /// may trigger a load of the dialect, which can arbitrarily re-enter. + llvm::Optional + lookupRawOpViewClass(llvm::StringRef operationName); + private: static PyGlobals *instance; /// Module name prefixes to search under for dialect implementation modules. std::vector dialectSearchPrefixes; - /// Map of dialect namespace to bool flag indicating whether the module has - /// been successfully loaded or resolved to not found. - llvm::StringSet<> loadedDialectModules; /// Map of dialect namespace to external dialect class object. llvm::StringMap dialectClassMap; /// Map of full operation name to external operation class object. llvm::StringMap operationClassMap; /// Map of operation name to custom subclass that directly initializes /// the OpView base class (bypassing the user class constructor). - llvm::StringMap rawOperationClassMap; + llvm::StringMap rawOpViewClassMap; + + /// Set of dialect namespaces that we have attempted to import implementation + /// modules for. + llvm::StringSet<> loadedDialectModulesCache; + /// Cache of operation name to custom OpView subclass that directly + /// initializes the OpView base class (or an undefined object for negative + /// lookup). This is maintained on loopup as a shadow of rawOpViewClassMap + /// in order for repeat lookups of the OpView classes to only incur the cost + /// of one hashtable lookup. + llvm::StringMap rawOpViewClassMapCache; }; } // namespace python diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -366,6 +366,24 @@ pybind11::handle handle; }; +/// Base class for PyOperation and PyOpView which exposes the primary, user +/// visible methods for manipulating it. +class PyOperationBase { +public: + virtual ~PyOperationBase() = default; + /// Implements the bound 'print' method and helps with others. + void print(pybind11::object fileObject, bool binary, + llvm::Optional largeElementsLimit, bool enableDebugInfo, + bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); + pybind11::object getAsm(bool binary, + llvm::Optional largeElementsLimit, + bool enableDebugInfo, bool prettyDebugInfo, + bool printGenericOpForm, bool useLocalScope); + + /// Each must provide access to the raw Operation. + virtual PyOperation &getOperation() = 0; +}; + /// Wrapper around PyOperation. /// Operations exist in either an attached (dependent) or detached (top-level) /// state. In the detached state (as on creation), an operation is owned by @@ -374,9 +392,11 @@ /// is bounded by its top-level parent reference. class PyOperation; using PyOperationRef = PyObjectRef; -class PyOperation : public BaseContextObject { +class PyOperation : public PyOperationBase, public BaseContextObject { public: ~PyOperation(); + PyOperation &getOperation() override { return *this; } + /// Returns a PyOperation for the given MlirOperation, optionally associating /// it with a parentKeepAlive. static PyOperationRef @@ -407,15 +427,6 @@ } void checkValid(); - /// Implements the bound 'print' method and helps with others. - void print(pybind11::object fileObject, bool binary, - llvm::Optional largeElementsLimit, bool enableDebugInfo, - bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); - pybind11::object getAsm(bool binary, - llvm::Optional largeElementsLimit, - bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, bool useLocalScope); - /// Gets the owning block or raises an exception if the operation has no /// owning block. PyBlock getBlock(); @@ -432,6 +443,9 @@ llvm::Optional> successors, int regions, DefaultingPyLocation location, pybind11::object ip); + /// Creates an OpView suitable for this operation. + pybind11::object createOpView(); + private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, @@ -456,17 +470,18 @@ /// custom ODS-style operation classes. Since this class is subclass on the /// python side, it must present an __init__ method that operates in pure /// python types. -class PyOpView { +class PyOpView : public PyOperationBase { public: - PyOpView(pybind11::object operation); + PyOpView(pybind11::object operationObject); + PyOperation &getOperation() override { return operation; } static pybind11::object createRawSubclass(pybind11::object userClass); pybind11::object getOperationObject() { return operationObject; } private: + PyOperation &operation; // For efficient, cast-free access from C++ pybind11::object operationObject; // Holds the reference. - PyOperation *operation; // For efficient, cast-free access from C++ }; /// Wrapper around an MlirRegion. @@ -519,7 +534,7 @@ /// block, but still inside the block. PyInsertionPoint(PyBlock &block); /// Creates an insertion point positioned before a reference operation. - PyInsertionPoint(PyOperation &beforeOperation); + PyInsertionPoint(PyOperationBase &beforeOperationBase); /// Shortcut to create an insertion point at the beginning of the block. static PyInsertionPoint atBlockBegin(PyBlock &block); @@ -527,7 +542,7 @@ static PyInsertionPoint atBlockTerminator(PyBlock &block); /// Inserts an operation. - void insert(PyOperation &operation); + void insert(PyOperationBase &operationBase); /// Enter and exit the context manager. pybind11::object contextEnter(); @@ -540,10 +555,10 @@ // Trampoline constructor that avoids null initializing members while // looking up parents. PyInsertionPoint(PyBlock block, llvm::Optional refOperation) - : block(std::move(block)), refOperation(std::move(refOperation)) {} + : refOperation(std::move(refOperation)), block(std::move(block)) {} - PyBlock block; llvm::Optional refOperation; + PyBlock block; }; /// Wrapper around the generic MlirAttribute. diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -407,7 +407,7 @@ PyOperationRef returnOperation = PyOperation::forOperation(parentOperation->getContext(), next); next = mlirOperationGetNextInBlock(next); - return returnOperation.releaseObject(); + return returnOperation->createOpView(); } static void bind(py::module &m) { @@ -457,7 +457,7 @@ while (!mlirOperationIsNull(childOp)) { if (index == 0) { return PyOperation::forOperation(parentOperation->getContext(), childOp) - .releaseObject(); + ->createOpView(); } childOp = mlirOperationGetNextInBlock(childOp); index -= 1; @@ -868,11 +868,12 @@ } } -void PyOperation::print(py::object fileObject, bool binary, - llvm::Optional largeElementsLimit, - bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, bool useLocalScope) { - checkValid(); +void PyOperationBase::print(py::object fileObject, bool binary, + llvm::Optional largeElementsLimit, + bool enableDebugInfo, bool prettyDebugInfo, + bool printGenericOpForm, bool useLocalScope) { + PyOperation &operation = getOperation(); + operation.checkValid(); if (fileObject.is_none()) fileObject = py::module::import("sys").attr("stdout"); MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); @@ -885,15 +886,16 @@ PyFileAccumulator accum(fileObject, binary); py::gil_scoped_release(); - mlirOperationPrintWithFlags(get(), flags, accum.getCallback(), + mlirOperationPrintWithFlags(operation.get(), flags, accum.getCallback(), accum.getUserData()); mlirOpPrintingFlagsDestroy(flags); } -py::object PyOperation::getAsm(bool binary, - llvm::Optional largeElementsLimit, - bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, bool useLocalScope) { +py::object PyOperationBase::getAsm(bool binary, + llvm::Optional largeElementsLimit, + bool enableDebugInfo, bool prettyDebugInfo, + bool printGenericOpForm, + bool useLocalScope) { py::object fileObject; if (binary) { fileObject = py::module::import("io").attr("BytesIO")(); @@ -1034,12 +1036,24 @@ ip->insert(*created.get()); } - return created.releaseObject(); + return created->createOpView(); } -PyOpView::PyOpView(py::object operation) - : operationObject(std::move(operation)), - operation(py::cast(this->operationObject)) {} +py::object PyOperation::createOpView() { + MlirIdentifier ident = mlirOperationGetName(get()); + MlirStringRef identStr = mlirIdentifierStr(ident); + auto opViewClass = PyGlobals::get().lookupRawOpViewClass( + llvm::StringRef(identStr.data, identStr.length)); + if (opViewClass) + return (*opViewClass)(getRef().getObject()); + return py::cast(PyOpView(getRef().getObject())); +} + +PyOpView::PyOpView(py::object operationObject) + // Casting through the PyOperationBase base-class and then back to the + // Operation lets us accept any PyOperationBase subclass. + : operation(py::cast(operationObject).getOperation()), + operationObject(operation.getRef().getObject()) {} py::object PyOpView::createRawSubclass(py::object userClass) { // This is... a little gross. The typical pattern is to have a pure python @@ -1082,11 +1096,12 @@ PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} -PyInsertionPoint::PyInsertionPoint(PyOperation &beforeOperation) - : block(beforeOperation.getBlock()), - refOperation(beforeOperation.getRef()) {} +PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) + : refOperation(beforeOperationBase.getOperation().getRef()), + block((*refOperation)->getBlock()) {} -void PyInsertionPoint::insert(PyOperation &operation) { +void PyInsertionPoint::insert(PyOperationBase &operationBase) { + PyOperation &operation = operationBase.getOperation(); if (operation.isAttached()) throw SetPyError(PyExc_ValueError, "Attempt to insert operation that is already attached"); @@ -2501,33 +2516,36 @@ //---------------------------------------------------------------------------- // Mapping of Operation. //---------------------------------------------------------------------------- - py::class_(m, "Operation") - .def_static("create", &PyOperation::create, py::arg("name"), - py::arg("operands") = py::none(), - py::arg("results") = py::none(), - py::arg("attributes") = py::none(), - py::arg("successors") = py::none(), py::arg("regions") = 0, - py::arg("loc") = py::none(), py::arg("ip") = py::none(), - kOperationCreateDocstring) - .def_property_readonly( - "context", - [](PyOperation &self) { return self.getContext().getObject(); }, - "Context that owns the Operation") - .def_property_readonly( - "operands", - [](PyOperation &self) { return PyOpOperandList(self.getRef()); }) - .def_property_readonly( - "regions", - [](PyOperation &self) { return PyRegionList(self.getRef()); }) + py::class_(m, "_OperationBase") + .def("__eq__", + [](PyOperationBase &self, PyOperationBase &other) { + return &self.getOperation() == &other.getOperation(); + }) + .def("__eq__", + [](PyOperationBase &self, py::object other) { return false; }) + .def_property_readonly("operands", + [](PyOperationBase &self) { + return PyOpOperandList( + self.getOperation().getRef()); + }) + .def_property_readonly("regions", + [](PyOperationBase &self) { + return PyRegionList( + self.getOperation().getRef()); + }) .def_property_readonly( "results", - [](PyOperation &self) { return PyOpResultList(self.getRef()); }, + [](PyOperationBase &self) { + return PyOpResultList(self.getOperation().getRef()); + }, "Returns the list of Operation results.") .def("__iter__", - [](PyOperation &self) { return PyRegionIterator(self.getRef()); }) + [](PyOperationBase &self) { + return PyRegionIterator(self.getOperation().getRef()); + }) .def( "__str__", - [](PyOperation &self) { + [](PyOperationBase &self) { return self.getAsm(/*binary=*/false, /*largeElementsLimit=*/llvm::None, /*enableDebugInfo=*/false, @@ -2536,7 +2554,7 @@ /*useLocalScope=*/false); }, "Returns the assembly form of the operation.") - .def("print", &PyOperation::print, + .def("print", &PyOperationBase::print, // Careful: Lots of arguments must match up with print method. py::arg("file") = py::none(), py::arg("binary") = false, py::arg("large_elements_limit") = py::none(), @@ -2544,7 +2562,7 @@ py::arg("pretty_debug_info") = false, py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, kOperationPrintDocstring) - .def("get_asm", &PyOperation::getAsm, + .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. py::arg("binary") = false, py::arg("large_elements_limit") = py::none(), @@ -2553,9 +2571,29 @@ py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, kOperationGetAsmDocstring); - py::class_(m, "OpView") + py::class_(m, "Operation") + .def_static("create", &PyOperation::create, py::arg("name"), + py::arg("operands") = py::none(), + py::arg("results") = py::none(), + py::arg("attributes") = py::none(), + py::arg("successors") = py::none(), py::arg("regions") = 0, + py::arg("loc") = py::none(), py::arg("ip") = py::none(), + kOperationCreateDocstring) + .def_property_readonly( + "context", + [](PyOperation &self) { return self.getContext().getObject(); }, + "Context that owns the Operation") + .def_property_readonly("opview", &PyOperation::createOpView); + + py::class_(m, "OpView") .def(py::init()) .def_property_readonly("operation", &PyOpView::getOperationObject) + .def_property_readonly( + "context", + [](PyOpView &self) { + return self.getOperation().getContext().getObject(); + }, + "Context that owns the Operation") .def("__str__", [](PyOpView &self) { return py::str(self.getOperationObject()); }); @@ -2577,14 +2615,11 @@ return PyBlockIterator(self.getParentOperation(), firstBlock); }, "Iterates over blocks in the region.") - .def("__eq__", [](PyRegion &self, py::object &other) { - try { - PyRegion *otherRegion = other.cast(); - return self.get().ptr == otherRegion->get().ptr; - } catch (std::exception &e) { - return false; - } - }); + .def("__eq__", + [](PyRegion &self, PyRegion &other) { + return self.get().ptr == other.get().ptr; + }) + .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); //---------------------------------------------------------------------------- // Mapping of PyBlock. @@ -2613,14 +2648,10 @@ }, "Iterates over operations in the block.") .def("__eq__", - [](PyBlock &self, py::object &other) { - try { - PyBlock *otherBlock = other.cast(); - return self.get().ptr == otherBlock->get().ptr; - } catch (std::exception &e) { - return false; - } + [](PyBlock &self, PyBlock &other) { + return self.get().ptr == other.get().ptr; }) + .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) .def( "__str__", [](PyBlock &self) { @@ -2651,7 +2682,7 @@ }, "Gets the InsertionPoint bound to the current thread or raises " "ValueError if none has been set") - .def(py::init(), py::arg("beforeOperation"), + .def(py::init(), py::arg("beforeOperation"), "Inserts before a referenced operation.") .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, py::arg("block"), "Inserts at the beginning of the block.") @@ -2696,14 +2727,8 @@ }, py::keep_alive<0, 1>(), "Binds a name to the attribute") .def("__eq__", - [](PyAttribute &self, py::object &other) { - try { - PyAttribute otherAttribute = other.cast(); - return self == otherAttribute; - } catch (std::exception &e) { - return false; - } - }) + [](PyAttribute &self, PyAttribute &other) { return self == other; }) + .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) .def( "dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); }, kDumpDocstring) @@ -2793,15 +2818,8 @@ .def_property_readonly( "context", [](PyType &self) { return self.getContext().getObject(); }, "Context that owns the Type") - .def("__eq__", - [](PyType &self, py::object &other) { - try { - PyType otherType = other.cast(); - return self == otherType; - } catch (std::exception &e) { - return false; - } - }) + .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) + .def("__eq__", [](PyType &self, py::object &other) { return false; }) .def( "dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring) .def( diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -30,17 +30,19 @@ PyGlobals::~PyGlobals() { instance = nullptr; } -void PyGlobals::loadDialectModule(const std::string &dialectNamespace) { - if (loadedDialectModules.contains(dialectNamespace)) +void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { + py::gil_scoped_acquire(); + if (loadedDialectModulesCache.contains(dialectNamespace)) return; // Since re-entrancy is possible, make a copy of the search prefixes. std::vector localSearchPrefixes = dialectSearchPrefixes; py::object loaded; for (std::string moduleName : localSearchPrefixes) { moduleName.push_back('.'); - moduleName.append(dialectNamespace); + moduleName.append(dialectNamespace.data(), dialectNamespace.size()); try { + py::gil_scoped_release(); loaded = py::module::import(moduleName.c_str()); } catch (py::error_already_set &e) { if (e.matches(PyExc_ModuleNotFoundError)) { @@ -54,11 +56,12 @@ // Note: Iterator cannot be shared from prior to loading, since re-entrancy // may have occurred, which may do anything. - loadedDialectModules.insert(dialectNamespace); + loadedDialectModulesCache.insert(dialectNamespace); } void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, py::object pyClass) { + py::gil_scoped_acquire(); py::object &found = dialectClassMap[dialectNamespace]; if (found) { throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + @@ -69,7 +72,9 @@ } void PyGlobals::registerOperationImpl(const std::string &operationName, - py::object pyClass, py::object rawClass) { + py::object pyClass, + py::object rawOpViewClass) { + py::gil_scoped_acquire(); py::object &found = operationClassMap[operationName]; if (found) { throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + @@ -77,11 +82,12 @@ "' is already registered."); } found = std::move(pyClass); - rawOperationClassMap[operationName] = std::move(rawClass); + rawOpViewClassMap[operationName] = std::move(rawOpViewClass); } llvm::Optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { + py::gil_scoped_acquire(); loadDialectModule(dialectNamespace); // Fast match against the class map first (common case). const auto foundIt = dialectClassMap.find(dialectNamespace); @@ -97,6 +103,49 @@ return llvm::None; } +llvm::Optional +PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { + { + py::gil_scoped_acquire(); + auto foundIt = rawOpViewClassMapCache.find(operationName); + if (foundIt != rawOpViewClassMapCache.end()) { + if (foundIt->second.is_none()) + return llvm::None; + assert(foundIt->second && "py::object is defined"); + return foundIt->second; + } + } + + // Not found. Load the dialect namespace. + auto split = operationName.split('.'); + llvm::StringRef dialectNamespace = split.first; + loadDialectModule(dialectNamespace); + + // Attempt to find from the canonical map and cache. + { + py::gil_scoped_acquire(); + auto foundIt = rawOpViewClassMap.find(operationName); + if (foundIt != rawOpViewClassMap.end()) { + if (foundIt->second.is_none()) + return llvm::None; + assert(foundIt->second && "py::object is defined"); + // Positive cache. + rawOpViewClassMapCache[operationName] = foundIt->second; + return foundIt->second; + } else { + // Negative cache. + rawOpViewClassMap[operationName] = py::none(); + return llvm::None; + } + } +} + +void PyGlobals::clearImportCache() { + py::gil_scoped_acquire(); + loadedDialectModulesCache.clear(); + rawOpViewClassMapCache.clear(); +} + // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- @@ -111,6 +160,7 @@ .def("append_dialect_search_prefix", [](PyGlobals &self, std::string moduleName) { self.getDialectSearchPrefixes().push_back(std::move(moduleName)); + self.clearImportCache(); }) .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, "Testing hook for directly registering a dialect") diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -293,3 +293,34 @@ pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True) run(testOperationPrint) + + +def testKnownOpView(): + with Context(), Location.unknown(): + Context.current.allow_unregistered_dialects = True + module = Module.parse(r""" + %1 = "custom.f32"() : () -> f32 + %2 = "custom.f32"() : () -> f32 + %3 = addf %1, %2 : f32 + """) + print(module) + + # addf should map to a known OpView class in the std dialect. + # We know the OpView for it defines an 'lhs' attribute. + addf = module.body.operations[2] + # CHECK: