diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/include/mlir/Bindings/Python/IRModule.h rename from mlir/lib/Bindings/Python/IRModule.h rename to mlir/include/mlir/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/include/mlir/Bindings/Python/IRModule.h @@ -21,6 +21,7 @@ #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Twine.h" namespace mlir { namespace python { @@ -155,7 +156,7 @@ /// Wrapper around MlirContext. using PyMlirContextRef = PyObjectRef; -class PyMlirContext { +class PYBIND11_EXPORT PyMlirContext { public: PyMlirContext() = delete; PyMlirContext(const PyMlirContext &) = delete; @@ -576,7 +577,7 @@ /// is bounded by its top-level parent reference. class PyOperation; using PyOperationRef = PyObjectRef; -class PyOperation : public PyOperationBase, public BaseContextObject { +class PYBIND11_EXPORT PyOperation : public PyOperationBase, public BaseContextObject { public: ~PyOperation() override; PyOperation &getOperation() override { return *this; } @@ -998,6 +999,54 @@ MlirValue value; }; +/// CRTP base class for Python MLIR values that subclass Value and should be +/// castable from it. The value hierarchy is one level deep and is not supposed +/// to accommodate other levels unless core MLIR changes. +template +class PyConcreteValue : public PyValue { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + // and redefine bindDerived. + using ClassTy = pybind11::class_; + using IsAFunctionTy = bool (*)(MlirValue); + + PyConcreteValue() = default; + PyConcreteValue(PyOperationRef operationRef, MlirValue value) + : PyValue(operationRef, value) {} + PyConcreteValue(PyValue &orig) + : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} + + /// Attempts to cast the original value to the derived type and throws on + /// type mismatches. + static MlirValue castFrom(PyValue &orig) { + if (!DerivedTy::isaFunction(orig.get())) { + auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast value to ") + + DerivedTy::pyClassName + + " (from " + origRepr + ")"); + } + return orig.get(); + } + + /// Binds the Python module objects to functions of this class. + static void bind(pybind11::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(pybind11::init(), pybind11::keep_alive<0, 1>(), pybind11::arg("value")); + cls.def_static( + "isinstance", + [](PyValue &otherValue) -> bool { + return DerivedTy::isaFunction(otherValue); + }, + pybind11::arg("other_value")); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. class PyAffineExpr : public BaseContextObject { public: diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/include/mlir/Bindings/Python/PybindUtils.h rename from mlir/lib/Bindings/Python/PybindUtils.h rename to mlir/include/mlir/Bindings/Python/PybindUtils.h --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/include/mlir/Bindings/Python/PybindUtils.h @@ -57,6 +57,13 @@ ReferrentTy *referrent = nullptr; }; +pybind11::error_already_set +inline SetPyError(PyObject *excClass, const llvm::Twine &message) { + auto messageStr = message.str(); + PyErr_SetString(excClass, messageStr.c_str()); + return pybind11::error_already_set(); +} + } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -13,7 +13,7 @@ #include #include -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -8,9 +8,8 @@ #include -#include "IRModule.h" - -#include "PybindUtils.h" +#include "mlir/Bindings/Python/IRModule.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Bindings/Python/Interop.h" diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -9,9 +9,8 @@ #include #include -#include "IRModule.h" - -#include "PybindUtils.h" +#include "mlir/Bindings/Python/IRModule.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#include "IRModule.h" +#include "mlir/Bindings/Python/IRModule.h" #include "Globals.h" -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" @@ -3096,7 +3096,7 @@ //---------------------------------------------------------------------------- // Mapping of PyAttribute. //---------------------------------------------------------------------------- - py::class_(m, "Attribute", py::module_local()) + py::class_(m, "Attribute") // Delegate to the PyAttribute copy constructor, which will also lifetime // extend the backing context which owns the MlirAttribute. .def(py::init(), py::arg("cast_from_type"), @@ -3205,7 +3205,7 @@ //---------------------------------------------------------------------------- // Mapping of PyType. //---------------------------------------------------------------------------- - py::class_(m, "Type", py::module_local()) + py::class_(m, "Type") // Delegate to the PyType copy constructor, which will also lifetime // extend the backing context which owns the MlirType. .def(py::init(), py::arg("cast_from_type"), @@ -3259,7 +3259,7 @@ //---------------------------------------------------------------------------- // Mapping of Value. //---------------------------------------------------------------------------- - py::class_(m, "Value", py::module_local()) + py::class_(m, "Value") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) .def_property_readonly( diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -9,7 +9,7 @@ #include #include -#include "IRModule.h" +#include "mlir/Bindings/Python/IRModule.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Interfaces.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -6,9 +6,9 @@ // //===----------------------------------------------------------------------===// -#include "IRModule.h" +#include "mlir/Bindings/Python/IRModule.h" #include "Globals.h" -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include #include diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -6,9 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "IRModule.h" - -#include "PybindUtils.h" +#include "mlir/Bindings/Python/IRModule.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -8,10 +8,10 @@ #include -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "Globals.h" -#include "IRModule.h" +#include "mlir/Bindings/Python/IRModule.h" #include "Pass.h" namespace py = pybind11; diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h --- a/mlir/lib/Bindings/Python/Pass.h +++ b/mlir/lib/Bindings/Python/Pass.h @@ -9,7 +9,7 @@ #ifndef MLIR_BINDINGS_PYTHON_PASS_H #define MLIR_BINDINGS_PYTHON_PASS_H -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindUtils.h" namespace mlir { namespace python { diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,7 +8,7 @@ #include "Pass.h" -#include "IRModule.h" +#include "mlir/Bindings/Python/IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Pass.h" diff --git a/mlir/lib/Bindings/Python/PybindUtils.cpp b/mlir/lib/Bindings/Python/PybindUtils.cpp deleted file mode 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.cpp +++ /dev/null @@ -1,16 +0,0 @@ -//===- PybindUtils.cpp - Utilities for interop with pybind11 --------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "PybindUtils.h" - -pybind11::error_already_set -mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) { - auto messageStr = message.str(); - PyErr_SetString(excClass, messageStr.c_str()); - return pybind11::error_already_set(); -} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -263,14 +263,11 @@ IRInterfaces.cpp IRModule.cpp IRTypes.cpp - PybindUtils.cpp Pass.cpp # Headers must be included explicitly so they are installed. Globals.h - IRModule.h Pass.h - PybindUtils.h PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * -from .._mlir_libs._mlirPythonTest import TestAttr, TestType +from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue def register_python_test_dialect(context, load=True): from .._mlir_libs import _mlirPythonTest diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -2,6 +2,8 @@ from mlir.ir import * import mlir.dialects.python_test as test +import mlir.dialects.tensor as tensor +import mlir.dialects.arith as arith def run(f): print("\nTEST:", f.__name__) @@ -302,3 +304,38 @@ pass else: raise + + +@run +# CHECK-LABEL: TEST: testCustomValue +def testCustomValue(): + with Context() as ctx, Location.unknown(): + test.register_python_test_dialect(ctx) + + i8 = IntegerType.get_signless(8) + + class Tensor(test.TestTensorValue): + def __getitem__(self, dims): + dims = list(dims) + for i, d in enumerate(dims): + if isinstance(d, int): + dims[i] = arith.ConstantOp.create_index(d) + + return tensor.ExtractOp(self, dims).result + + module = Module.create() + with InsertionPoint(module.body): + t = tensor.EmptyOp([10, 10], i8).result + tt = Tensor(t) + # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>) + print(tt) + el = tt[0, 0] + + + # CHECK: module { + # CHECK: %[[T0:.*]] = tensor.empty() : tensor<10x10xi8> + # CHECK: %[[C0:.*]] = arith.constant 0 : index + # CHECK: %[[C00:.*]] = arith.constant 0 : index + # CHECK: %[[RES:.*]] = tensor.extract %[[T0]][%[[C0]], %[[C00]]] : tensor<10x10xi8> + # CHECK: } + print(module) diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h --- a/mlir/test/python/lib/PythonTestCAPI.h +++ b/mlir/test/python/lib/PythonTestCAPI.h @@ -27,6 +27,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context); +MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value); + #ifdef __cplusplus } #endif diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp --- a/mlir/test/python/lib/PythonTestCAPI.cpp +++ b/mlir/test/python/lib/PythonTestCAPI.cpp @@ -10,6 +10,7 @@ #include "PythonTestDialect.h" #include "mlir/CAPI/Registration.h" #include "mlir/CAPI/Wrap.h" +#include "mlir-c/BuiltinTypes.h" MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test, python_test::PythonTestDialect) @@ -29,3 +30,7 @@ MlirType mlirPythonTestTestTypeGet(MlirContext context) { return wrap(python_test::TestTypeType::get(unwrap(context))); } + +bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) { + return mlirTypeIsATensor(wrap(unwrap(value).getType())); +} diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -7,10 +7,37 @@ //===----------------------------------------------------------------------===// #include "PythonTestCAPI.h" +#include "mlir/Bindings/Python/IRModule.h" +#include "mlir/Bindings/Python/PybindUtils.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/CAPI/Support.h" namespace py = pybind11; using namespace mlir::python::adaptors; +using namespace mlir::python; + + +/// Python wrapper for MlirTensor. +class PyTensor : public PyConcreteValue { +public: + static constexpr IsAFunctionTy isaFunction = + mlirTypeIsAPythonTestTestTensorValue; + static constexpr const char *pyClassName = "TestTensorValue"; + using PyConcreteValue::PyConcreteValue; + + static void bindDerived(ClassTy &c) { + c.def( + "__str__", + [](PyValue &self) { + mlir::PyPrintAccumulator printAccum; + printAccum.parts.append("Tensor("); + mlirValuePrint(self.get(), printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + } +}; PYBIND11_MODULE(_mlirPythonTest, m) { m.def( @@ -40,4 +67,6 @@ return cls(mlirPythonTestTestTypeGet(ctx)); }, py::arg("cls"), py::arg("context") = py::none()); + + PyTensor::bind(m); }