diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -618,6 +618,19 @@ /// Returns 1 if the value is an operation result, 0 otherwise. MLIR_CAPI_EXPORTED bool mlirValueIsAOpResult(MlirValue value); +/// Replaces all uses of a value with a new value. +MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesWith(MlirValue value, + MlirValue newValue); + +/// Replaces uses of a value with a new value if a predicate returns true. +/// The predicate takes a user operation and the index into its operands of the +/// use. +MLIR_CAPI_EXPORTED void mlirValueReplaceUsesWithIf( + MlirValue value, MlirValue newValue, + bool (*shouldReplace)(MlirOperation operation, intptr_t operandIndex, + void *userData), + void *userData); + /// Returns the block in which this value is defined as an argument. Asserts if /// the value is not a block argument. MLIR_CAPI_EXPORTED MlirBlock mlirBlockArgumentGetOwner(MlirValue value); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -467,9 +467,13 @@ return forContext(rawContext).releaseObject(); } -PyMlirContext *PyMlirContext::createNewContextForInit() { +PyMlirContext * +PyMlirContext::createNewContextForInit(bool registerAllDialects) { MlirContext context = mlirContextCreate(); - mlirRegisterAllDialects(context); + // TODO: Remove implicit registration of dialects. + if (registerAllDialects) { + mlirRegisterAllDialects(context); + } return new PyMlirContext(context); } @@ -1204,12 +1208,17 @@ py::object PyOperation::createOpView() { checkValid(); - MlirIdentifier ident = mlirOperationGetName(get()); - MlirStringRef identStr = mlirIdentifierStr(ident); - auto opViewClass = PyGlobals::get().lookupRawOpViewClass( - StringRef(identStr.data, identStr.length)); - if (opViewClass) - return (*opViewClass)(getRef().getObject()); + MlirOperation cOp = get(); + MlirTypeID registeredTypeID = mlirOperationGetTypeID(cOp); + // For registered ops, lookup an appropriate, generated op view. + if (!mlirTypeIDIsNull(registeredTypeID)) { + MlirIdentifier ident = mlirOperationGetName(cOp); + MlirStringRef identStr = mlirIdentifierStr(ident); + auto opViewClass = PyGlobals::get().lookupRawOpViewClass( + StringRef(identStr.data, identStr.length)); + if (opViewClass) + return (*opViewClass)(getRef().getObject()); + } return py::cast(PyOpView(getRef().getObject())); } @@ -2181,7 +2190,8 @@ // Mapping of MlirContext. //---------------------------------------------------------------------------- py::class_(m, "Context", py::module_local()) - .def(py::init<>(&PyMlirContext::createNewContextForInit)) + .def(py::init<>(&PyMlirContext::createNewContextForInit), + py::arg("register_all_dialects") = true) .def_static("_get_live_count", &PyMlirContext::getLiveCount) .def("_get_context_again", [](PyMlirContext &self) { @@ -2569,30 +2579,48 @@ py::arg("use_local_scope") = false, py::arg("assume_verified") = false, kOperationGetAsmDocstring) .def( - "verify", + "detach_from_parent", [](PyOperationBase &self) { - return mlirOperationVerify(self.getOperation()); + PyOperation &operation = self.getOperation(); + operation.checkValid(); + if (!operation.isAttached()) + throw py::value_error("Detached operation has no parent."); + + operation.detachFromParent(); + return operation.createOpView(); }, - "Verify the operation and return true if it passes, false if it " - "fails.") + "Detaches the operation from its parent block.") + .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), "Puts self immediately after the other operation in its parent " "block.") .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), "Puts self immediately before the other operation in its parent " "block.") + .def_property_readonly("name", + [](PyOperationBase &self) { + PyOperation &operation = self.getOperation(); + operation.checkValid(); + MlirStringRef name = mlirIdentifierStr( + mlirOperationGetName(operation.get())); + return py::str(name.data, name.length); + }) + .def_property_readonly("parent", + [](PyOperationBase &self) -> py::object { + PyOperation &operation = self.getOperation(); + operation.checkValid(); + auto parent = operation.getParentOperation(); + if (parent) + return parent->getObject(); + return py::none(); + }) .def( - "detach_from_parent", + "verify", [](PyOperationBase &self) { - PyOperation &operation = self.getOperation(); - operation.checkValid(); - if (!operation.isAttached()) - throw py::value_error("Detached operation has no parent."); - - operation.detachFromParent(); - return operation.createOpView(); + return mlirOperationVerify(self.getOperation()); }, - "Detaches the operation from its parent block."); + "Verify the operation and return true if it passes, false if it " + "fails."); py::class_(m, "Operation", py::module_local()) .def_static("create", &PyOperation::create, py::arg("name"), @@ -2602,25 +2630,9 @@ py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), kOperationCreateDocstring) - .def_property_readonly("parent", - [](PyOperation &self) -> py::object { - auto parent = self.getParentOperation(); - if (parent) - return parent->getObject(); - return py::none(); - }) - .def("erase", &PyOperation::erase) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) - .def_property_readonly("name", - [](PyOperation &self) { - self.checkValid(); - MlirOperation operation = self.get(); - MlirStringRef name = mlirIdentifierStr( - mlirOperationGetName(operation)); - return py::str(name.data, name.length); - }) .def_property_readonly( "context", [](PyOperation &self) { @@ -3026,6 +3038,10 @@ "the IR"); return self.getParentOperation().getObject(); }) + .def("replace_all_uses_with", + [](PyValue &self, PyValue &newValue) { + mlirValueReplaceAllUsesWith(self.get(), newValue.get()); + }) .def("__eq__", [](PyValue &self, PyValue &other) { return self.get().ptr == other.get().ptr; diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -168,7 +168,7 @@ /// that is also not supported by pybind11. Instead, we use this entry /// point which always constructs a fresh context (which cannot alias an /// existing one because it is fresh). - static PyMlirContext *createNewContextForInit(); + static PyMlirContext *createNewContextForInit(bool registerAllDialects); /// Returns a context reference for the singleton PyMlirContext wrapper for /// the given context. diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -555,6 +555,18 @@ return mlirTupleTypeGetNumTypes(self); }, "Returns the number of types contained in a tuple."); + c.def_property_readonly( + "types", + [](PyTupleType &self) { + py::list result; + auto &context = self.getContext(); + intptr_t n = mlirTupleTypeGetNumTypes(self); + for (intptr_t i = 0; i < n; ++i) { + result.append(PyType(context, mlirTupleTypeGetType(self, i))); + } + return result; + }, + "Returns a sequence of the contained types."); } }; diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -648,6 +648,21 @@ return unwrap(value).isa(); } +void mlirValueReplaceAllUsesWith(MlirValue value, MlirValue newValue) { + unwrap(value).replaceAllUsesWith(unwrap(newValue)); +} + +void mlirValueReplaceUsesWithIf(MlirValue value, MlirValue newValue, + bool (*shouldReplace)(MlirOperation operation, + intptr_t operandIndex, + void *userData), + void *userData) { + unwrap(value).replaceUsesWithIf(unwrap(newValue), [&](OpOperand &operand) { + return shouldReplace(wrap(operand.getOwner()), operand.getOperandNumber(), + userData); + }); +} + MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { return wrap(unwrap(value).cast().getOwner()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -110,11 +110,16 @@ def _CAPIPtr(self) -> object: ... @property def attributes(self) -> "OpAttributeMap": ... + def erase(self) -> None: ... @property def location(self) -> "Location": ... @property + def name(self) -> str: ... + @property def operands(self) -> "OpOperandList": ... @property + def parent(self) -> Optional["_OperationBase"]: ... + @property def regions(self) -> "RegionSequence": ... @property def result(self) -> "OpResult": ... @@ -421,7 +426,7 @@ class Context: current: ClassVar["Context"] = ... # read-only allow_unregistered_dialects: bool - def __init__(self) -> None: ... + def __init__(self, register_all_dialects: bool = True) -> None: ... def _CAPICreate(self) -> object: ... def _get_context_again(self) -> "Context": ... @staticmethod @@ -807,17 +812,12 @@ regions: int = 0, loc: Optional["Location"] = None, ip: Optional["InsertionPoint"] = None) -> "_OperationBase": ... - def erase(self) -> None: ... @property def _CAPIPtr(self) -> object: ... @property def context(self) -> "Context": ... @property - def name(self) -> str: ... - @property def opview(self) -> "OpView": ... - @property - def parent(self) -> Optional["_OperationBase"]: ... class OperationIterator: def __iter__(self) -> "OperationIterator": ...