diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -74,8 +74,7 @@ /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. void registerOperationImpl(const std::string &operationName, - pybind11::object pyClass, - pybind11::object rawOpViewClass); + pybind11::object pyClass); /// Returns the custom Attribute builder for Attribute kind. std::optional @@ -86,10 +85,11 @@ std::optional lookupDialectClass(const std::string &dialectNamespace); - /// Looks up a registered raw OpView class by operation name. Note that this - /// may trigger a load of the dialect, which can arbitrarily re-enter. + /// Looks up a registered operation class (deriving from OpView) by operation + /// name. Note that this may trigger a load of the dialect, which can + /// arbitrarily re-enter. std::optional - lookupRawOpViewClass(llvm::StringRef operationName); + lookupOperationClass(llvm::StringRef operationName); private: static PyGlobals *instance; @@ -99,21 +99,16 @@ llvm::StringMap dialectClassMap; /// Map of full operation name to external operation class object. llvm::StringMap operationClassMap; - /// Map of operation name to custom subclass that directly initializes - /// the OpView base class (bypassing the user class constructor). - llvm::StringMap rawOpViewClassMap; /// Map of attribute ODS name to custom builder. llvm::StringMap attributeBuilderMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModulesCache; - /// Cache of operation name to custom OpView subclass that directly - /// initializes the OpView base class (or an undefined object for negative - /// lookup). This is maintained on loopup as a shadow of rawOpViewClassMap - /// in order for repeat lookups of the OpView classes to only incur the cost - /// of one hashtable lookup. - llvm::StringMap rawOpViewClassMapCache; + /// Cache of operation name to external operation class object. This is + /// maintained on lookup as a shadow of operationClassMap in order for repeat + /// lookups of the classes to only incur the cost of one hashtable lookup. + llvm::StringMap operationClassMapCache; }; } // namespace python diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1339,10 +1339,10 @@ checkValid(); MlirIdentifier ident = mlirOperationGetName(get()); MlirStringRef identStr = mlirIdentifierStr(ident); - auto opViewClass = PyGlobals::get().lookupRawOpViewClass( + auto operationCls = PyGlobals::get().lookupOperationClass( StringRef(identStr.data, identStr.length)); - if (opViewClass) - return (*opViewClass)(getRef().getObject()); + if (operationCls) + return PyOpView::constructDerived(*operationCls, *getRef().get()); return py::cast(PyOpView(getRef().getObject())); } @@ -1618,47 +1618,23 @@ /*regions=*/*regions, location, maybeIp); } +pybind11::object PyOpView::constructDerived(const pybind11::object &cls, + const PyOperation &operation) { + // TODO: pybind11 2.6 supports a more direct form. + // Upgrade many years from now. + // auto opViewType = py::type::of(); + py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true); + py::object instance = cls.attr("__new__")(cls); + opViewType.attr("__init__")(instance, operation); + return instance; +} + PyOpView::PyOpView(const py::object &operationObject) // Casting through the PyOperationBase base-class and then back to the // Operation lets us accept any PyOperationBase subclass. : operation(py::cast(operationObject).getOperation()), operationObject(operation.getRef().getObject()) {} -py::object PyOpView::createRawSubclass(const py::object &userClass) { - // This is... a little gross. The typical pattern is to have a pure python - // class that extends OpView like: - // class AddFOp(_cext.ir.OpView): - // def __init__(self, loc, lhs, rhs): - // operation = loc.context.create_operation( - // "addf", lhs, rhs, results=[lhs.type]) - // super().__init__(operation) - // - // I.e. The goal of the user facing type is to provide a nice constructor - // that has complete freedom for the op under construction. This is at odds - // with our other desire to sometimes create this object by just passing an - // operation (to initialize the base class). We could do *arg and **kwargs - // munging to try to make it work, but instead, we synthesize a new class - // on the fly which extends this user class (AddFOp in this example) and - // *give it* the base class's __init__ method, thus bypassing the - // intermediate subclass's __init__ method entirely. While slightly, - // underhanded, this is safe/legal because the type hierarchy has not changed - // (we just added a new leaf) and we aren't mucking around with __new__. - // Typically, this new class will be stored on the original as "_Raw" and will - // be used for casts and other things that need a variant of the class that - // is initialized purely from an operation. - py::object parentMetaclass = - py::reinterpret_borrow((PyObject *)&PyType_Type); - py::dict attributes; - // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from - // now. - // auto opViewType = py::type::of(); - auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); - attributes["__init__"] = opViewType.attr("__init__"); - py::str origName = userClass.attr("__name__"); - py::str newName = py::str("_") + origName; - return parentMetaclass(newName, py::make_tuple(userClass), attributes); -} - //------------------------------------------------------------------------------ // PyInsertionPoint. //------------------------------------------------------------------------------ @@ -2863,7 +2839,7 @@ throw py::value_error( "Expected a '" + clsOpName + "' op, got: '" + std::string(parsedOpName.data, parsedOpName.length) + "'"); - return cls.attr("_Raw")(parsed.getObject()); + return PyOpView::constructDerived(cls, *parsed.get()); }, py::arg("cls"), py::arg("source"), py::kw_only(), py::arg("source_name") = "", py::arg("context") = py::none(), diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -654,8 +654,6 @@ PyOpView(const pybind11::object &operationObject); PyOperation &getOperation() override { return operation; } - static pybind11::object createRawSubclass(const pybind11::object &userClass); - pybind11::object getOperationObject() { return operationObject; } static pybind11::object @@ -666,6 +664,16 @@ std::optional regions, DefaultingPyLocation location, const pybind11::object &maybeIp); + /// Construct an instance of a class deriving from OpView, bypassing its + /// `__init__` method. The derived class will typically define a constructor + /// that provides a convenient builder, but we need to side-step this when + /// constructing an `OpView` for an already-built operation. + /// + /// The caller is responsible for verifying that `operation` is a valid + /// operation to construct `cls` with. + static pybind11::object constructDerived(const pybind11::object &cls, + const PyOperation &operation); + private: PyOperation &operation; // For efficient, cast-free access from C++ pybind11::object operationObject; // Holds the reference. diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -84,8 +84,7 @@ } void PyGlobals::registerOperationImpl(const std::string &operationName, - py::object pyClass, - py::object rawOpViewClass) { + py::object pyClass) { py::object &found = operationClassMap[operationName]; if (found) { throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + @@ -93,7 +92,6 @@ "' is already registered."); } found = std::move(pyClass); - rawOpViewClassMap[operationName] = std::move(rawOpViewClass); } std::optional @@ -130,10 +128,10 @@ } std::optional -PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { +PyGlobals::lookupOperationClass(llvm::StringRef operationName) { { - auto foundIt = rawOpViewClassMapCache.find(operationName); - if (foundIt != rawOpViewClassMapCache.end()) { + auto foundIt = operationClassMapCache.find(operationName); + if (foundIt != operationClassMapCache.end()) { if (foundIt->second.is_none()) return std::nullopt; assert(foundIt->second && "py::object is defined"); @@ -148,22 +146,22 @@ // Attempt to find from the canonical map and cache. { - auto foundIt = rawOpViewClassMap.find(operationName); - if (foundIt != rawOpViewClassMap.end()) { + auto foundIt = operationClassMap.find(operationName); + if (foundIt != operationClassMap.end()) { if (foundIt->second.is_none()) return std::nullopt; assert(foundIt->second && "py::object is defined"); // Positive cache. - rawOpViewClassMapCache[operationName] = foundIt->second; + operationClassMapCache[operationName] = foundIt->second; return foundIt->second; } // Negative cache. - rawOpViewClassMap[operationName] = py::none(); + operationClassMap[operationName] = py::none(); return std::nullopt; } } void PyGlobals::clearImportCache() { loadedDialectModulesCache.clear(); - rawOpViewClassMapCache.clear(); + operationClassMapCache.clear(); } diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -41,7 +41,6 @@ "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, py::arg("operation_name"), py::arg("operation_class"), - py::arg("raw_opview_class"), "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage @@ -68,18 +67,11 @@ [dialectClass](py::object opClass) -> py::object { std::string operationName = opClass.attr("OPERATION_NAME").cast(); - auto rawSubclass = PyOpView::createRawSubclass(opClass); - PyGlobals::get().registerOperationImpl(operationName, opClass, - rawSubclass); + PyGlobals::get().registerOperationImpl(operationName, opClass); // Dict-stuff the new opClass by name onto the dialect class. py::object opClassName = opClass.attr("__name__"); dialectClass.attr(opClassName) = opClass; - - // Now create a special "Raw" subclass that passes through - // construction to the OpView parent (bypasses the intermediate - // child's __init__). - opClass.attr("_Raw") = rawSubclass; return opClass; }); }, diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi @@ -5,7 +5,7 @@ class _Globals: dialect_search_modules: List[str] def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ... - def _register_operation_impl(self, operation_name: str, operation_class: type, raw_opview_class: type) -> None: ... + def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ... def append_dialect_search_prefix(self, module_name: str) -> None: ... def register_dialect(dialect_class: type) -> object: ... diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -620,7 +620,7 @@ # addf should map to a known OpView class in the arithmetic dialect. # We know the OpView for it defines an 'lhs' attribute. addf = module.body.operations[2] - # CHECK: