diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "IRModule.h" @@ -1028,8 +1028,7 @@ py::arg("value"), py::arg("context") = py::none(), "Gets a uniqued Type attribute"); c.def_property_readonly("value", [](PyTypeAttribute &self) { - return PyType(self.getContext()->getRef(), - mlirTypeAttrGetValue(self.get())); + return mlirTypeAttrGetValue(self.get()); }); } }; diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -21,6 +21,8 @@ #include "llvm/ADT/SmallVector.h" #include +#include +#include #include namespace py = pybind11; @@ -3145,11 +3147,8 @@ "context", [](PyAttribute &self) { return self.getContext().getObject(); }, "Context that owns the Attribute") - .def_property_readonly("type", - [](PyAttribute &self) { - return PyType(self.getContext()->getRef(), - mlirAttributeGetType(self)); - }) + .def_property_readonly( + "type", [](PyAttribute &self) { return mlirAttributeGetType(self); }) .def( "get_named", [](PyAttribute &self, std::string name) { @@ -3244,7 +3243,7 @@ mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); if (mlirTypeIsNull(type)) throw MLIRError("Unable to parse type", errors.take()); - return PyType(context->getRef(), type); + return type; }, py::arg("asm"), py::arg("context") = py::none(), kContextParseTypeDocstring) @@ -3353,12 +3352,8 @@ return printAccum.join(); }, py::arg("use_local_scope") = false, kGetNameAsOperand) - .def_property_readonly("type", - [](PyValue &self) { - return PyType( - self.getParentOperation()->getContext(), - mlirValueGetType(self.get())); - }) + .def_property_readonly( + "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) .def( "replace_all_uses_with", [](PyValue &self, PyValue &with) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -17,6 +17,8 @@ #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" @@ -866,6 +868,14 @@ return DerivedTy::isaFunction(otherType); }, pybind11::arg("other")); + cls.def("__repr__", [](DerivedTy &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append(DerivedTy::pyClassName); + printAccum.parts.append("("); + mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); DerivedTy::bindDerived(cls); } @@ -960,9 +970,8 @@ return DerivedTy::isaFunction(otherAttr); }, pybind11::arg("other")); - cls.def_property_readonly("type", [](PyAttribute &attr) { - return PyType(attr.getContext(), mlirAttributeGetType(attr)); - }); + cls.def_property_readonly( + "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); DerivedTy::bindDerived(cls); } @@ -1135,18 +1144,101 @@ void populateIRInterfaces(pybind11::module &m); void populateIRTypes(pybind11::module &m); +// Note order here matters; firstly with respect to overhead; more common types +// should match earliest in the if-case tree. Secondly, concrete types that have +// a base type (such as all of the ShapedTypes) should come before the base type +// to prevent a "less-refined-than-possible" match. +#define FORALL_NO_BASE_CONCRETE_TYPES(_) \ + _(Index) \ + _(Integer) \ + _(F16) \ + _(F32) \ + _(F64) \ + _(Complex) \ + _(BF16) \ + _(Float8E4M3B11FNUZ) \ + _(Float8E4M3FN) \ + _(Float8E4M3FNUZ) \ + _(Float8E5M2) \ + _(Float8E5M2FNUZ) \ + _(Function) \ + _(None) \ + _(Opaque) \ + _(Tuple) + +#define FORALL_SHAPE_BASE_CONCRETE_TYPES(_) \ + _(RankedTensor) \ + _(UnrankedTensor) \ + _(UnrankedMemRef) \ + _(MemRef) \ + _(Vector) + +class PyShapedType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; + static constexpr const char *pyClassName = "ShapedType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); + +private: + void requireHasRank(); +}; + +#define DEFINE_WITH_BASE_CONCRETE_TYPE(CONCRETETYPE, BASETYPE) \ + class Py##CONCRETETYPE##Type \ + : public PyConcreteType { \ + public: \ + static constexpr IsAFunctionTy isaFunction = mlirTypeIsA##CONCRETETYPE; \ + static constexpr const char *pyClassName = #CONCRETETYPE "Type"; \ + using PyConcreteType::PyConcreteType; \ + static void bindDerived(ClassTy &c); \ + }; + +#define DEFINE_SHAPE_BASE_CONCRETE_TYPE(CONCRETETYPE) \ + DEFINE_WITH_BASE_CONCRETE_TYPE(CONCRETETYPE, ShapedType) +FORALL_SHAPE_BASE_CONCRETE_TYPES(DEFINE_SHAPE_BASE_CONCRETE_TYPE) +#undef DEFINE_SHAPE_BASE_CONCRETE_TYPE + +#define DEFINE_NO_BASE_CONCRETE_TYPE(CONCRETETYPE) \ + DEFINE_WITH_BASE_CONCRETE_TYPE(CONCRETETYPE, Type) +FORALL_NO_BASE_CONCRETE_TYPES(DEFINE_NO_BASE_CONCRETE_TYPE) +#undef DEFINE_NO_BASE_CONCRETE_TYPE + } // namespace python } // namespace mlir namespace pybind11 { namespace detail { +using namespace mlir::python; template <> -struct type_caster - : MlirDefaultingCaster {}; +struct type_caster + : MlirDefaultingCaster {}; +template <> +struct type_caster + : MlirDefaultingCaster {}; + +/// Casts MlirType that matches one of the concretes above -> ConcreteType. template <> -struct type_caster - : MlirDefaultingCaster {}; +struct type_caster { + PYBIND11_TYPE_CASTER(MlirType, _("MlirType")); + static handle cast(MlirType t, return_value_policy, handle) { + PyMlirContextRef context = PyMlirContext::forContext(mlirTypeGetContext(t)); + +#define DEFINE_TYPE_MATCH(TTT) \ + if (Py##TTT##Type::isaFunction(t)) { \ + return pybind11::cast(Py##TTT##Type(context, t)).release(); \ + } + FORALL_SHAPE_BASE_CONCRETE_TYPES(DEFINE_TYPE_MATCH) + // just do one (Shaped) + DEFINE_TYPE_MATCH(Shaped) + FORALL_NO_BASE_CONCRETE_TYPES(DEFINE_TYPE_MATCH) +#undef DEFINE_TYPE_MATCH + + return pybind11::cast(PyType(context, t)).release(); + } +}; } // namespace detail } // namespace pybind11 diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -29,687 +29,500 @@ mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); } -class PyIntegerType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; - static constexpr const char *pyClassName = "IntegerType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_signless", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create a signless integer type"); - c.def_static( - "get_signed", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeSignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create a signed integer type"); - c.def_static( - "get_unsigned", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create an unsigned integer type"); - c.def_property_readonly( - "width", - [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, - "Returns the width of the integer type"); - c.def_property_readonly( - "is_signless", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSignless(self); - }, - "Returns whether this is a signless integer"); - c.def_property_readonly( - "is_signed", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSigned(self); - }, - "Returns whether this is a signed integer"); - c.def_property_readonly( - "is_unsigned", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsUnsigned(self); - }, - "Returns whether this is an unsigned integer"); - } -}; - -/// Index Type subclass - IndexType. -class PyIndexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; - static constexpr const char *pyClassName = "IndexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirIndexTypeGet(context->get()); - return PyIndexType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a index type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3FNType. -class PyFloat8E4M3FNType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; - static constexpr const char *pyClassName = "Float8E4M3FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); - return PyFloat8E4M3FNType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a float8_e4m3fn type."); - } -}; - -/// Floating Point Type subclass - Float8M5E2Type. -class PyFloat8E5M2Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; - static constexpr const char *pyClassName = "Float8E5M2Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E5M2TypeGet(context->get()); - return PyFloat8E5M2Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a float8_e5m2 type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3FNUZ. -class PyFloat8E4M3FNUZType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; - static constexpr const char *pyClassName = "Float8E4M3FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); - return PyFloat8E4M3FNUZType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a float8_e4m3fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3B11FNUZ. -class PyFloat8E4M3B11FNUZType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; - static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); - return PyFloat8E4M3B11FNUZType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E5M2FNUZ. -class PyFloat8E5M2FNUZType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; - static constexpr const char *pyClassName = "Float8E5M2FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); - return PyFloat8E5M2FNUZType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a float8_e5m2fnuz type."); - } -}; - -/// Floating Point Type subclass - BF16Type. -class PyBF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; - static constexpr const char *pyClassName = "BF16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirBF16TypeGet(context->get()); - return PyBF16Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a bf16 type."); - } -}; - -/// Floating Point Type subclass - F16Type. -class PyF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; - static constexpr const char *pyClassName = "F16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF16TypeGet(context->get()); - return PyF16Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f16 type."); - } -}; - -/// Floating Point Type subclass - F32Type. -class PyF32Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; - static constexpr const char *pyClassName = "F32Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF32TypeGet(context->get()); - return PyF32Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f32 type."); - } -}; - -/// Floating Point Type subclass - F64Type. -class PyF64Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; - static constexpr const char *pyClassName = "F64Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF64TypeGet(context->get()); - return PyF64Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f64 type."); - } -}; - -/// None Type subclass - NoneType. -class PyNoneType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; - static constexpr const char *pyClassName = "NoneType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirNoneTypeGet(context->get()); - return PyNoneType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a none type."); - } -}; - -/// Complex Type subclass - ComplexType. -class PyComplexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; - static constexpr const char *pyClassName = "ComplexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType) { - // The element must be a floating point or integer scalar type. - if (mlirTypeIsAIntegerOrFloat(elementType)) { - MlirType t = mlirComplexTypeGet(elementType); - return PyComplexType(elementType.getContext(), t); - } - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); - }, - "Create a complex type"); - c.def_property_readonly( - "element_type", - [](PyComplexType &self) -> PyType { - MlirType t = mlirComplexTypeGetElementType(self); - return PyType(self.getContext(), t); - }, - "Returns element type."); - } -}; - -class PyShapedType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; - static constexpr const char *pyClassName = "ShapedType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_property_readonly( - "element_type", - [](PyShapedType &self) { - MlirType t = mlirShapedTypeGetElementType(self); - return PyType(self.getContext(), t); - }, - "Returns the element type of the shaped type."); - c.def_property_readonly( - "has_rank", - [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, - "Returns whether the given shaped type is ranked."); - c.def_property_readonly( - "rank", - [](PyShapedType &self) { - self.requireHasRank(); - return mlirShapedTypeGetRank(self); - }, - "Returns the rank of the given ranked shaped type."); - c.def_property_readonly( - "has_static_shape", - [](PyShapedType &self) -> bool { - return mlirShapedTypeHasStaticShape(self); - }, - "Returns whether the given shaped type has a static shape."); - c.def( - "is_dynamic_dim", - [](PyShapedType &self, intptr_t dim) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicDim(self, dim); - }, - py::arg("dim"), - "Returns whether the dim-th dimension of the given shaped type is " - "dynamic."); - c.def( - "get_dim_size", - [](PyShapedType &self, intptr_t dim) { - self.requireHasRank(); - return mlirShapedTypeGetDimSize(self, dim); - }, - py::arg("dim"), - "Returns the dim-th dimension of the given ranked shaped type."); - c.def_static( - "is_dynamic_size", - [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, - py::arg("dim_size"), - "Returns whether the given dimension size indicates a dynamic " - "dimension."); - c.def( - "is_dynamic_stride_or_offset", - [](PyShapedType &self, int64_t val) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicStrideOrOffset(val); - }, - py::arg("dim_size"), - "Returns whether the given value is used as a placeholder for dynamic " - "strides and offsets in shaped types."); - c.def_property_readonly( - "shape", - [](PyShapedType &self) { - self.requireHasRank(); - - std::vector shape; - int64_t rank = mlirShapedTypeGetRank(self); - shape.reserve(rank); - for (int64_t i = 0; i < rank; ++i) - shape.push_back(mlirShapedTypeGetDimSize(self, i)); - return shape; - }, - "Returns the shape of the ranked shaped type as a list of integers."); - c.def_static( - "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, - "Returns the value used to indicate dynamic dimensions in shaped " - "types."); - c.def_static( - "get_dynamic_stride_or_offset", - []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, - "Returns the value used to indicate dynamic strides or offsets in " - "shaped types."); - } - -private: - void requireHasRank() { - if (!mlirShapedTypeHasRank(*this)) { - throw SetPyError( - PyExc_ValueError, - "calling this method requires that the type has a rank."); - } - } -}; - -/// Vector Type subclass - VectorType. -class PyVectorType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; - static constexpr const char *pyClassName = "VectorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), - elementType); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyVectorType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), - "Create a vector type"); - } -}; - -/// Ranked Tensor Type subclass - RankedTensorType. -class PyRankedTensorType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; - static constexpr const char *pyClassName = "RankedTensorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - std::optional &encodingAttr, DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirType t = mlirRankedTensorTypeGetChecked( - loc, shape.size(), shape.data(), elementType, - encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyRankedTensorType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("element_type"), - py::arg("encoding") = py::none(), py::arg("loc") = py::none(), - "Create a ranked tensor type"); - c.def_property_readonly( - "encoding", [](PyRankedTensorType &self) -> std::optional { - MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); - if (mlirAttributeIsNull(encoding)) - return std::nullopt; - return PyAttribute(self.getContext(), encoding); - }); - } -}; - -/// Unranked Tensor Type subclass - UnrankedTensorType. -class PyUnrankedTensorType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; - static constexpr const char *pyClassName = "UnrankedTensorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyUnrankedTensorType(elementType.getContext(), t); - }, - py::arg("element_type"), py::arg("loc") = py::none(), - "Create a unranked tensor type"); - } -}; - -/// Ranked MemRef Type subclass - MemRefType. -class PyMemRefType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; - static constexpr const char *pyClassName = "MemRefType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - PyAttribute *layout, PyAttribute *memorySpace, - DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); - MlirAttribute memSpaceAttr = - memorySpace ? *memorySpace : mlirAttributeGetNull(); - MlirType t = - mlirMemRefTypeGetChecked(loc, elementType, shape.size(), - shape.data(), layoutAttr, memSpaceAttr); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyMemRefType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::none(), py::arg("memory_space") = py::none(), - py::arg("loc") = py::none(), "Create a memref type") - .def_property_readonly( - "layout", - [](PyMemRefType &self) -> PyAttribute { - MlirAttribute layout = mlirMemRefTypeGetLayout(self); - return PyAttribute(self.getContext(), layout); - }, - "The layout of the MemRef type.") - .def_property_readonly( - "affine_map", - [](PyMemRefType &self) -> PyAffineMap { - MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); - return PyAffineMap(self.getContext(), map); - }, - "The layout of the MemRef type as an affine map.") - .def_property_readonly( - "memory_space", - [](PyMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); - }, - "Returns the memory space of the given MemRef type."); - } -}; - -/// Unranked MemRef Type subclass - UnrankedMemRefType. -class PyUnrankedMemRefType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; - static constexpr const char *pyClassName = "UnrankedMemRefType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, PyAttribute *memorySpace, - DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirAttribute memSpaceAttr = {}; - if (memorySpace) - memSpaceAttr = *memorySpace; - - MlirType t = - mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyUnrankedMemRefType(elementType.getContext(), t); - }, - py::arg("element_type"), py::arg("memory_space"), - py::arg("loc") = py::none(), "Create a unranked memref type") - .def_property_readonly( - "memory_space", - [](PyUnrankedMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); - }, - "Returns the memory space of the given Unranked MemRef type."); - } -}; - -/// Tuple Type subclass - TupleType. -class PyTupleType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; - static constexpr const char *pyClassName = "TupleType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_tuple", - [](py::list elementList, DefaultingPyMlirContext context) { - intptr_t num = py::len(elementList); - // Mapping py::list to SmallVector. - SmallVector elements; - for (auto element : elementList) - elements.push_back(element.cast()); - MlirType t = mlirTupleTypeGet(context->get(), num, elements.data()); - return PyTupleType(context->getRef(), t); - }, - py::arg("elements"), py::arg("context") = py::none(), - "Create a tuple type"); - c.def( - "get_type", - [](PyTupleType &self, intptr_t pos) -> PyType { - MlirType t = mlirTupleTypeGetType(self, pos); - return PyType(self.getContext(), t); - }, - py::arg("pos"), "Returns the pos-th type in the tuple type."); - c.def_property_readonly( - "num_types", - [](PyTupleType &self) -> intptr_t { - return mlirTupleTypeGetNumTypes(self); - }, - "Returns the number of types contained in a tuple."); - } -}; - -/// Function type. -class PyFunctionType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; - static constexpr const char *pyClassName = "FunctionType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector inputs, std::vector results, - DefaultingPyMlirContext context) { - SmallVector inputsRaw(inputs.begin(), inputs.end()); - SmallVector resultsRaw(results.begin(), results.end()); - MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(), - inputsRaw.data(), resultsRaw.size(), - resultsRaw.data()); - return PyFunctionType(context->getRef(), t); - }, - py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), - "Gets a FunctionType from a list of input and result types"); - c.def_property_readonly( - "inputs", - [](PyFunctionType &self) { - MlirType t = self; - auto contextRef = self.getContext(); - py::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; - ++i) { - types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i))); - } - return types; - }, - "Returns the list of input types in the FunctionType."); - c.def_property_readonly( - "results", - [](PyFunctionType &self) { - auto contextRef = self.getContext(); - py::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; - ++i) { - types.append( - PyType(contextRef, mlirFunctionTypeGetResult(self, i))); - } - return types; - }, - "Returns the list of result types in the FunctionType."); - } -}; - static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } -/// Opaque Type subclass - OpaqueType. -class PyOpaqueType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; - static constexpr const char *pyClassName = "OpaqueType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::string dialectNamespace, std::string typeData, - DefaultingPyMlirContext context) { - MlirType type = mlirOpaqueTypeGet(context->get(), - toMlirStringRef(dialectNamespace), - toMlirStringRef(typeData)); - return PyOpaqueType(context->getRef(), type); - }, - py::arg("dialect_namespace"), py::arg("buffer"), - py::arg("context") = py::none(), - "Create an unregistered (opaque) dialect type."); - c.def_property_readonly( - "dialect_namespace", - [](PyOpaqueType &self) { - MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); - return py::str(stringRef.data, stringRef.length); - }, - "Returns the dialect namespace for the Opaque type as a string."); - c.def_property_readonly( - "data", - [](PyOpaqueType &self) { - MlirStringRef stringRef = mlirOpaqueTypeGetData(self); - return py::str(stringRef.data, stringRef.length); - }, - "Returns the data for the Opaque type as a string."); +} // namespace + +namespace mlir::python { + +void PyIntegerType::bindDerived(ClassTy &c) { + c.def_static( + "get_signless", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + py::arg("width"), py::arg("context") = py::none(), + "Create a signless integer type"); + c.def_static( + "get_signed", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeSignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + py::arg("width"), py::arg("context") = py::none(), + "Create a signed integer type"); + c.def_static( + "get_unsigned", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + py::arg("width"), py::arg("context") = py::none(), + "Create an unsigned integer type"); + c.def_property_readonly( + "width", + [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, + "Returns the width of the integer type"); + c.def_property_readonly( + "is_signless", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsSignless(self); + }, + "Returns whether this is a signless integer"); + c.def_property_readonly( + "is_signed", + [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); }, + "Returns whether this is a signed integer"); + c.def_property_readonly( + "is_unsigned", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsUnsigned(self); + }, + "Returns whether this is an unsigned integer"); +} + +void PyIndexType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirIndexTypeGet(context->get()); + return PyIndexType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a index type."); +} + +void PyFloat8E4M3FNType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); + return PyFloat8E4M3FNType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3fn type."); +} + +void PyFloat8E5M2Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2TypeGet(context->get()); + return PyFloat8E5M2Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e5m2 type."); +} + +void PyFloat8E4M3FNUZType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); + return PyFloat8E4M3FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3fnuz type."); +} + +void PyFloat8E4M3B11FNUZType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); + return PyFloat8E4M3B11FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type."); +} + +void PyFloat8E5M2FNUZType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); + return PyFloat8E5M2FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e5m2fnuz type."); +} + +void PyBF16Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirBF16TypeGet(context->get()); + return PyBF16Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a bf16 type."); +} + +void PyF16Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF16TypeGet(context->get()); + return PyF16Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f16 type."); +} + +void PyF32Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF32TypeGet(context->get()); + return PyF32Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f32 type."); +} + +void PyF64Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF64TypeGet(context->get()); + return PyF64Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f64 type."); +} + +void PyNoneType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirNoneTypeGet(context->get()); + return PyNoneType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a none type."); +} + +void PyComplexType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType) { + // The element must be a floating point or integer scalar type. + if (mlirTypeIsAIntegerOrFloat(elementType)) { + MlirType t = mlirComplexTypeGet(elementType); + return PyComplexType(elementType.getContext(), t); + } + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point or integer type."); + }, + "Create a complex type"); + c.def_property_readonly( + "element_type", + [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, + "Returns element type."); +} + +void PyShapedType::bindDerived(ClassTy &c) { + c.def_property_readonly( + "element_type", + [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, + "Returns the element type of the shaped type."); + c.def_property_readonly( + "has_rank", + [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, + "Returns whether the given shaped type is ranked."); + c.def_property_readonly( + "rank", + [](PyShapedType &self) { + self.requireHasRank(); + return mlirShapedTypeGetRank(self); + }, + "Returns the rank of the given ranked shaped type."); + c.def_property_readonly( + "has_static_shape", + [](PyShapedType &self) -> bool { + return mlirShapedTypeHasStaticShape(self); + }, + "Returns whether the given shaped type has a static shape."); + c.def( + "is_dynamic_dim", + [](PyShapedType &self, intptr_t dim) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicDim(self, dim); + }, + py::arg("dim"), + "Returns whether the dim-th dimension of the given shaped type is " + "dynamic."); + c.def( + "get_dim_size", + [](PyShapedType &self, intptr_t dim) { + self.requireHasRank(); + return mlirShapedTypeGetDimSize(self, dim); + }, + py::arg("dim"), + "Returns the dim-th dimension of the given ranked shaped type."); + c.def_static( + "is_dynamic_size", + [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, + py::arg("dim_size"), + "Returns whether the given dimension size indicates a dynamic " + "dimension."); + c.def( + "is_dynamic_stride_or_offset", + [](PyShapedType &self, int64_t val) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicStrideOrOffset(val); + }, + py::arg("dim_size"), + "Returns whether the given value is used as a placeholder for dynamic " + "strides and offsets in shaped types."); + c.def_property_readonly( + "shape", + [](PyShapedType &self) { + self.requireHasRank(); + + std::vector shape; + int64_t rank = mlirShapedTypeGetRank(self); + shape.reserve(rank); + for (int64_t i = 0; i < rank; ++i) + shape.push_back(mlirShapedTypeGetDimSize(self, i)); + return shape; + }, + "Returns the shape of the ranked shaped type as a list of integers."); + c.def_static( + "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, + "Returns the value used to indicate dynamic dimensions in shaped " + "types."); + c.def_static( + "get_dynamic_stride_or_offset", + []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, + "Returns the value used to indicate dynamic strides or offsets in " + "shaped types."); +} + +void PyShapedType::requireHasRank() { + if (!mlirShapedTypeHasRank(*this)) { + throw SetPyError(PyExc_ValueError, + "calling this method requires that the type has a rank."); } -}; +} -} // namespace +void PyVectorType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, + DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), + elementType); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyVectorType(elementType.getContext(), t); + }, + py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), + "Create a vector type"); +} + +void PyRankedTensorType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, + std::optional &encodingAttr, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType t = mlirRankedTensorTypeGetChecked( + loc, shape.size(), shape.data(), elementType, + encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyRankedTensorType(elementType.getContext(), t); + }, + py::arg("shape"), py::arg("element_type"), + py::arg("encoding") = py::none(), py::arg("loc") = py::none(), + "Create a ranked tensor type"); + c.def_property_readonly( + "encoding", [](PyRankedTensorType &self) -> std::optional { + MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); + if (mlirAttributeIsNull(encoding)) + return std::nullopt; + return PyAttribute(self.getContext(), encoding); + }); +} + +void PyUnrankedTensorType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedTensorType(elementType.getContext(), t); + }, + py::arg("element_type"), py::arg("loc") = py::none(), + "Create a unranked tensor type"); +} + +void PyMemRefType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, PyAttribute *layout, + PyAttribute *memorySpace, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); + MlirAttribute memSpaceAttr = + memorySpace ? *memorySpace : mlirAttributeGetNull(); + MlirType t = + mlirMemRefTypeGetChecked(loc, elementType, shape.size(), + shape.data(), layoutAttr, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyMemRefType(elementType.getContext(), t); + }, + py::arg("shape"), py::arg("element_type"), + py::arg("layout") = py::none(), py::arg("memory_space") = py::none(), + py::arg("loc") = py::none(), "Create a memref type") + .def_property_readonly( + "layout", + [](PyMemRefType &self) -> PyAttribute { + MlirAttribute layout = mlirMemRefTypeGetLayout(self); + return PyAttribute(self.getContext(), layout); + }, + "The layout of the MemRef type.") + .def_property_readonly( + "affine_map", + [](PyMemRefType &self) -> PyAffineMap { + MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); + return PyAffineMap(self.getContext(), map); + }, + "The layout of the MemRef type as an affine map.") + .def_property_readonly( + "memory_space", + [](PyMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); + }, + "Returns the memory space of the given MemRef type."); +} + +void PyUnrankedMemRefType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, PyAttribute *memorySpace, + DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + + MlirType t = + mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedMemRefType(elementType.getContext(), t); + }, + py::arg("element_type"), py::arg("memory_space"), + py::arg("loc") = py::none(), "Create a unranked memref type") + .def_property_readonly( + "memory_space", + [](PyUnrankedMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); + }, + "Returns the memory space of the given Unranked MemRef type."); +} + +void PyTupleType::bindDerived(ClassTy &c) { + c.def_static( + "get_tuple", + [](py::list elementList, DefaultingPyMlirContext context) { + intptr_t num = py::len(elementList); + // Mapping py::list to SmallVector. + SmallVector elements; + for (auto element : elementList) + elements.push_back(element.cast()); + return mlirTupleTypeGet(context->get(), num, elements.data()); + }, + py::arg("elements"), py::arg("context") = py::none(), + "Create a tuple type"); + c.def( + "get_type", + [](PyTupleType &self, intptr_t pos) { + return mlirTupleTypeGetType(self, pos); + }, + py::arg("pos"), "Returns the pos-th type in the tuple type."); + c.def_property_readonly( + "num_types", + [](PyTupleType &self) -> intptr_t { + return mlirTupleTypeGetNumTypes(self); + }, + "Returns the number of types contained in a tuple."); +} + +void PyFunctionType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector inputs, std::vector results, + DefaultingPyMlirContext context) { + SmallVector inputsRaw(inputs.begin(), inputs.end()); + SmallVector resultsRaw(results.begin(), results.end()); + MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(), + inputsRaw.data(), resultsRaw.size(), + resultsRaw.data()); + return PyFunctionType(context->getRef(), t); + }, + py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), + "Gets a FunctionType from a list of input and result types"); + c.def_property_readonly( + "inputs", + [](PyFunctionType &self) { + MlirType t = self; + auto contextRef = self.getContext(); + py::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; + ++i) { + types.append(mlirFunctionTypeGetInput(t, i)); + } + return types; + }, + "Returns the list of input types in the FunctionType."); + c.def_property_readonly( + "results", + [](PyFunctionType &self) { + auto contextRef = self.getContext(); + py::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; + ++i) { + types.append(mlirFunctionTypeGetResult(self, i)); + } + return types; + }, + "Returns the list of result types in the FunctionType."); +} + +void PyOpaqueType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string dialectNamespace, std::string typeData, + DefaultingPyMlirContext context) { + MlirType type = + mlirOpaqueTypeGet(context->get(), toMlirStringRef(dialectNamespace), + toMlirStringRef(typeData)); + return PyOpaqueType(context->getRef(), type); + }, + py::arg("dialect_namespace"), py::arg("buffer"), + py::arg("context") = py::none(), + "Create an unregistered (opaque) dialect type."); + c.def_property_readonly( + "dialect_namespace", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the dialect namespace for the Opaque type as a string."); + c.def_property_readonly( + "data", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetData(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the data for the Opaque type as a string."); +} + +} // namespace mlir::python void mlir::python::populateIRTypes(py::module &m) { PyIntegerType::bind(m); diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -557,3 +557,57 @@ print(f"rank: {len(attr.strides)}") # CHECK: strides are dynamic: [True, True, True] print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}") + + +# CHECK-LABEL: TEST: testConcreteTypesRoundTrip +@run +def testConcreteTypesRoundTrip(): + with Context(), Location.unknown(): + def print_item(attr): + print(repr(attr.type)) + + # CHECK: F32Type(f32) + print_item(Attribute.parse("42.0 : f32")) + # CHECK: F32Type(f32) + print_item(FloatAttr.get_f32(42.0)) + # CHECK: F64Type(f64) + print_item(FloatAttr.get_f64(42.0)) + # CHECK: IntegerType(i32) + print_item(IntegerAttr.get(IntegerType.get_signless(32), 42)) + # CHECK: IntegerType(i64) + print_item(IntegerAttr.get(IntegerType.get_signless(64), 42)) + + def print_container_item(attr_asm): + attr = DenseElementsAttr(Attribute.parse(attr_asm)) + print(repr(attr.type)) + print(repr(attr.type.element_type)) + + # CHECK: RankedTensorType(tensor) + # CHECK: IntegerType(i16) + print_container_item("dense<123> : tensor") + # CHECK: RankedTensorType(tensor) + # CHECK: IntegerType(i32) + print_container_item("dense<123> : tensor") + # CHECK: RankedTensorType(tensor) + # CHECK: IntegerType(i64) + print_container_item("dense<123> : tensor") + + # CHECK: RankedTensorType(tensor) + # CHECK: F16Type(f16) + print_container_item("dense<1.0> : tensor") + # CHECK: RankedTensorType(tensor) + # CHECK: F32Type(f32) + print_container_item("dense<1.0> : tensor") + # CHECK: RankedTensorType(tensor) + # CHECK: F64Type(f64) + print_container_item("dense<1.0> : tensor") + + raw = Attribute.parse("vector<4xf32>") + # CHECK: attr: vector<4xf32> + print("attr:", raw) + type_attr = TypeAttr(raw) + + # CHECK: VectorType(vector<4xf32>) + print(repr(type_attr.value)) + # CHECK: F32Type(f32) + print(repr(type_attr.value.element_type)) diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -2,6 +2,7 @@ import gc from mlir.ir import * +from mlir.dialects import arith, tensor, func, memref def run(f): print("\nTEST:", f.__name__) @@ -380,15 +381,15 @@ f32 = F32Type.get() shape = [2, 3] loc = Location.unknown() - memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) + memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) # CHECK: memref type: memref<2x3xf32, 2> - print("memref type:", memref) + print("memref type:", memref_f32) # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)> - print("memref layout:", memref.layout) + print("memref layout:", memref_f32.layout) # CHECK: memref affine map: (d0, d1) -> (d0, d1) - print("memref affine map:", memref.affine_map) + print("memref affine map:", memref_f32.affine_map) # CHECK: memory space: 2 - print("memory space:", memref.memory_space) + print("memory space:", memref_f32.memory_space) layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0])) memref_layout = MemRefType.get(shape, f32, layout=layout) @@ -411,7 +412,7 @@ else: print("Exception not produced") - assert memref.shape == shape + assert memref_f32.shape == shape # CHECK-LABEL: TEST: testUnrankedMemRefType @@ -481,9 +482,9 @@ IntegerType.get_signless(16)] result_types = [IndexType.get()] func = FunctionType.get(input_types, result_types) - # CHECK: INPUTS: [Type(i32), Type(i16)] + # CHECK: INPUTS: [IntegerType(i32), IntegerType(i16)] print("INPUTS:", func.inputs) - # CHECK: RESULTS: [Type(index)] + # CHECK: RESULTS: [IndexType(index)] print("RESULTS:", func.results) @@ -509,3 +510,122 @@ print(type(ShapedType.get_dynamic_size())) # CHECK: print(type(ShapedType.get_dynamic_stride_or_offset())) + + +# CHECK-LABEL: TEST: testConcreteTypesRoundTrip +@run +def testConcreteTypesRoundTrip(): + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + def print_item(typ, v): + cst = arith.ConstantOp(typ, v).result + print(type(cst.type).__name__) + print(repr(cst.type)) + + # CHECK: F16Type + # CHECK: F16Type(f16) + print_item(F16Type.get(), 0.0) + # CHECK: F32Type + # CHECK: F32Type(f32) + print_item(F32Type.get(), 0.0) + # CHECK: F64Type + # CHECK: F64Type(f64) + print_item(F64Type.get(), 0.0) + # CHECK: Float8E4M3B11FNUZType + # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ) + print_item(Float8E4M3B11FNUZType.get(), 0.0) + # CHECK: Float8E4M3FNType + # CHECK: Float8E4M3FNType(f8E4M3FN) + print_item(Float8E4M3FNType.get(), 0.0) + # CHECK: Float8E4M3FNUZType + # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ) + print_item(Float8E4M3FNUZType.get(), 0.0) + # CHECK: Float8E5M2Type + # CHECK: Float8E5M2Type(f8E5M2) + print_item(Float8E5M2Type.get(), 0.0) + # CHECK: Float8E5M2FNUZType + # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ) + print_item(Float8E5M2FNUZType.get(), 0.0) + # CHECK: BF16Type + # CHECK: BF16Type(bf16) + print_item(BF16Type.get(), 0.0) + # CHECK: IndexType + # CHECK: IndexType(index) + print_item(IndexType.get(), 0) + # CHECK: IntegerType + # CHECK: IntegerType(i32) + print_item(IntegerType.get_signless(32), 0) + + f32 = F32Type.get() + ranked_tensor = tensor.EmptyOp([10, 10], f32).result + # CHECK: RankedTensorType + print(type(ranked_tensor.type).__name__) + # CHECK: RankedTensorType(tensor<10x10xf32>) + print(repr(ranked_tensor.type)) + + cf32 = ComplexType.get(f32) + # CHECK: ComplexType + print(type(cf32).__name__) + # CHECK: ComplexType(complex) + print(repr(cf32)) + + ranked_tensor = tensor.EmptyOp([10, 10], f32).result + # CHECK: RankedTensorType + print(type(ranked_tensor.type).__name__) + # CHECK: RankedTensorType(tensor<10x10xf32>) + print(repr(ranked_tensor.type)) + + vector = VectorType.get([10, 10], f32) + tuple_type = TupleType.get_tuple([f32, vector]) + # CHECK: TupleType + print(type(tuple_type).__name__) + # CHECK: TupleType(tuple>) + print(repr(tuple_type)) + # CHECK: F32Type(f32) + print(repr(tuple_type.get_type(0))) + # CHECK: VectorType(vector<10x10xf32>) + print(repr(tuple_type.get_type(1))) + + index_type = IndexType.get() + @func.FuncOp.from_py_func() + def default_builder(): + c0 = arith.ConstantOp(f32, 0.0) + unranked_tensor_type = UnrankedTensorType.get(f32) + unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result + # CHECK: UnrankedTensorType + print(type(unranked_tensor.type).__name__) + # CHECK: UnrankedTensorType(tensor<*xf32>) + print(repr(unranked_tensor.type)) + + c10 = arith.ConstantOp(index_type, 10) + memref_f32_t = MemRefType.get([10, 10], f32) + memref_f32 = memref.AllocOp(memref_f32_t, [c10, c10], []).result + # CHECK: MemRefType + print(type(memref_f32.type).__name__) + # CHECK: MemRefType(memref<10x10xf32>) + print(repr(memref_f32.type)) + + unranked_memref_t = UnrankedMemRefType.get(f32, Attribute.parse("2")) + memref_f32 = memref.AllocOp(unranked_memref_t, [c10, c10], []).result + # CHECK: UnrankedMemRefType + print(type(memref_f32.type).__name__) + # CHECK: UnrankedMemRefType(memref<*xf32, 2>) + print(repr(memref_f32.type)) + + tuple_type = Operation.parse(f'"test.make_tuple"() : () -> tuple').result + # CHECK: TupleType + print(type(tuple_type.type).__name__) + # CHECK: TupleType(tuple) + print(repr(tuple_type.type)) + + return c0, c10 + + func_op = default_builder.func_op + # CHECK: FunctionType + print(type(func_op.type).__name__) + # CHECK: FunctionType(() -> (f32, index)) + print(repr(func_op.type)) + # CHECK: [] + print(func_op.type.inputs) + # CHECK: [F32Type(f32), IndexType(index)] + print(func_op.type.results)