diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -9,6 +9,7 @@ #include "IRModule.h" #include "Globals.h" +#include "IRTypes.h" #include "PybindUtils.h" #include "mlir-c/Bindings/Python/Interop.h" @@ -3353,12 +3354,8 @@ return printAccum.join(); }, py::arg("use_local_scope") = false, kGetNameAsOperand) - .def_property_readonly("type", - [](PyValue &self) { - return PyType( - self.getParentOperation()->getContext(), - mlirValueGetType(self.get())); - }) + .def_property_readonly( + "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) .def( "replace_all_uses_with", [](PyValue &self, PyValue &with) { diff --git a/mlir/lib/Bindings/Python/IRTypes.h b/mlir/lib/Bindings/Python/IRTypes.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/IRTypes.h @@ -0,0 +1,307 @@ +//===- IRTypes.h - IR types of pybind module -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H +#define MLIR_BINDINGS_PYTHON_IRTYPES_H + +#include "IRModule.h" +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/BuiltinTypes.h" +#include "mlir-c/IR.h" +#include + +using namespace mlir; +using namespace mlir::python; + +namespace mlir { +namespace python { +class PyIntegerType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; + static constexpr const char *pyClassName = "IntegerType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Index Type subclass - IndexType. +class PyIndexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; + static constexpr const char *pyClassName = "IndexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E4M3FNType. +class PyFloat8E4M3FNType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; + static constexpr const char *pyClassName = "Float8E4M3FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8M5E2Type. +class PyFloat8E5M2Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; + static constexpr const char *pyClassName = "Float8E5M2Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E4M3FNUZ. +class PyFloat8E4M3FNUZType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; + static constexpr const char *pyClassName = "Float8E4M3FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E4M3B11FNUZ. +class PyFloat8E4M3B11FNUZType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; + static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E5M2FNUZ. +class PyFloat8E5M2FNUZType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; + static constexpr const char *pyClassName = "Float8E5M2FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - BF16Type. +class PyBF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; + static constexpr const char *pyClassName = "BF16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - F16Type. +class PyF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; + static constexpr const char *pyClassName = "F16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - F32Type. +class PyF32Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; + static constexpr const char *pyClassName = "F32Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - F64Type. +class PyF64Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; + static constexpr const char *pyClassName = "F64Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// None Type subclass - NoneType. +class PyNoneType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; + static constexpr const char *pyClassName = "NoneType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Complex Type subclass - ComplexType. +class PyComplexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; + static constexpr const char *pyClassName = "ComplexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +class PyShapedType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; + static constexpr const char *pyClassName = "ShapedType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); + +private: + void requireHasRank(); +}; + +/// Vector Type subclass - VectorType. +class PyVectorType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; + static constexpr const char *pyClassName = "VectorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Ranked Tensor Type subclass - RankedTensorType. +class PyRankedTensorType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr const char *pyClassName = "RankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Unranked Tensor Type subclass - UnrankedTensorType. +class PyUnrankedTensorType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; + static constexpr const char *pyClassName = "UnrankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Ranked MemRef Type subclass - MemRefType. +class PyMemRefType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; + static constexpr const char *pyClassName = "MemRefType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Unranked MemRef Type subclass - UnrankedMemRefType. +class PyUnrankedMemRefType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; + static constexpr const char *pyClassName = "UnrankedMemRefType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Tuple Type subclass - TupleType. +class PyTupleType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; + static constexpr const char *pyClassName = "TupleType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Function type. +class PyFunctionType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; + static constexpr const char *pyClassName = "FunctionType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Opaque Type subclass - OpaqueType. +class PyOpaqueType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; + static constexpr const char *pyClassName = "OpaqueType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +// Note order here matters; firstly with respect to overhead; more common types +// should match earliest in the if-case tree. Secondly, concrete types that have +// a base type (such as all of the ShapedTypes) should come before the base type +// to prevent a "less-refined-than-possible" match. +#define FORALL_CONCRETE_TYPES(_) \ + _(Index) \ + _(Integer) \ + _(F16) \ + _(F32) \ + _(F64) \ + _(RankedTensor) \ + _(UnrankedTensor) \ + _(UnrankedMemRef) \ + _(MemRef) \ + _(Vector) \ + _(Complex) \ + _(BF16) \ + _(Float8E4M3B11FNUZ) \ + _(Float8E4M3FN) \ + _(Float8E4M3FNUZ) \ + _(Float8E5M2) \ + _(Float8E5M2FNUZ) \ + _(Function) \ + _(None) \ + _(Opaque) \ + _(Tuple) \ + _(Shaped) + +} // namespace python +} // namespace mlir + +#include + +namespace pybind11 { +namespace detail { + +/// Casts MlirType that matches one of the concretes above <-> ConcreteType. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirType, _("MlirType")); + static handle cast(MlirType t, return_value_policy, handle) { + PyMlirContextRef context = PyMlirContext::forContext(mlirTypeGetContext(t)); + +#define DEFINE_SUBCLASS(TTT) \ + if (Py##TTT##Type::isaFunction(t)) \ + return pybind11::cast(Py##TTT##Type(context, t)).release(); + FORALL_CONCRETE_TYPES(DEFINE_SUBCLASS) +#undef DEFINE_SUBCLASS + + object capsule = reinterpret_steal(mlirPythonTypeToCapsule(t)); + return module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Type") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +} // namespace detail +} // namespace pybind11 + +#endif diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "IRTypes.h" #include "IRModule.h" #include "PybindUtils.h" @@ -29,687 +30,508 @@ mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); } -class PyIntegerType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; - static constexpr const char *pyClassName = "IntegerType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_signless", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create a signless integer type"); - c.def_static( - "get_signed", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeSignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create a signed integer type"); - c.def_static( - "get_unsigned", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create an unsigned integer type"); - c.def_property_readonly( - "width", - [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, - "Returns the width of the integer type"); - c.def_property_readonly( - "is_signless", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSignless(self); - }, - "Returns whether this is a signless integer"); - c.def_property_readonly( - "is_signed", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSigned(self); - }, - "Returns whether this is a signed integer"); - c.def_property_readonly( - "is_unsigned", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsUnsigned(self); - }, - "Returns whether this is an unsigned integer"); - } -}; - -/// Index Type subclass - IndexType. -class PyIndexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; - static constexpr const char *pyClassName = "IndexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirIndexTypeGet(context->get()); - return PyIndexType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a index type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3FNType. -class PyFloat8E4M3FNType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; - static constexpr const char *pyClassName = "Float8E4M3FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); - return PyFloat8E4M3FNType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a float8_e4m3fn type."); - } -}; - -/// Floating Point Type subclass - Float8M5E2Type. -class PyFloat8E5M2Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; - static constexpr const char *pyClassName = "Float8E5M2Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E5M2TypeGet(context->get()); - return PyFloat8E5M2Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a float8_e5m2 type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3FNUZ. -class PyFloat8E4M3FNUZType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; - static constexpr const char *pyClassName = "Float8E4M3FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); - return PyFloat8E4M3FNUZType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a float8_e4m3fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3B11FNUZ. -class PyFloat8E4M3B11FNUZType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; - static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); - return PyFloat8E4M3B11FNUZType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E5M2FNUZ. -class PyFloat8E5M2FNUZType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; - static constexpr const char *pyClassName = "Float8E5M2FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); - return PyFloat8E5M2FNUZType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a float8_e5m2fnuz type."); - } -}; - -/// Floating Point Type subclass - BF16Type. -class PyBF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; - static constexpr const char *pyClassName = "BF16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirBF16TypeGet(context->get()); - return PyBF16Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a bf16 type."); - } -}; - -/// Floating Point Type subclass - F16Type. -class PyF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; - static constexpr const char *pyClassName = "F16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF16TypeGet(context->get()); - return PyF16Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f16 type."); - } -}; - -/// Floating Point Type subclass - F32Type. -class PyF32Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; - static constexpr const char *pyClassName = "F32Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF32TypeGet(context->get()); - return PyF32Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f32 type."); - } -}; - -/// Floating Point Type subclass - F64Type. -class PyF64Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; - static constexpr const char *pyClassName = "F64Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF64TypeGet(context->get()); - return PyF64Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f64 type."); - } -}; - -/// None Type subclass - NoneType. -class PyNoneType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; - static constexpr const char *pyClassName = "NoneType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirNoneTypeGet(context->get()); - return PyNoneType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a none type."); - } -}; - -/// Complex Type subclass - ComplexType. -class PyComplexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; - static constexpr const char *pyClassName = "ComplexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType) { - // The element must be a floating point or integer scalar type. - if (mlirTypeIsAIntegerOrFloat(elementType)) { - MlirType t = mlirComplexTypeGet(elementType); - return PyComplexType(elementType.getContext(), t); - } - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); - }, - "Create a complex type"); - c.def_property_readonly( - "element_type", - [](PyComplexType &self) -> PyType { - MlirType t = mlirComplexTypeGetElementType(self); - return PyType(self.getContext(), t); - }, - "Returns element type."); - } -}; - -class PyShapedType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; - static constexpr const char *pyClassName = "ShapedType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_property_readonly( - "element_type", - [](PyShapedType &self) { - MlirType t = mlirShapedTypeGetElementType(self); - return PyType(self.getContext(), t); - }, - "Returns the element type of the shaped type."); - c.def_property_readonly( - "has_rank", - [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, - "Returns whether the given shaped type is ranked."); - c.def_property_readonly( - "rank", - [](PyShapedType &self) { - self.requireHasRank(); - return mlirShapedTypeGetRank(self); - }, - "Returns the rank of the given ranked shaped type."); - c.def_property_readonly( - "has_static_shape", - [](PyShapedType &self) -> bool { - return mlirShapedTypeHasStaticShape(self); - }, - "Returns whether the given shaped type has a static shape."); - c.def( - "is_dynamic_dim", - [](PyShapedType &self, intptr_t dim) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicDim(self, dim); - }, - py::arg("dim"), - "Returns whether the dim-th dimension of the given shaped type is " - "dynamic."); - c.def( - "get_dim_size", - [](PyShapedType &self, intptr_t dim) { - self.requireHasRank(); - return mlirShapedTypeGetDimSize(self, dim); - }, - py::arg("dim"), - "Returns the dim-th dimension of the given ranked shaped type."); - c.def_static( - "is_dynamic_size", - [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, - py::arg("dim_size"), - "Returns whether the given dimension size indicates a dynamic " - "dimension."); - c.def( - "is_dynamic_stride_or_offset", - [](PyShapedType &self, int64_t val) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicStrideOrOffset(val); - }, - py::arg("dim_size"), - "Returns whether the given value is used as a placeholder for dynamic " - "strides and offsets in shaped types."); - c.def_property_readonly( - "shape", - [](PyShapedType &self) { - self.requireHasRank(); - - std::vector shape; - int64_t rank = mlirShapedTypeGetRank(self); - shape.reserve(rank); - for (int64_t i = 0; i < rank; ++i) - shape.push_back(mlirShapedTypeGetDimSize(self, i)); - return shape; - }, - "Returns the shape of the ranked shaped type as a list of integers."); - c.def_static( - "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, - "Returns the value used to indicate dynamic dimensions in shaped " - "types."); - c.def_static( - "get_dynamic_stride_or_offset", - []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, - "Returns the value used to indicate dynamic strides or offsets in " - "shaped types."); - } - -private: - void requireHasRank() { - if (!mlirShapedTypeHasRank(*this)) { - throw SetPyError( - PyExc_ValueError, - "calling this method requires that the type has a rank."); - } - } -}; - -/// Vector Type subclass - VectorType. -class PyVectorType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; - static constexpr const char *pyClassName = "VectorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), - elementType); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyVectorType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), - "Create a vector type"); - } -}; - -/// Ranked Tensor Type subclass - RankedTensorType. -class PyRankedTensorType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; - static constexpr const char *pyClassName = "RankedTensorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - std::optional &encodingAttr, DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirType t = mlirRankedTensorTypeGetChecked( - loc, shape.size(), shape.data(), elementType, - encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyRankedTensorType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("element_type"), - py::arg("encoding") = py::none(), py::arg("loc") = py::none(), - "Create a ranked tensor type"); - c.def_property_readonly( - "encoding", [](PyRankedTensorType &self) -> std::optional { - MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); - if (mlirAttributeIsNull(encoding)) - return std::nullopt; - return PyAttribute(self.getContext(), encoding); - }); - } -}; - -/// Unranked Tensor Type subclass - UnrankedTensorType. -class PyUnrankedTensorType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; - static constexpr const char *pyClassName = "UnrankedTensorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyUnrankedTensorType(elementType.getContext(), t); - }, - py::arg("element_type"), py::arg("loc") = py::none(), - "Create a unranked tensor type"); - } -}; - -/// Ranked MemRef Type subclass - MemRefType. -class PyMemRefType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; - static constexpr const char *pyClassName = "MemRefType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - PyAttribute *layout, PyAttribute *memorySpace, - DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); - MlirAttribute memSpaceAttr = - memorySpace ? *memorySpace : mlirAttributeGetNull(); - MlirType t = - mlirMemRefTypeGetChecked(loc, elementType, shape.size(), - shape.data(), layoutAttr, memSpaceAttr); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyMemRefType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::none(), py::arg("memory_space") = py::none(), - py::arg("loc") = py::none(), "Create a memref type") - .def_property_readonly( - "layout", - [](PyMemRefType &self) -> PyAttribute { - MlirAttribute layout = mlirMemRefTypeGetLayout(self); - return PyAttribute(self.getContext(), layout); - }, - "The layout of the MemRef type.") - .def_property_readonly( - "affine_map", - [](PyMemRefType &self) -> PyAffineMap { - MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); - return PyAffineMap(self.getContext(), map); - }, - "The layout of the MemRef type as an affine map.") - .def_property_readonly( - "memory_space", - [](PyMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); - }, - "Returns the memory space of the given MemRef type."); - } -}; - -/// Unranked MemRef Type subclass - UnrankedMemRefType. -class PyUnrankedMemRefType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; - static constexpr const char *pyClassName = "UnrankedMemRefType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, PyAttribute *memorySpace, - DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirAttribute memSpaceAttr = {}; - if (memorySpace) - memSpaceAttr = *memorySpace; - - MlirType t = - mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyUnrankedMemRefType(elementType.getContext(), t); - }, - py::arg("element_type"), py::arg("memory_space"), - py::arg("loc") = py::none(), "Create a unranked memref type") - .def_property_readonly( - "memory_space", - [](PyUnrankedMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); - }, - "Returns the memory space of the given Unranked MemRef type."); - } -}; - -/// Tuple Type subclass - TupleType. -class PyTupleType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; - static constexpr const char *pyClassName = "TupleType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_tuple", - [](py::list elementList, DefaultingPyMlirContext context) { - intptr_t num = py::len(elementList); - // Mapping py::list to SmallVector. - SmallVector elements; - for (auto element : elementList) - elements.push_back(element.cast()); - MlirType t = mlirTupleTypeGet(context->get(), num, elements.data()); - return PyTupleType(context->getRef(), t); - }, - py::arg("elements"), py::arg("context") = py::none(), - "Create a tuple type"); - c.def( - "get_type", - [](PyTupleType &self, intptr_t pos) -> PyType { - MlirType t = mlirTupleTypeGetType(self, pos); - return PyType(self.getContext(), t); - }, - py::arg("pos"), "Returns the pos-th type in the tuple type."); - c.def_property_readonly( - "num_types", - [](PyTupleType &self) -> intptr_t { - return mlirTupleTypeGetNumTypes(self); - }, - "Returns the number of types contained in a tuple."); - } -}; - -/// Function type. -class PyFunctionType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; - static constexpr const char *pyClassName = "FunctionType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector inputs, std::vector results, - DefaultingPyMlirContext context) { - SmallVector inputsRaw(inputs.begin(), inputs.end()); - SmallVector resultsRaw(results.begin(), results.end()); - MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(), - inputsRaw.data(), resultsRaw.size(), - resultsRaw.data()); - return PyFunctionType(context->getRef(), t); - }, - py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), - "Gets a FunctionType from a list of input and result types"); - c.def_property_readonly( - "inputs", - [](PyFunctionType &self) { - MlirType t = self; - auto contextRef = self.getContext(); - py::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; - ++i) { - types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i))); - } - return types; - }, - "Returns the list of input types in the FunctionType."); - c.def_property_readonly( - "results", - [](PyFunctionType &self) { - auto contextRef = self.getContext(); - py::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; - ++i) { - types.append( - PyType(contextRef, mlirFunctionTypeGetResult(self, i))); - } - return types; - }, - "Returns the list of result types in the FunctionType."); - } -}; - static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } -/// Opaque Type subclass - OpaqueType. -class PyOpaqueType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; - static constexpr const char *pyClassName = "OpaqueType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::string dialectNamespace, std::string typeData, - DefaultingPyMlirContext context) { - MlirType type = mlirOpaqueTypeGet(context->get(), - toMlirStringRef(dialectNamespace), - toMlirStringRef(typeData)); - return PyOpaqueType(context->getRef(), type); - }, - py::arg("dialect_namespace"), py::arg("buffer"), - py::arg("context") = py::none(), - "Create an unregistered (opaque) dialect type."); - c.def_property_readonly( - "dialect_namespace", - [](PyOpaqueType &self) { - MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); - return py::str(stringRef.data, stringRef.length); - }, - "Returns the dialect namespace for the Opaque type as a string."); - c.def_property_readonly( - "data", - [](PyOpaqueType &self) { - MlirStringRef stringRef = mlirOpaqueTypeGetData(self); - return py::str(stringRef.data, stringRef.length); - }, - "Returns the data for the Opaque type as a string."); +} // namespace + +namespace mlir::python { + +void PyIntegerType::bindDerived(ClassTy &c) { + c.def_static( + "get_signless", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + py::arg("width"), py::arg("context") = py::none(), + "Create a signless integer type"); + c.def_static( + "get_signed", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeSignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + py::arg("width"), py::arg("context") = py::none(), + "Create a signed integer type"); + c.def_static( + "get_unsigned", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + py::arg("width"), py::arg("context") = py::none(), + "Create an unsigned integer type"); + c.def_property_readonly( + "width", + [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, + "Returns the width of the integer type"); + c.def_property_readonly( + "is_signless", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsSignless(self); + }, + "Returns whether this is a signless integer"); + c.def_property_readonly( + "is_signed", + [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); }, + "Returns whether this is a signed integer"); + c.def_property_readonly( + "is_unsigned", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsUnsigned(self); + }, + "Returns whether this is an unsigned integer"); +} + +void PyIndexType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirIndexTypeGet(context->get()); + return PyIndexType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a index type."); +} + +void PyFloat8E4M3FNType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); + return PyFloat8E4M3FNType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3fn type."); +} + +void PyFloat8E5M2Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2TypeGet(context->get()); + return PyFloat8E5M2Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e5m2 type."); +} + +void PyFloat8E4M3FNUZType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); + return PyFloat8E4M3FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3fnuz type."); +} + +void PyFloat8E4M3B11FNUZType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); + return PyFloat8E4M3B11FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type."); +} + +void PyFloat8E5M2FNUZType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); + return PyFloat8E5M2FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e5m2fnuz type."); +} + +void PyBF16Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirBF16TypeGet(context->get()); + return PyBF16Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a bf16 type."); +} + +void PyF16Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF16TypeGet(context->get()); + return PyF16Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f16 type."); +} + +void PyF32Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF32TypeGet(context->get()); + return PyF32Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f32 type."); +} + +void PyF64Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF64TypeGet(context->get()); + return PyF64Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f64 type."); +} + +void PyNoneType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirNoneTypeGet(context->get()); + return PyNoneType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a none type."); +} + +void PyComplexType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType) { + // The element must be a floating point or integer scalar type. + if (mlirTypeIsAIntegerOrFloat(elementType)) { + MlirType t = mlirComplexTypeGet(elementType); + return PyComplexType(elementType.getContext(), t); + } + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point or integer type."); + }, + "Create a complex type"); + c.def_property_readonly( + "element_type", + [](PyComplexType &self) -> PyType { + MlirType t = mlirComplexTypeGetElementType(self); + return PyType(self.getContext(), t); + }, + "Returns element type."); +} + +void PyShapedType::bindDerived(ClassTy &c) { + c.def_property_readonly( + "element_type", + [](PyShapedType &self) { + MlirType t = mlirShapedTypeGetElementType(self); + return PyType(self.getContext(), t); + }, + "Returns the element type of the shaped type."); + c.def_property_readonly( + "has_rank", + [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, + "Returns whether the given shaped type is ranked."); + c.def_property_readonly( + "rank", + [](PyShapedType &self) { + self.requireHasRank(); + return mlirShapedTypeGetRank(self); + }, + "Returns the rank of the given ranked shaped type."); + c.def_property_readonly( + "has_static_shape", + [](PyShapedType &self) -> bool { + return mlirShapedTypeHasStaticShape(self); + }, + "Returns whether the given shaped type has a static shape."); + c.def( + "is_dynamic_dim", + [](PyShapedType &self, intptr_t dim) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicDim(self, dim); + }, + py::arg("dim"), + "Returns whether the dim-th dimension of the given shaped type is " + "dynamic."); + c.def( + "get_dim_size", + [](PyShapedType &self, intptr_t dim) { + self.requireHasRank(); + return mlirShapedTypeGetDimSize(self, dim); + }, + py::arg("dim"), + "Returns the dim-th dimension of the given ranked shaped type."); + c.def_static( + "is_dynamic_size", + [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, + py::arg("dim_size"), + "Returns whether the given dimension size indicates a dynamic " + "dimension."); + c.def( + "is_dynamic_stride_or_offset", + [](PyShapedType &self, int64_t val) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicStrideOrOffset(val); + }, + py::arg("dim_size"), + "Returns whether the given value is used as a placeholder for dynamic " + "strides and offsets in shaped types."); + c.def_property_readonly( + "shape", + [](PyShapedType &self) { + self.requireHasRank(); + + std::vector shape; + int64_t rank = mlirShapedTypeGetRank(self); + shape.reserve(rank); + for (int64_t i = 0; i < rank; ++i) + shape.push_back(mlirShapedTypeGetDimSize(self, i)); + return shape; + }, + "Returns the shape of the ranked shaped type as a list of integers."); + c.def_static( + "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, + "Returns the value used to indicate dynamic dimensions in shaped " + "types."); + c.def_static( + "get_dynamic_stride_or_offset", + []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, + "Returns the value used to indicate dynamic strides or offsets in " + "shaped types."); +} + +void PyShapedType::requireHasRank() { + if (!mlirShapedTypeHasRank(*this)) { + throw SetPyError(PyExc_ValueError, + "calling this method requires that the type has a rank."); } -}; +} -} // namespace +void PyVectorType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, + DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), + elementType); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyVectorType(elementType.getContext(), t); + }, + py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), + "Create a vector type"); +} + +void PyRankedTensorType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, + std::optional &encodingAttr, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType t = mlirRankedTensorTypeGetChecked( + loc, shape.size(), shape.data(), elementType, + encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyRankedTensorType(elementType.getContext(), t); + }, + py::arg("shape"), py::arg("element_type"), + py::arg("encoding") = py::none(), py::arg("loc") = py::none(), + "Create a ranked tensor type"); + c.def_property_readonly( + "encoding", [](PyRankedTensorType &self) -> std::optional { + MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); + if (mlirAttributeIsNull(encoding)) + return std::nullopt; + return PyAttribute(self.getContext(), encoding); + }); +} + +void PyUnrankedTensorType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedTensorType(elementType.getContext(), t); + }, + py::arg("element_type"), py::arg("loc") = py::none(), + "Create a unranked tensor type"); +} + +void PyMemRefType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, PyAttribute *layout, + PyAttribute *memorySpace, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); + MlirAttribute memSpaceAttr = + memorySpace ? *memorySpace : mlirAttributeGetNull(); + MlirType t = + mlirMemRefTypeGetChecked(loc, elementType, shape.size(), + shape.data(), layoutAttr, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyMemRefType(elementType.getContext(), t); + }, + py::arg("shape"), py::arg("element_type"), + py::arg("layout") = py::none(), py::arg("memory_space") = py::none(), + py::arg("loc") = py::none(), "Create a memref type") + .def_property_readonly( + "layout", + [](PyMemRefType &self) -> PyAttribute { + MlirAttribute layout = mlirMemRefTypeGetLayout(self); + return PyAttribute(self.getContext(), layout); + }, + "The layout of the MemRef type.") + .def_property_readonly( + "affine_map", + [](PyMemRefType &self) -> PyAffineMap { + MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); + return PyAffineMap(self.getContext(), map); + }, + "The layout of the MemRef type as an affine map.") + .def_property_readonly( + "memory_space", + [](PyMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); + }, + "Returns the memory space of the given MemRef type."); +} + +void PyUnrankedMemRefType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, PyAttribute *memorySpace, + DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + + MlirType t = + mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedMemRefType(elementType.getContext(), t); + }, + py::arg("element_type"), py::arg("memory_space"), + py::arg("loc") = py::none(), "Create a unranked memref type") + .def_property_readonly( + "memory_space", + [](PyUnrankedMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); + }, + "Returns the memory space of the given Unranked MemRef type."); +} + +void PyTupleType::bindDerived(ClassTy &c) { + c.def_static( + "get_tuple", + [](py::list elementList, DefaultingPyMlirContext context) { + intptr_t num = py::len(elementList); + // Mapping py::list to SmallVector. + SmallVector elements; + for (auto element : elementList) + elements.push_back(element.cast()); + MlirType t = mlirTupleTypeGet(context->get(), num, elements.data()); + return PyTupleType(context->getRef(), t); + }, + py::arg("elements"), py::arg("context") = py::none(), + "Create a tuple type"); + c.def( + "get_type", + [](PyTupleType &self, intptr_t pos) -> PyType { + MlirType t = mlirTupleTypeGetType(self, pos); + return PyType(self.getContext(), t); + }, + py::arg("pos"), "Returns the pos-th type in the tuple type."); + c.def_property_readonly( + "num_types", + [](PyTupleType &self) -> intptr_t { + return mlirTupleTypeGetNumTypes(self); + }, + "Returns the number of types contained in a tuple."); +} + +void PyFunctionType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector inputs, std::vector results, + DefaultingPyMlirContext context) { + SmallVector inputsRaw(inputs.begin(), inputs.end()); + SmallVector resultsRaw(results.begin(), results.end()); + MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(), + inputsRaw.data(), resultsRaw.size(), + resultsRaw.data()); + return PyFunctionType(context->getRef(), t); + }, + py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), + "Gets a FunctionType from a list of input and result types"); + c.def_property_readonly( + "inputs", + [](PyFunctionType &self) { + MlirType t = self; + auto contextRef = self.getContext(); + py::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; + ++i) { + types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i))); + } + return types; + }, + "Returns the list of input types in the FunctionType."); + c.def_property_readonly( + "results", + [](PyFunctionType &self) { + auto contextRef = self.getContext(); + py::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; + ++i) { + types.append(PyType(contextRef, mlirFunctionTypeGetResult(self, i))); + } + return types; + }, + "Returns the list of result types in the FunctionType."); +} + +void PyOpaqueType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string dialectNamespace, std::string typeData, + DefaultingPyMlirContext context) { + MlirType type = + mlirOpaqueTypeGet(context->get(), toMlirStringRef(dialectNamespace), + toMlirStringRef(typeData)); + return PyOpaqueType(context->getRef(), type); + }, + py::arg("dialect_namespace"), py::arg("buffer"), + py::arg("context") = py::none(), + "Create an unregistered (opaque) dialect type."); + c.def_property_readonly( + "dialect_namespace", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the dialect namespace for the Opaque type as a string."); + c.def_property_readonly( + "data", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetData(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the data for the Opaque type as a string."); +} + +} // namespace mlir::python void mlir::python::populateIRTypes(py::module &m) { PyIntegerType::bind(m); diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -2,6 +2,7 @@ import gc from mlir.ir import * +from mlir.dialects import arith, tensor, func def run(f): print("\nTEST:", f.__name__) @@ -509,3 +510,123 @@ print(type(ShapedType.get_dynamic_size())) # CHECK: print(type(ShapedType.get_dynamic_stride_or_offset())) + + +# CHECK-LABEL: TEST: testConcreteTypesRoundTrip +@run +def testConcreteTypesRoundTrip(): + with Context(), Location.unknown(): + f16 = F16Type.get() + # CHECK: concrete type instance: f16 + print("concrete type instance:", f16) + cst_f16 = arith.ConstantOp(f16, 0.0) + # CHECK: result type: f16 + print("result type:", cst_f16.type) + # CHECK: python type name: F16Type + print("python type name:", type(cst_f16.type).__name__) + + f32 = F32Type.get() + # CHECK: concrete type instance: f32 + print("concrete type instance:", f32) + cst_f32 = arith.ConstantOp(f32, 0.0) + # CHECK: result type: f32 + print("result type:", cst_f32.type) + # CHECK: python type name: F32Type + print("python type name:", type(cst_f32.type).__name__) + + f64 = F64Type.get() + # CHECK: concrete type instance: f64 + print("concrete type instance:", f64) + cst_f64 = arith.ConstantOp(f64, 0.0) + # CHECK: result type: f64 + print("result type:", cst_f64.type) + # CHECK: python type name: F64Type + print("python type name:", type(cst_f64.type).__name__) + + f8E4M3B11FNUZ = Float8E4M3B11FNUZType.get() + # CHECK: concrete type instance: f8E4M3B11FNUZ + print("concrete type instance:", f8E4M3B11FNUZ) + cst_f8E4M3B11FNUZ = arith.ConstantOp(f8E4M3B11FNUZ, 0.0) + # CHECK: result type: f8E4M3B11FNUZ + print("result type:", cst_f8E4M3B11FNUZ.type) + # CHECK: python type name: Float8E4M3B11FNUZType + print("python type name:", type(cst_f8E4M3B11FNUZ.type).__name__) + + f8E4M3FN = Float8E4M3FNType.get() + # CHECK: concrete type instance: f8E4M3FN + print("concrete type instance:", f8E4M3FN) + cst_f8E4M3FN = arith.ConstantOp(f8E4M3FN, 0.0) + # CHECK: result type: f8E4M3FN + print("result type:", cst_f8E4M3FN.type) + # CHECK: python type name: Float8E4M3FNType + print("python type name:", type(cst_f8E4M3FN.type).__name__) + + f8E4M3FNUZ = Float8E4M3FNUZType.get() + # CHECK: concrete type instance: f8E4M3FNUZ + print("concrete type instance:", f8E4M3FNUZ) + cst_f8E4M3FNUZ = arith.ConstantOp(f8E4M3FNUZ, 0.0) + # CHECK: result type: f8E4M3FNUZ + print("result type:", cst_f8E4M3FNUZ.type) + # CHECK: python type name: Float8E4M3FNUZType + print("python type name:", type(cst_f8E4M3FNUZ.type).__name__) + + f8E5M2 = Float8E5M2Type.get() + # CHECK: concrete type instance: f8E5M2 + print("concrete type instance:", f8E5M2) + cst_f8E5M2 = arith.ConstantOp(f8E5M2, 0.0) + # CHECK: result type: f8E5M2 + print("result type:", cst_f8E5M2.type) + # CHECK: python type name: Float8E5M2Type + print("python type name:", type(cst_f8E5M2.type).__name__) + + f8E5M2FNUZ = Float8E5M2FNUZType.get() + # CHECK: concrete type instance: f8E5M2FNUZ + print("concrete type instance:", f8E5M2FNUZ) + cst_f8E5M2FNUZ = arith.ConstantOp(f8E5M2FNUZ, 0.0) + # CHECK: result type: f8E5M2FNUZ + print("result type:", cst_f8E5M2FNUZ.type) + # CHECK: python type name: Float8E5M2FNUZType + print("python type name:", type(cst_f8E5M2FNUZ.type).__name__) + + bf16 = BF16Type.get() + # CHECK: concrete type instance: bf16 + print("concrete type instance:", bf16) + cst_bf16 = arith.ConstantOp(bf16, 0.0) + # CHECK: result type: bf16 + print("result type:", cst_bf16.type) + # CHECK: python type name: BF16Type + print("python type name:", type(cst_bf16.type).__name__) + + index = IndexType.get() + # CHECK: concrete type instance: index + print("concrete type instance:", index) + cst_index = arith.ConstantOp(index, 0) + # CHECK: result type: index + print("result type:", cst_index.type) + # CHECK: python type name: IndexType + print("python type name:", type(cst_index.type).__name__) + + integer = IntegerType.get_signless(32) + # CHECK: concrete type instance: i32 + print("concrete type instance:", integer) + cst_integer = arith.ConstantOp(integer, 0) + # CHECK: result type: i32 + print("result type:", cst_integer.type) + # CHECK: python type name: IntegerType + print("python type name:", type(cst_integer.type).__name__) + + ranked_tensor = tensor.EmptyOp([10, 10], f32).result + # CHECK: result type: tensor<10x10xf32> + print("result type:", ranked_tensor.type) + # CHECK: python type name: RankedTensorType + print("python type name:", type(ranked_tensor.type).__name__) + + @func.FuncOp.from_py_func() + def default_builder(): + c0 = arith.ConstantOp(f32, 0.0) + dynamic_shaped_type = UnrankedTensorType.get(f32) + t = tensor.FromElementsOp(dynamic_shaped_type, [c0]).result + # CHECK: result type: tensor<*xf32> + print("result type:", t.type) + # CHECK: python type name: UnrankedTensorType + print("python type name:", type(t.type).__name__)