diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -168,6 +168,12 @@ mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); } +/// Checks whether the given type is an valid element type of tensor or memref. +int mlirTypeIsTensorOrMemRefElement(MlirType type) { + return mlirTypeIsAIntegerOrFloat(type) || mlirTypeIsAVector(type) || + mlirTypeIsAComplex(type); +} + } // namespace //------------------------------------------------------------------------------ @@ -516,6 +522,56 @@ } }; +/// Function template for shaped type. +template +void bindShaped(py::class_ &c) { + c.def_property_readonly( + "element_type", + [](PyType &self) { + MlirType t = mlirShapedTypeGetElementType(self.type); + return PyType(t); + }, + py::keep_alive<0, 1>(), "Returns the element type of the shaped type."); + c.def_property_readonly( + "has_rank", + [](PyType &self) -> bool { return mlirShapedTypeHasRank(self.type); }, + "Returns whether the given shaped type is ranked."); + c.def_property_readonly( + "rank", [](PyType &self) { return mlirShapedTypeGetRank(self.type); }, + "Returns the rank of the given ranked shaped type."); + c.def_property_readonly( + "has_static_shape", + [](PyType &self) -> bool { + return mlirShapedTypeHasStaticShape(self.type); + }, + "Returns whether the given shaped type has a static shape."); + c.def( + "is_dynamic_dim", + [](PyType &self, intptr_t dim) -> bool { + return mlirShapedTypeIsDynamicDim(self.type, dim); + }, + "Returns whether the dim-th dimension of the given shaped type is " + "dynamic."); + c.def( + "get_dim_size", + [](PyType &self, intptr_t dim) { + return mlirShapedTypeGetDimSize(self.type, dim); + }, + "Returns the dim-th dimension of the given ranked shaped type."); + c.def_static( + "is_dynamic_size", + [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, + "Returns whether the given dimension size indicates a dynamic " + "dimension."); + c.def_static( + "is_dynamic_stride_or_offset", + [](int64_t val) -> bool { + return mlirShapedTypeIsDynamicStrideOrOffset(val); + }, + "Returns whether the given value is used as a placeholder for dynamic " + "strides and offsets in shaped types."); +} + /// Vector Type subclass - VectorType. class PyVectorType : public PyConcreteType { public: @@ -540,6 +596,148 @@ "' and expected floating point or integer type."); }, py::keep_alive<0, 2>(), "Create a vector type"); + // Binds the base shaped type for vector type. + bindShaped(c); + } +}; + +/// Ranked Tensor Type subclass - RankedTensorType. +class PyRankedTensorType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr const char *pyClassName = "RankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_ranked_tensor", + [](std::vector shape, PyType &elementType) { + // The element must be a floating point/integer/vector/complex type. + if (mlirTypeIsTensorOrMemRefElement(elementType.type)) { + MlirType t = mlirRankedTensorTypeGet(shape.size(), shape.data(), + elementType.type); + return PyRankedTensorType(t); + } + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or complex " + "type."); + }, + py::keep_alive<0, 2>(), "Create a ranked tensor type"); + // Binds the base shaped type for ranked tensor type. + bindShaped(c); + } +}; + +/// Unranked Tensor Type subclass - UnrankedTensorType. +class PyUnrankedTensorType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; + static constexpr const char *pyClassName = "UnrankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_unranked_tensor", + [](PyType &elementType) { + // The element must be a floating point/integer/vector/complex type. + if (mlirTypeIsTensorOrMemRefElement(elementType.type)) { + MlirType t = mlirUnrankedTensorTypeGet(elementType.type); + return PyUnrankedTensorType(t); + } + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or complex " + "type."); + }, + py::keep_alive<0, 1>(), "Create a unranked tensor type"); + // Binds the base shaped type for unranked tensor type. + bindShaped(c); + } +}; + +/// Ranked MemRef Type subclass - MemRefType. +class PyMemRefType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr const char *pyClassName = "MemRefType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding once + // the affine map binding is completed. + c.def_static( + "get_contiguous_memref", + [](PyType &elementType, std::vector shape, + unsigned memorySpace) { + // The element must be a floating point/integer/vector/complex type. + if (mlirTypeIsTensorOrMemRefElement(elementType.type)) { + MlirType t = mlirMemRefTypeContiguousGet( + elementType.type, shape.size(), shape.data(), memorySpace); + return PyMemRefType(t); + } + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or complex " + "type."); + }, + py::keep_alive<0, 1>(), "Create a memref type"); + c.def_property_readonly( + "num_affine_maps", + [](PyMemRefType &self) -> intptr_t { + return mlirMemRefTypeGetNumAffineMaps(self.type); + }, + "Returns the number of affine layout maps in the given MemRef type."); + c.def_property_readonly( + "memory_space", + [](PyMemRefType &self) -> unsigned { + return mlirMemRefTypeGetMemorySpace(self.type); + }, + "Returns the memory space of the given MemRef type."); + // Binds the base shaped type for ranked memref type. + bindShaped(c); + } +}; + +/// Unranked MemRef Type subclass - UnrankedMemRefType. +class PyUnrankedMemRefType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; + static constexpr const char *pyClassName = "UnrankedMemRefType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_unranked_memref", + [](PyType &elementType, unsigned memorySpace) { + // The element must be a floating point/integer/vector/complex type. + if (mlirTypeIsTensorOrMemRefElement(elementType.type)) { + MlirType t = + mlirUnrankedMemRefTypeGet(elementType.type, memorySpace); + return PyUnrankedMemRefType(t); + } + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or complex " + "type."); + }, + py::keep_alive<0, 1>(), "Create a unranked memref type"); + c.def_property_readonly( + "memory_space", + [](PyUnrankedMemRefType &self) -> unsigned { + return mlirUnrankedMemrefGetMemorySpace(self.type); + }, + "Returns the memory spcae of the given Unranked MemRef type."); + // Binds the base shaped type for unranked memref type. + bindShaped(c); } }; @@ -887,5 +1085,9 @@ PyNoneType::bind(m); PyComplexType::bind(m); PyVectorType::bind(m); + PyRankedTensorType::bind(m); + PyUnrankedTensorType::bind(m); + PyMemRefType::bind(m); + PyUnrankedMemRefType::bind(m); PyTupleType::bind(m); } diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -177,6 +177,34 @@ run(testComplexType) +# CHECK-LABEL: TEST: testShapedType +# Shaped type is not a kind of standard types, it is the base class for +# vectors, memrefs and tensors, so this test case uses an instance of vector +# to test the shaped type. +def testShapedType(): + ctx = mlir.ir.Context() + f32 = mlir.ir.F32Type(ctx) + shape = [2, 3] + vector = mlir.ir.VectorType.get_vector(shape, f32) + # CHECK: element type: f32 + print("element type:", vector.element_type) + # CHECK: whether the given shaped type is ranked: True + print("whether the given shaped type is ranked:", vector.has_rank) + # CHECK: rank: 2 + print("rank:", vector.rank) + # CHECK: whether the shaped type has a static shape: True + print("whether the shaped type has a static shape:", vector.has_static_shape) + # CHECK: whether the dim-th dimension is dynamic: False + print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0)) + # CHECK: dim size: 3 + print("dim size:", vector.get_dim_size(1)) + # CHECK: False + print(vector.is_dynamic_size(3)) + # CHECK: False + print(vector.is_dynamic_stride_or_offset(1)) + +run(testShapedType) + # CHECK-LABEL: TEST: testVectorType def testVectorType(): ctx = mlir.ir.Context() @@ -196,6 +224,92 @@ run(testVectorType) +# CHECK-LABEL: TEST: testRankedTensorType +def testRankedTensorType(): + ctx = mlir.ir.Context() + f32 = mlir.ir.F32Type(ctx) + shape = [2, 3] + # CHECK: ranked tensor type: tensor<2x3xf32> + print("ranked tensor type:", + mlir.ir.RankedTensorType.get_ranked_tensor(shape, f32)) + + index = mlir.ir.IndexType(ctx) + try: + tensor_invalid = mlir.ir.RankedTensorType.get_ranked_tensor(shape, index) + except ValueError as e: + # CHECK: invalid 'Type(index)' and expected floating point, integer, vector + # CHECK: or complex type. + print(e) + else: + print("Exception not produced") + +run(testRankedTensorType) + +# CHECK-LABEL: TEST: testUnrankedTensorType +def testUnrankedTensorType(): + ctx = mlir.ir.Context() + f32 = mlir.ir.F32Type(ctx) + # CHECK: unranked tensor type: tensor<*xf32> + print("unranked tensor type:", + mlir.ir.UnrankedTensorType.get_unranked_tensor(f32)) + + index = mlir.ir.IndexType(ctx) + try: + tensor_invalid = mlir.ir.UnrankedTensorType.get_unranked_tensor(index) + except ValueError as e: + # CHECK: invalid 'Type(index)' and expected floating point, integer, vector + # CHECK: or complex type. + print(e) + else: + print("Exception not produced") + +run(testUnrankedTensorType) + +# CHECK-LABEL: TEST: testMemRefType +def testMemRefType(): + ctx = mlir.ir.Context() + f32 = mlir.ir.F32Type(ctx) + shape = [2, 3] + memref = mlir.ir.MemRefType.get_contiguous_memref(f32, shape, 2) + # CHECK: memref type: memref<2x3xf32, 2> + print("memref type:", memref) + # CHECK: number of affine layout maps: 0 + print("number of affine layout maps:", memref.num_affine_maps) + # CHECK: memory space: 2 + print("memory space:", memref.memory_space) + + index = mlir.ir.IndexType(ctx) + try: + memref_invalid = mlir.ir.MemRefType.get_contiguous_memref(index, shape, 2) + except ValueError as e: + # CHECK: invalid 'Type(index)' and expected floating point, integer, vector + # CHECK: or complex type. + print(e) + else: + print("Exception not produced") + +run(testMemRefType) + +# CHECK-LABEL: TEST: testUnrankedMemRefType +def testUnrankedMemRefType(): + ctx = mlir.ir.Context() + f32 = mlir.ir.F32Type(ctx) + memref = mlir.ir.UnrankedMemRefType.get_unranked_memref(f32, 2) + # CHECK: unranked memref type: memref<*xf32, 2> + print("unranked memref type:", memref) + + index = mlir.ir.IndexType(ctx) + try: + memref_invalid = mlir.ir.UnrankedMemRefType.get_unranked_memref(index, 2) + except ValueError as e: + # CHECK: invalid 'Type(index)' and expected floating point, integer, vector + # CHECK: or complex type. + print(e) + else: + print("Exception not produced") + +run(testUnrankedMemRefType) + # CHECK-LABEL: TEST: testTupleType def testTupleType(): ctx = mlir.ir.Context()