diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -432,6 +432,9 @@ /** Returns the type of the value. */ MlirType mlirValueGetType(MlirValue value); +/** Prints the value to the standard error stream. */ +void mlirValueDump(MlirValue value); + /** Prints a value by sending chunks of the string representation and * forwarding `userData to `callback`. Note that the callback may be called * several times with consecutive chunks of the string. */ diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -23,6 +23,7 @@ class PyModule; class PyOperation; class PyType; +class PyValue; /// Template for a reference to a concrete type which captures a python /// reference to its underlying python object. @@ -381,6 +382,27 @@ MlirType type; }; +/// Wrapper around the generic MlirValue. +/// Values are managed completely by the operation that resulted in their +/// definition. For op result value, this is the operation that defines the +/// value. For block argument values, this is the operation that contains the +/// block to which the value is an argument (blocks cannot be detached in Python +/// bindings so such operation always exists). +class PyValue { +public: + PyValue(PyOperationRef parentOperation, MlirValue value) + : parentOperation(parentOperation), value(value) {} + + MlirValue get() { return value; } + PyOperationRef &getParentOperation() { return parentOperation; } + + void checkValid() { return parentOperation->checkValid(); } + +private: + PyOperationRef parentOperation; + MlirValue value; +}; + void populateIRSubmodule(pybind11::module &m); } // namespace python diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -85,6 +85,14 @@ The created block. )"; +static const char kValueDunderStrDocstring[] = + R"(Returns the string form of the value. + +If the value is a block argument, this is the assembly form of its type and the +position in the argument list. If the value is an operation result, this is +equivalent to printing the operation that produced it. +)"; + //------------------------------------------------------------------------------ // Conversion utilities. //------------------------------------------------------------------------------ @@ -732,6 +740,168 @@ return mlirTypeEqual(type, other.type); } +//------------------------------------------------------------------------------ +// PyValue and subclases. +//------------------------------------------------------------------------------ + +namespace { +/// CRTP base class for Python MLIR values that subclass Value and should be +/// castable from it. The value hierarchy is one level deep and is not supposed +/// to accommodate other levels unless core MLIR changes. +template class PyConcreteValue : public PyValue { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + // and redefine bindDerived. + using ClassTy = py::class_; + using IsAFunctionTy = int (*)(MlirValue); + + PyConcreteValue() = default; + PyConcreteValue(PyOperationRef operationRef, MlirValue value) + : PyValue(operationRef, value) {} + PyConcreteValue(PyValue &orig) + : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} + + /// Attempts to cast the original value to the derived type and throws on + /// type mismatches. + static MlirValue castFrom(PyValue &orig) { + if (!DerivedTy::isaFunction(orig.get())) { + auto origRepr = py::repr(py::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast value to ") + + DerivedTy::pyClassName + + " (from " + origRepr + ")"); + } + return orig.get(); + } + + /// Binds the Python module objects to functions of this class. + static void bind(py::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(py::init(), py::keep_alive<0, 1>()); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +/// Python wrapper for MlirBlockArgument. +class PyBlockArgument : public PyConcreteValue { +public: + static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; + static constexpr const char *pyClassName = "BlockArgument"; + using PyConcreteValue::PyConcreteValue; + + static void bindDerived(ClassTy &c) { + c.def_property_readonly("owner", [](PyBlockArgument &self) { + return PyBlock(self.getParentOperation(), + mlirBlockArgumentGetOwner(self.get())); + }); + c.def_property_readonly("arg_number", [](PyBlockArgument &self) { + return mlirBlockArgumentGetArgNumber(self.get()); + }); + c.def("set_type", [](PyBlockArgument &self, PyType type) { + return mlirBlockArgumentSetType(self.get(), type); + }); + } +}; + +/// Python wrapper for MlirOpResult. +class PyOpResult : public PyConcreteValue { +public: + static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; + static constexpr const char *pyClassName = "OpResult"; + using PyConcreteValue::PyConcreteValue; + + static void bindDerived(ClassTy &c) { + c.def_property_readonly("owner", [](PyOpResult &self) { + assert( + mlirOperationEqual(self.getParentOperation()->get(), + mlirOpResultGetOwner(self.get())) && + "expected the owner of the value in Python to match that in the IR"); + return self.getParentOperation(); + }); + c.def_property_readonly("result_number", [](PyOpResult &self) { + return mlirOpResultGetResultNumber(self.get()); + }); + } +}; + +/// A list of block arguments. Internally, these are stored as consecutive +/// elements, random access is cheap. The argument list is associated with the +/// operation that contains the block (detached blocks are not allowed in +/// Python bindings) and extends its lifetime. +class PyBlockArgumentList { +public: + PyBlockArgumentList(PyOperationRef operation, MlirBlock block) + : operation(std::move(operation)), block(block) {} + + /// Returns the length of the block argument list. + intptr_t dunderLen() { + operation->checkValid(); + return mlirBlockGetNumArguments(block); + } + + /// Returns `index`-th element of the block argument list. + PyBlockArgument dunderGetItem(intptr_t index) { + if (index < 0 || index >= dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds region"); + } + PyValue value(operation, mlirBlockGetArgument(block, index)); + return PyBlockArgument(value); + } + + /// Defines a Python class in the bindings. + static void bind(py::module &m) { + py::class_(m, "BlockArgumentList") + .def("__len__", &PyBlockArgumentList::dunderLen) + .def("__getitem__", &PyBlockArgumentList::dunderGetItem); + } + +private: + PyOperationRef operation; + MlirBlock block; +}; + +/// A list of operation results. Internally, these are stored as consecutive +/// elements, random access is cheap. The result list is associated with the +/// operation whose results these are, and extends the lifetime of this +/// operation. +class PyOpResultList { +public: + PyOpResultList(PyOperationRef operation) : operation(operation) {} + + /// Returns the length of the result list. + intptr_t dunderLen() { + operation->checkValid(); + return mlirOperationGetNumResults(operation->get()); + } + + /// Returns `index`-th element in the result list. + PyOpResult dunderGetItem(intptr_t index) { + if (index < 0 || index >= dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds region"); + } + PyValue value(operation, mlirOperationGetResult(operation->get(), index)); + return PyOpResult(value); + } + + /// Defines a Python class in the bindings. + static void bind(py::module &m) { + py::class_(m, "OpResultList") + .def("__len__", &PyOpResultList::dunderLen) + .def("__getitem__", &PyOpResultList::dunderGetItem); + } + +private: + PyOperationRef operation; +}; + +} // end namespace + //------------------------------------------------------------------------------ // Standard attribute subclasses. //------------------------------------------------------------------------------ @@ -1793,6 +1963,10 @@ .def_property_readonly( "regions", [](PyOperation &self) { return PyRegionList(self.getRef()); }) + .def_property_readonly( + "results", + [](PyOperation &self) { return PyOpResultList(self.getRef()); }, + "Returns the list of Operation results.") .def("__iter__", [](PyOperation &self) { return PyRegionIterator(self.getRef()); }) .def( @@ -1833,6 +2007,12 @@ // Mapping of PyBlock. py::class_(m, "Block") + .def_property_readonly( + "arguments", + [](PyBlock &self) { + return PyBlockArgumentList(self.getParentOperation(), self.get()); + }, + "Returns a list of block arguments.") .def_property_readonly( "operations", [](PyBlock &self) { @@ -2015,11 +2195,40 @@ PyTupleType::bind(m); PyFunctionType::bind(m); + // Mapping of Value. + py::class_(m, "Value") + .def_property_readonly( + "context", + [](PyValue &self) { return self.getParentOperation()->getContext(); }, + "Context in which the value lives.") + .def( + "dump", [](PyValue &self) { mlirValueDump(self.get()); }, + kDumpDocstring) + .def( + "__str__", + [](PyValue &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("Value("); + mlirValuePrint(self.get(), printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }, + kValueDunderStrDocstring) + .def_property_readonly("type", [](PyValue &self) { + return PyType(self.getParentOperation()->getContext(), + mlirValueGetType(self.get())); + }); + PyBlockArgument::bind(m); + PyOpResult::bind(m); + // Container bindings. + PyBlockArgumentList::bind(m); PyBlockIterator::bind(m); PyBlockList::bind(m); PyOperationIterator::bind(m); PyOperationList::bind(m); + PyOpResultList::bind(m); PyRegionIterator::bind(m); PyRegionList::bind(m); } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -454,6 +454,8 @@ return wrap(unwrap(value).getType()); } +void mlirValueDump(MlirValue value) { unwrap(value).dump(); } + void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData) { detail::CallbackOstream stream(callback, userData); diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -102,6 +102,35 @@ run(testTraverseOpRegionBlockIndices) +# CHECK-LABEL: TEST: testBlockArgumentList +def testBlockArgumentList(): + ctx = mlir.ir.Context() + module = ctx.parse_module(r""" + func @f1(%arg0: i32, %arg1: f64, %arg2: index) { + return + } + """) + func = module.operation.regions[0].blocks[0].operations[0] + entry_block = func.regions[0].blocks[0] + assert len(entry_block.arguments) == 3 + # CHECK: Argument 0, type i32 + # CHECK: Argument 1, type f64 + # CHECK: Argument 2, type index + for arg in entry_block.arguments: + print(f"Argument {arg.arg_number}, type {arg.type}") + new_type = mlir.ir.IntegerType.get_signless(ctx, 8 * (arg.arg_number + 1)) + arg.set_type(new_type) + + # CHECK: Argument 0, type i8 + # CHECK: Argument 1, type i16 + # CHECK: Argument 2, type i24 + for arg in entry_block.arguments: + print(f"Argument {arg.arg_number}, type {arg.type}") + + +run(testBlockArgumentList) + + # CHECK-LABEL: TEST: testDetachedOperation def testDetachedOperation(): ctx = mlir.ir.Context() @@ -196,3 +225,26 @@ print(module) run(testOperationWithRegion) + + +# CHECK-LABEL: TEST: testOperationResultList +def testOperationResultList(): + ctx = mlir.ir.Context() + module = ctx.parse_module(r""" + func @f1() { + %0:3 = call @f2() : () -> (i32, f64, index) + return + } + func @f2() -> (i32, f64, index) + """) + caller = module.operation.regions[0].blocks[0].operations[0] + call = caller.regions[0].blocks[0].operations[0] + assert len(call.results) == 3 + # CHECK: Result 0, type i32 + # CHECK: Result 1, type f64 + # CHECK: Result 2, type index + for res in call.results: + print(f"Result {res.result_number}, type {res.type}") + + +run(testOperationResultList)