diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -48,7 +48,7 @@ /// Note that this returns void because it is expected that the module /// contains calls to decorators and helpers that register the salient /// entities. - void loadDialectModule(const std::string &dialectNamespace); + void loadDialectModule(llvm::StringRef dialectNamespace); /// Decorator for registering a custom Dialect class. The class object must /// have a DIALECT_NAMESPACE attribute. @@ -65,13 +65,18 @@ /// This is intended to be called by implementation code. void registerOperationImpl(const std::string &operationName, pybind11::object pyClass, - pybind11::object rawClass); + pybind11::object rawOpViewClass); /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. llvm::Optional lookupDialectClass(const std::string &dialectNamespace); + /// Looks up a registered raw OpView class by operation name. Note that this + /// may trigger a load of the dialect, which can arbitrarily re-enter. + llvm::Optional + lookupRawOpViewClass(llvm::StringRef operationName); + private: static PyGlobals *instance; /// Module name prefixes to search under for dialect implementation modules. @@ -85,7 +90,7 @@ llvm::StringMap operationClassMap; /// Map of operation name to custom subclass that directly initializes /// the OpView base class (bypassing the user class constructor). - llvm::StringMap rawOperationClassMap; + llvm::StringMap rawOpViewClassMap; }; } // namespace python diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -366,6 +366,24 @@ pybind11::handle handle; }; +/// Base class for PyOperation and PyOpView which exposes the primary, user +/// visible methods for manipulating it. +class PyOperationBase { +public: + virtual ~PyOperationBase() = default; + /// Implements the bound 'print' method and helps with others. + void print(pybind11::object fileObject, bool binary, + llvm::Optional largeElementsLimit, bool enableDebugInfo, + bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); + pybind11::object getAsm(bool binary, + llvm::Optional largeElementsLimit, + bool enableDebugInfo, bool prettyDebugInfo, + bool printGenericOpForm, bool useLocalScope); + + /// Each must provide access to the raw Operation. + virtual PyOperation &getOperation() = 0; +}; + /// Wrapper around PyOperation. /// Operations exist in either an attached (dependent) or detached (top-level) /// state. In the detached state (as on creation), an operation is owned by @@ -374,9 +392,11 @@ /// is bounded by its top-level parent reference. class PyOperation; using PyOperationRef = PyObjectRef; -class PyOperation : public BaseContextObject { +class PyOperation : public PyOperationBase, public BaseContextObject { public: ~PyOperation(); + PyOperation &getOperation() override { return *this; } + /// Returns a PyOperation for the given MlirOperation, optionally associating /// it with a parentKeepAlive. static PyOperationRef @@ -407,15 +427,6 @@ } void checkValid(); - /// Implements the bound 'print' method and helps with others. - void print(pybind11::object fileObject, bool binary, - llvm::Optional largeElementsLimit, bool enableDebugInfo, - bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); - pybind11::object getAsm(bool binary, - llvm::Optional largeElementsLimit, - bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, bool useLocalScope); - /// Gets the owning block or raises an exception if the operation has no /// owning block. PyBlock getBlock(); @@ -432,6 +443,9 @@ llvm::Optional> successors, int regions, DefaultingPyLocation location, pybind11::object ip); + /// Creates an OpView suitable for this operation. + pybind11::object createOpView(); + private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, @@ -456,17 +470,18 @@ /// custom ODS-style operation classes. Since this class is subclass on the /// python side, it must present an __init__ method that operates in pure /// python types. -class PyOpView { +class PyOpView : public PyOperationBase { public: - PyOpView(pybind11::object operation); + PyOpView(pybind11::object operationObject); + PyOperation &getOperation() override { return operation; } static pybind11::object createRawSubclass(pybind11::object userClass); pybind11::object getOperationObject() { return operationObject; } private: + PyOperation &operation; // For efficient, cast-free access from C++ pybind11::object operationObject; // Holds the reference. - PyOperation *operation; // For efficient, cast-free access from C++ }; /// Wrapper around an MlirRegion. @@ -519,7 +534,7 @@ /// block, but still inside the block. PyInsertionPoint(PyBlock &block); /// Creates an insertion point positioned before a reference operation. - PyInsertionPoint(PyOperation &beforeOperation); + PyInsertionPoint(PyOperationBase &beforeOperationBase); /// Shortcut to create an insertion point at the beginning of the block. static PyInsertionPoint atBlockBegin(PyBlock &block); @@ -527,7 +542,7 @@ static PyInsertionPoint atBlockTerminator(PyBlock &block); /// Inserts an operation. - void insert(PyOperation &operation); + void insert(PyOperationBase &operationBase); /// Enter and exit the context manager. pybind11::object contextEnter(); @@ -540,10 +555,10 @@ // Trampoline constructor that avoids null initializing members while // looking up parents. PyInsertionPoint(PyBlock block, llvm::Optional refOperation) - : block(std::move(block)), refOperation(std::move(refOperation)) {} + : refOperation(std::move(refOperation)), block(std::move(block)) {} - PyBlock block; llvm::Optional refOperation; + PyBlock block; }; /// Wrapper around the generic MlirAttribute. diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -407,7 +407,7 @@ PyOperationRef returnOperation = PyOperation::forOperation(parentOperation->getContext(), next); next = mlirOperationGetNextInBlock(next); - return returnOperation.releaseObject(); + return returnOperation->createOpView(); } static void bind(py::module &m) { @@ -457,7 +457,7 @@ while (!mlirOperationIsNull(childOp)) { if (index == 0) { return PyOperation::forOperation(parentOperation->getContext(), childOp) - .releaseObject(); + ->createOpView(); } childOp = mlirOperationGetNextInBlock(childOp); index -= 1; @@ -868,11 +868,12 @@ } } -void PyOperation::print(py::object fileObject, bool binary, - llvm::Optional largeElementsLimit, - bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, bool useLocalScope) { - checkValid(); +void PyOperationBase::print(py::object fileObject, bool binary, + llvm::Optional largeElementsLimit, + bool enableDebugInfo, bool prettyDebugInfo, + bool printGenericOpForm, bool useLocalScope) { + PyOperation &operation = getOperation(); + operation.checkValid(); if (fileObject.is_none()) fileObject = py::module::import("sys").attr("stdout"); MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); @@ -885,15 +886,16 @@ PyFileAccumulator accum(fileObject, binary); py::gil_scoped_release(); - mlirOperationPrintWithFlags(get(), flags, accum.getCallback(), + mlirOperationPrintWithFlags(operation.get(), flags, accum.getCallback(), accum.getUserData()); mlirOpPrintingFlagsDestroy(flags); } -py::object PyOperation::getAsm(bool binary, - llvm::Optional largeElementsLimit, - bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, bool useLocalScope) { +py::object PyOperationBase::getAsm(bool binary, + llvm::Optional largeElementsLimit, + bool enableDebugInfo, bool prettyDebugInfo, + bool printGenericOpForm, + bool useLocalScope) { py::object fileObject; if (binary) { fileObject = py::module::import("io").attr("BytesIO")(); @@ -1034,12 +1036,26 @@ ip->insert(*created.get()); } - return created.releaseObject(); + return created->createOpView(); } -PyOpView::PyOpView(py::object operation) - : operationObject(std::move(operation)), - operation(py::cast(this->operationObject)) {} +py::object PyOperation::createOpView() { + // TODO: There is various caching that could be done to reduce work. + // Investigate if view creation is ever a problem. + MlirIdentifier ident = mlirOperationGetName(get()); + MlirStringRef identStr = mlirIdentifierStr(ident); + auto opViewClass = PyGlobals::get().lookupRawOpViewClass( + llvm::StringRef(identStr.data, identStr.length)); + if (opViewClass) + return (*opViewClass)(getRef().getObject()); + return py::cast(PyOpView(getRef().getObject())); +} + +PyOpView::PyOpView(py::object operationObject) + // Casting through the PyOperationBase base-class and then back to the + // Operation lets us accept any PyOperationBase subclass. + : operation(py::cast(operationObject).getOperation()), + operationObject(operation.getRef().getObject()) {} py::object PyOpView::createRawSubclass(py::object userClass) { // This is... a little gross. The typical pattern is to have a pure python @@ -1082,11 +1098,12 @@ PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} -PyInsertionPoint::PyInsertionPoint(PyOperation &beforeOperation) - : block(beforeOperation.getBlock()), - refOperation(beforeOperation.getRef()) {} +PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) + : refOperation(beforeOperationBase.getOperation().getRef()), + block((*refOperation)->getBlock()) {} -void PyInsertionPoint::insert(PyOperation &operation) { +void PyInsertionPoint::insert(PyOperationBase &operationBase) { + PyOperation &operation = operationBase.getOperation(); if (operation.isAttached()) throw SetPyError(PyExc_ValueError, "Attempt to insert operation that is already attached"); @@ -2501,33 +2518,36 @@ //---------------------------------------------------------------------------- // Mapping of Operation. //---------------------------------------------------------------------------- - py::class_(m, "Operation") - .def_static("create", &PyOperation::create, py::arg("name"), - py::arg("operands") = py::none(), - py::arg("results") = py::none(), - py::arg("attributes") = py::none(), - py::arg("successors") = py::none(), py::arg("regions") = 0, - py::arg("loc") = py::none(), py::arg("ip") = py::none(), - kOperationCreateDocstring) - .def_property_readonly( - "context", - [](PyOperation &self) { return self.getContext().getObject(); }, - "Context that owns the Operation") - .def_property_readonly( - "operands", - [](PyOperation &self) { return PyOpOperandList(self.getRef()); }) - .def_property_readonly( - "regions", - [](PyOperation &self) { return PyRegionList(self.getRef()); }) + py::class_(m, "_OperationBase") + .def("__eq__", + [](PyOperationBase &self, PyOperationBase &other) { + return &self.getOperation() == &other.getOperation(); + }) + .def("__eq__", + [](PyOperationBase &self, py::object other) { return false; }) + .def_property_readonly("operands", + [](PyOperationBase &self) { + return PyOpOperandList( + self.getOperation().getRef()); + }) + .def_property_readonly("regions", + [](PyOperationBase &self) { + return PyRegionList( + self.getOperation().getRef()); + }) .def_property_readonly( "results", - [](PyOperation &self) { return PyOpResultList(self.getRef()); }, + [](PyOperationBase &self) { + return PyOpResultList(self.getOperation().getRef()); + }, "Returns the list of Operation results.") .def("__iter__", - [](PyOperation &self) { return PyRegionIterator(self.getRef()); }) + [](PyOperationBase &self) { + return PyRegionIterator(self.getOperation().getRef()); + }) .def( "__str__", - [](PyOperation &self) { + [](PyOperationBase &self) { return self.getAsm(/*binary=*/false, /*largeElementsLimit=*/llvm::None, /*enableDebugInfo=*/false, @@ -2536,7 +2556,7 @@ /*useLocalScope=*/false); }, "Returns the assembly form of the operation.") - .def("print", &PyOperation::print, + .def("print", &PyOperationBase::print, // Careful: Lots of arguments must match up with print method. py::arg("file") = py::none(), py::arg("binary") = false, py::arg("large_elements_limit") = py::none(), @@ -2544,7 +2564,7 @@ py::arg("pretty_debug_info") = false, py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, kOperationPrintDocstring) - .def("get_asm", &PyOperation::getAsm, + .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. py::arg("binary") = false, py::arg("large_elements_limit") = py::none(), @@ -2553,9 +2573,29 @@ py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, kOperationGetAsmDocstring); - py::class_(m, "OpView") + py::class_(m, "Operation") + .def_static("create", &PyOperation::create, py::arg("name"), + py::arg("operands") = py::none(), + py::arg("results") = py::none(), + py::arg("attributes") = py::none(), + py::arg("successors") = py::none(), py::arg("regions") = 0, + py::arg("loc") = py::none(), py::arg("ip") = py::none(), + kOperationCreateDocstring) + .def_property_readonly( + "context", + [](PyOperation &self) { return self.getContext().getObject(); }, + "Context that owns the Operation") + .def_property_readonly("opview", &PyOperation::createOpView); + + py::class_(m, "OpView") .def(py::init()) .def_property_readonly("operation", &PyOpView::getOperationObject) + .def_property_readonly( + "context", + [](PyOpView &self) { + return self.getOperation().getContext().getObject(); + }, + "Context that owns the Operation") .def("__str__", [](PyOpView &self) { return py::str(self.getOperationObject()); }); @@ -2577,14 +2617,11 @@ return PyBlockIterator(self.getParentOperation(), firstBlock); }, "Iterates over blocks in the region.") - .def("__eq__", [](PyRegion &self, py::object &other) { - try { - PyRegion *otherRegion = other.cast(); - return self.get().ptr == otherRegion->get().ptr; - } catch (std::exception &e) { - return false; - } - }); + .def("__eq__", + [](PyRegion &self, PyRegion &other) { + return self.get().ptr == other.get().ptr; + }) + .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); //---------------------------------------------------------------------------- // Mapping of PyBlock. @@ -2613,14 +2650,10 @@ }, "Iterates over operations in the block.") .def("__eq__", - [](PyBlock &self, py::object &other) { - try { - PyBlock *otherBlock = other.cast(); - return self.get().ptr == otherBlock->get().ptr; - } catch (std::exception &e) { - return false; - } + [](PyBlock &self, PyBlock &other) { + return self.get().ptr == other.get().ptr; }) + .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) .def( "__str__", [](PyBlock &self) { @@ -2651,7 +2684,7 @@ }, "Gets the InsertionPoint bound to the current thread or raises " "ValueError if none has been set") - .def(py::init(), py::arg("beforeOperation"), + .def(py::init(), py::arg("beforeOperation"), "Inserts before a referenced operation.") .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, py::arg("block"), "Inserts at the beginning of the block.") @@ -2696,14 +2729,8 @@ }, py::keep_alive<0, 1>(), "Binds a name to the attribute") .def("__eq__", - [](PyAttribute &self, py::object &other) { - try { - PyAttribute otherAttribute = other.cast(); - return self == otherAttribute; - } catch (std::exception &e) { - return false; - } - }) + [](PyAttribute &self, PyAttribute &other) { return self == other; }) + .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) .def( "dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); }, kDumpDocstring) @@ -2793,15 +2820,8 @@ .def_property_readonly( "context", [](PyType &self) { return self.getContext().getObject(); }, "Context that owns the Type") - .def("__eq__", - [](PyType &self, py::object &other) { - try { - PyType otherType = other.cast(); - return self == otherType; - } catch (std::exception &e) { - return false; - } - }) + .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) + .def("__eq__", [](PyType &self, py::object &other) { return false; }) .def( "dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring) .def( diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -30,7 +30,7 @@ PyGlobals::~PyGlobals() { instance = nullptr; } -void PyGlobals::loadDialectModule(const std::string &dialectNamespace) { +void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { if (loadedDialectModules.contains(dialectNamespace)) return; // Since re-entrancy is possible, make a copy of the search prefixes. @@ -38,7 +38,7 @@ py::object loaded; for (std::string moduleName : localSearchPrefixes) { moduleName.push_back('.'); - moduleName.append(dialectNamespace); + moduleName.append(dialectNamespace.data(), dialectNamespace.size()); try { loaded = py::module::import(moduleName.c_str()); @@ -69,7 +69,8 @@ } void PyGlobals::registerOperationImpl(const std::string &operationName, - py::object pyClass, py::object rawClass) { + py::object pyClass, + py::object rawOpViewClass) { py::object &found = operationClassMap[operationName]; if (found) { throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + @@ -77,7 +78,7 @@ "' is already registered."); } found = std::move(pyClass); - rawOperationClassMap[operationName] = std::move(rawClass); + rawOpViewClassMap[operationName] = std::move(rawOpViewClass); } llvm::Optional @@ -97,6 +98,40 @@ return llvm::None; } +llvm::Optional +PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { + { + auto foundIt = rawOpViewClassMap.find(operationName); + if (foundIt != rawOpViewClassMap.end()) { + if (foundIt->second.is_none()) + return llvm::None; + assert(foundIt->second && "py::object is defined"); + return foundIt->second; + } + } + + // Not found. Load the dialect namespace. + auto split = operationName.split('.'); + llvm::StringRef dialectNamespace = split.first; + loadDialectModule(dialectNamespace); + + // Attempt to find again and negative cache if not found. + { + auto foundIt = rawOpViewClassMap.find(operationName); + + if (foundIt != rawOpViewClassMap.end()) { + if (foundIt->second.is_none()) + return llvm::None; + assert(foundIt->second && "py::object is defined"); + return foundIt->second; + } + + // Negative cache. + rawOpViewClassMap[operationName] = py::none(); + return llvm::None; + } +} + // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -293,3 +293,34 @@ pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True) run(testOperationPrint) + + +def testKnownOpView(): + with Context(), Location.unknown(): + Context.current.allow_unregistered_dialects = True + module = Module.parse(r""" + %1 = "custom.f32"() : () -> f32 + %2 = "custom.f32"() : () -> f32 + %3 = addf %1, %2 : f32 + """) + print(module) + + # addf should map to a known OpView class in the std dialect. + # We know the OpView for it defines an 'lhs' attribute. + addf = module.body.operations[2] + # CHECK: