diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -912,6 +912,150 @@ } }; +// TODO: Support construction of bool elements. +// TODO: Support construction of string elements. +class PyDenseElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; + static constexpr const char *pyClassName = "DenseElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static PyDenseElementsAttribute getFromBuffer(PyMlirContext &contextWrapper, + py::buffer array, + bool signless) { + // Request a contiguous view. In exotic cases, this will cause a copy. + int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; + Py_buffer *view = new Py_buffer(); + if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { + delete view; + throw py::error_already_set(); + } + py::buffer_info arrayInfo(view); + + MlirContext context = contextWrapper.get(); + // Switch on the types that can be bulk loaded between the Python and + // MLIR-C APIs. + if (arrayInfo.format == "f") { + // f32 + assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); + return PyDenseElementsAttribute( + contextWrapper.getRef(), + bulkLoad(context, mlirDenseElementsAttrFloatGet, + mlirF32TypeGet(context), arrayInfo)); + } else if (arrayInfo.format == "d") { + // f64 + assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); + return PyDenseElementsAttribute( + contextWrapper.getRef(), + bulkLoad(context, mlirDenseElementsAttrDoubleGet, + mlirF64TypeGet(context), arrayInfo)); + } else if (arrayInfo.format == "i") { + // i32 + assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); + MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); + return PyDenseElementsAttribute(contextWrapper.getRef(), + bulkLoad(context, + mlirDenseElementsAttrInt32Get, + elementType, arrayInfo)); + } else if (arrayInfo.format == "I") { + // unsigned i32 + assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); + MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); + return PyDenseElementsAttribute(contextWrapper.getRef(), + bulkLoad(context, + mlirDenseElementsAttrUInt32Get, + elementType, arrayInfo)); + } else if (arrayInfo.format == "l") { + // i64 + assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); + MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + return PyDenseElementsAttribute(contextWrapper.getRef(), + bulkLoad(context, + mlirDenseElementsAttrInt64Get, + elementType, arrayInfo)); + } else if (arrayInfo.format == "L") { + // unsigned i64 + assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); + MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + return PyDenseElementsAttribute(contextWrapper.getRef(), + bulkLoad(context, + mlirDenseElementsAttrUInt64Get, + elementType, arrayInfo)); + } + + // TODO: Fall back to string-based get. + std::string message = "unimplemented array format conversion from format: "; + message.append(arrayInfo.format); + throw SetPyError(PyExc_ValueError, message); + } + + static PyDenseElementsAttribute getSplat(PyType shapedType, + PyAttribute &elementAttr) { + auto contextWrapper = + PyMlirContext::forContext(mlirTypeGetContext(shapedType)); + if (!mlirAttributeIsAInteger(elementAttr.attr) && + !mlirAttributeIsAFloat(elementAttr.attr)) { + std::string message = "Illegal element type for DenseElementsAttr: "; + message.append(py::repr(py::cast(elementAttr))); + throw SetPyError(PyExc_ValueError, message); + } + if (!mlirTypeIsAShaped(shapedType) || + !mlirShapedTypeHasStaticShape(shapedType)) { + std::string message = + "Expected a static ShapedType for the shaped_type parameter: "; + message.append(py::repr(py::cast(shapedType))); + throw SetPyError(PyExc_ValueError, message); + } + MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType.type); + MlirType attrType = mlirAttributeGetType(elementAttr.attr); + if (!mlirTypeEqual(shapedElementType, attrType)) { + std::string message = + "Shaped element type and attribute type must be equal: shaped="; + message.append(py::repr(py::cast(shapedType))); + message.append(", element="); + message.append(py::repr(py::cast(elementAttr))); + throw SetPyError(PyExc_ValueError, message); + } + + MlirAttribute elements = + mlirDenseElementsAttrSplatGet(shapedType.type, elementAttr.attr); + return PyDenseElementsAttribute(contextWrapper->getRef(), elements); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", PyDenseElementsAttribute::getFromBuffer, + py::arg("context"), py::arg("array"), + py::arg("signless") = true, "Gets from a buffer or ndarray") + .def_static("get_splat", PyDenseElementsAttribute::getSplat, + py::arg("shaped_type"), py::arg("element_attr"), + "Gets a DenseElementsAttr where all values are the same") + .def_property_readonly("is_splat", + [](PyDenseElementsAttribute &self) -> bool { + return mlirDenseElementsAttrIsSplat(self.attr); + }); + } + +private: + template + static MlirAttribute + bulkLoad(MlirContext context, + MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), + MlirType mlirElementType, py::buffer_info &arrayInfo) { + SmallVector shape(arrayInfo.shape.begin(), + arrayInfo.shape.begin() + arrayInfo.ndim); + auto shapedType = + mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType); + intptr_t numElements = arrayInfo.size; + const ElementTy *contents = static_cast(arrayInfo.ptr); + return ctor(shapedType, numElements, contents); + } +}; + } // namespace //------------------------------------------------------------------------------ @@ -1021,11 +1165,13 @@ using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def(py::init([](PyMlirContext &context) { - MlirType t = mlirIndexTypeGet(context.get()); - return PyIndexType(context.getRef(), t); - }), - "Create a index type."); + c.def_static( + "get", + [](PyMlirContext &context) { + MlirType t = mlirIndexTypeGet(context.get()); + return PyIndexType(context.getRef(), t); + }, + "Create a index type."); } }; @@ -1037,11 +1183,13 @@ using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def(py::init([](PyMlirContext &context) { - MlirType t = mlirBF16TypeGet(context.get()); - return PyBF16Type(context.getRef(), t); - }), - "Create a bf16 type."); + c.def_static( + "get", + [](PyMlirContext &context) { + MlirType t = mlirBF16TypeGet(context.get()); + return PyBF16Type(context.getRef(), t); + }, + "Create a bf16 type."); } }; @@ -1053,11 +1201,13 @@ using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def(py::init([](PyMlirContext &context) { - MlirType t = mlirF16TypeGet(context.get()); - return PyF16Type(context.getRef(), t); - }), - "Create a f16 type."); + c.def_static( + "get", + [](PyMlirContext &context) { + MlirType t = mlirF16TypeGet(context.get()); + return PyF16Type(context.getRef(), t); + }, + "Create a f16 type."); } }; @@ -1069,11 +1219,13 @@ using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def(py::init([](PyMlirContext &context) { - MlirType t = mlirF32TypeGet(context.get()); - return PyF32Type(context.getRef(), t); - }), - "Create a f32 type."); + c.def_static( + "get", + [](PyMlirContext &context) { + MlirType t = mlirF32TypeGet(context.get()); + return PyF32Type(context.getRef(), t); + }, + "Create a f32 type."); } }; @@ -1085,11 +1237,13 @@ using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def(py::init([](PyMlirContext &context) { - MlirType t = mlirF64TypeGet(context.get()); - return PyF64Type(context.getRef(), t); - }), - "Create a f64 type."); + c.def_static( + "get", + [](PyMlirContext &context) { + MlirType t = mlirF64TypeGet(context.get()); + return PyF64Type(context.getRef(), t); + }, + "Create a f64 type."); } }; @@ -1101,11 +1255,13 @@ using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def(py::init([](PyMlirContext &context) { - MlirType t = mlirNoneTypeGet(context.get()); - return PyNoneType(context.getRef(), t); - }), - "Create a none type."); + c.def_static( + "get", + [](PyMlirContext &context) { + MlirType t = mlirNoneTypeGet(context.get()); + return PyNoneType(context.getRef(), t); + }, + "Create a none type."); } }; @@ -1118,7 +1274,7 @@ static void bindDerived(ClassTy &c) { c.def_static( - "get_complex", + "get", [](PyType &elementType) { // The element must be a floating point or integer scalar type. if (mlirTypeIsAIntegerOrFloat(elementType.type)) { @@ -1224,7 +1380,7 @@ static void bindDerived(ClassTy &c) { c.def_static( - "get_vector", + "get", // TODO: Make the location optional and create a default location. [](std::vector shape, PyType &elementType, PyLocation &loc) { MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(), @@ -1254,7 +1410,7 @@ static void bindDerived(ClassTy &c) { c.def_static( - "get_ranked_tensor", + "get", // TODO: Make the location optional and create a default location. [](std::vector shape, PyType &elementType, PyLocation &loc) { MlirType t = mlirRankedTensorTypeGetChecked( @@ -1286,7 +1442,7 @@ static void bindDerived(ClassTy &c) { c.def_static( - "get_unranked_tensor", + "get", // TODO: Make the location optional and create a default location. [](PyType &elementType, PyLocation &loc) { MlirType t = @@ -1366,7 +1522,7 @@ static void bindDerived(ClassTy &c) { c.def_static( - "get_unranked_memref", + "get", // TODO: Make the location optional and create a default location. [](PyType &elementType, unsigned memorySpace, PyLocation &loc) { MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type, @@ -1719,6 +1875,11 @@ "context", [](PyAttribute &self) { return self.getContext().getObject(); }, "Context that owns the Attribute") + .def_property_readonly("type", + [](PyAttribute &self) { + return PyType(self.getContext()->getRef(), + mlirAttributeGetType(self.attr)); + }) .def( "get_named", [](PyAttribute &self, std::string name) { @@ -1796,6 +1957,7 @@ PyIntegerAttribute::bind(m); PyBoolAttribute::bind(m); PyStringAttribute::bind(m); + PyDenseElementsAttribute::bind(m); // Mapping of Type. py::class_(m, "Type") diff --git a/mlir/test/Bindings/Python/ir_array_attributes.py b/mlir/test/Bindings/Python/ir_array_attributes.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/ir_array_attributes.py @@ -0,0 +1,213 @@ +# RUN: %PYTHON %s | FileCheck %s +# Note that this is separate from ir_attributes.py since it depends on numpy, +# and we may want to disable if not available. + +import gc +import mlir +import numpy as np + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert mlir.ir.Context._get_live_count() == 0 + +################################################################################ +# Tests of the array/buffer .get() factory method on unsupported dtype. +################################################################################ + +def testGetDenseElementsUnsupported(): + ctx = mlir.ir.Context() + array = np.array([["hello", "goodbye"]]) + try: + attr = mlir.ir.DenseElementsAttr.get(ctx, array) + except ValueError as e: + # CHECK: unimplemented array format conversion from format: + print(e) + +run(testGetDenseElementsUnsupported) + +################################################################################ +# Splats. +################################################################################ + +# CHECK-LABEL: TEST: testGetDenseElementsSplatInt +def testGetDenseElementsSplatInt(): + ctx = mlir.ir.Context() + loc = ctx.get_unknown_location() + t = mlir.ir.IntegerType.get_signless(ctx, 32) + element = mlir.ir.IntegerAttr.get(t, 555) + shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc) + attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, element) + # CHECK: dense<555> : tensor<2x3x4xi32> + print(attr) + # CHECK: is_splat: True + print("is_splat:", attr.is_splat) + +run(testGetDenseElementsSplatInt) + + +# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat +def testGetDenseElementsSplatFloat(): + ctx = mlir.ir.Context() + loc = ctx.get_unknown_location() + t = mlir.ir.F32Type.get(ctx) + element = mlir.ir.FloatAttr.get(t, 1.2, loc) + shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc) + attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, element) + # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32> + print(attr) + +run(testGetDenseElementsSplatFloat) + + +# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors +def testGetDenseElementsSplatErrors(): + ctx = mlir.ir.Context() + loc = ctx.get_unknown_location() + t = mlir.ir.F32Type.get(ctx) + other_t = mlir.ir.F64Type.get(ctx) + element = mlir.ir.FloatAttr.get(t, 1.2, loc) + other_element = mlir.ir.FloatAttr.get(other_t, 1.2, loc) + shaped_type = mlir.ir.RankedTensorType.get((2, 3, 4), t, loc) + dynamic_shaped_type = mlir.ir.UnrankedTensorType.get(t, loc) + non_shaped_type = t + + try: + attr = mlir.ir.DenseElementsAttr.get_splat(non_shaped_type, element) + except ValueError as e: + # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32) + print(e) + + try: + attr = mlir.ir.DenseElementsAttr.get_splat(dynamic_shaped_type, element) + except ValueError as e: + # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>) + print(e) + + try: + attr = mlir.ir.DenseElementsAttr.get_splat(shaped_type, other_element) + except ValueError as e: + # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64) + print(e) + +run(testGetDenseElementsSplatErrors) + + +################################################################################ +# Tests of the array/buffer .get() factory method, in all of its permutations. +################################################################################ + +### float and double arrays. + +# CHECK-LABEL: TEST: testGetDenseElementsF32 +def testGetDenseElementsF32(): + ctx = mlir.ir.Context() + array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32) + attr = mlir.ir.DenseElementsAttr.get(ctx, array) + # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32> + print(attr) + # CHECK: is_splat: False + print("is_splat:", attr.is_splat) + +run(testGetDenseElementsF32) + + +# CHECK-LABEL: TEST: testGetDenseElementsF64 +def testGetDenseElementsF64(): + ctx = mlir.ir.Context() + array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64) + attr = mlir.ir.DenseElementsAttr.get(ctx, array) + # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64> + print(attr) + +run(testGetDenseElementsF64) + + +### 32 bit integer arrays +# CHECK-LABEL: TEST: testGetDenseElementsI32Signless +def testGetDenseElementsI32Signless(): + ctx = mlir.ir.Context() + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) + attr = mlir.ir.DenseElementsAttr.get(ctx, array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> + print(attr) + +run(testGetDenseElementsI32Signless) + + +# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless +def testGetDenseElementsUI32Signless(): + ctx = mlir.ir.Context() + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) + attr = mlir.ir.DenseElementsAttr.get(ctx, array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> + print(attr) + +run(testGetDenseElementsUI32Signless) + +# CHECK-LABEL: TEST: testGetDenseElementsI32 +def testGetDenseElementsI32(): + ctx = mlir.ir.Context() + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) + attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32> + print(attr) + +run(testGetDenseElementsI32) + + +# CHECK-LABEL: TEST: testGetDenseElementsUI32 +def testGetDenseElementsUI32(): + ctx = mlir.ir.Context() + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) + attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32> + print(attr) + +run(testGetDenseElementsUI32) + + +## 64bit integer arrays +# CHECK-LABEL: TEST: testGetDenseElementsI64Signless +def testGetDenseElementsI64Signless(): + ctx = mlir.ir.Context() + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) + attr = mlir.ir.DenseElementsAttr.get(ctx, array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> + print(attr) + +run(testGetDenseElementsI64Signless) + + +# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless +def testGetDenseElementsUI64Signless(): + ctx = mlir.ir.Context() + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) + attr = mlir.ir.DenseElementsAttr.get(ctx, array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> + print(attr) + +run(testGetDenseElementsUI64Signless) + +# CHECK-LABEL: TEST: testGetDenseElementsI64 +def testGetDenseElementsI64(): + ctx = mlir.ir.Context() + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) + attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64> + print(attr) + +run(testGetDenseElementsI64) + + +# CHECK-LABEL: TEST: testGetDenseElementsUI64 +def testGetDenseElementsUI64(): + ctx = mlir.ir.Context() + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) + attr = mlir.ir.DenseElementsAttr.get(ctx, array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64> + print(attr) + +run(testGetDenseElementsUI64) + diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -104,7 +104,7 @@ loc = ctx.get_unknown_location() # CHECK: default_get: 4.200000e+01 : f32 print("default_get:", mlir.ir.FloatAttr.get( - mlir.ir.F32Type(ctx), 42.0, loc)) + mlir.ir.F32Type.get(ctx), 42.0, loc)) # CHECK: f32_get: 4.200000e+01 : f32 print("f32_get:", mlir.ir.FloatAttr.get_f32(ctx, 42.0)) # CHECK: f64_get: 4.200000e+01 : f64 @@ -127,6 +127,8 @@ iattr = mlir.ir.IntegerAttr(ctx.parse_attr("42")) # CHECK: iattr value: 42 print("iattr value:", iattr.value) + # CHECK: iattr type: i64 + print("iattr type:", iattr.type) # Test factory methods. # CHECK: default_get: 42 : i32 diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -135,7 +135,7 @@ def testIndexType(): ctx = mlir.ir.Context() # CHECK: index type: index - print("index type:", mlir.ir.IndexType(ctx)) + print("index type:", mlir.ir.IndexType.get(ctx)) run(testIndexType) @@ -143,13 +143,13 @@ def testFloatType(): ctx = mlir.ir.Context() # CHECK: float: bf16 - print("float:", mlir.ir.BF16Type(ctx)) + print("float:", mlir.ir.BF16Type.get(ctx)) # CHECK: float: f16 - print("float:", mlir.ir.F16Type(ctx)) + print("float:", mlir.ir.F16Type.get(ctx)) # CHECK: float: f32 - print("float:", mlir.ir.F32Type(ctx)) + print("float:", mlir.ir.F32Type.get(ctx)) # CHECK: float: f64 - print("float:", mlir.ir.F64Type(ctx)) + print("float:", mlir.ir.F64Type.get(ctx)) run(testFloatType) @@ -157,7 +157,7 @@ def testNoneType(): ctx = mlir.ir.Context() # CHECK: none type: none - print("none type:", mlir.ir.NoneType(ctx)) + print("none type:", mlir.ir.NoneType.get(ctx)) run(testNoneType) @@ -168,13 +168,13 @@ # CHECK: complex type element: i32 print("complex type element:", complex_i32.element_type) - f32 = mlir.ir.F32Type(ctx) + f32 = mlir.ir.F32Type.get(ctx) # CHECK: complex type: complex - print("complex type:", mlir.ir.ComplexType.get_complex(f32)) + print("complex type:", mlir.ir.ComplexType.get(f32)) - index = mlir.ir.IndexType(ctx) + index = mlir.ir.IndexType.get(ctx) try: - complex_invalid = mlir.ir.ComplexType.get_complex(index) + complex_invalid = mlir.ir.ComplexType.get(index) except ValueError as e: # CHECK: invalid 'Type(index)' and expected floating point or integer type. print(e) @@ -225,15 +225,15 @@ # CHECK-LABEL: TEST: testVectorType def testVectorType(): ctx = mlir.ir.Context() - f32 = mlir.ir.F32Type(ctx) + f32 = mlir.ir.F32Type.get(ctx) shape = [2, 3] loc = ctx.get_unknown_location() # CHECK: vector type: vector<2x3xf32> - print("vector type:", mlir.ir.VectorType.get_vector(shape, f32, loc)) + print("vector type:", mlir.ir.VectorType.get(shape, f32, loc)) - none = mlir.ir.NoneType(ctx) + none = mlir.ir.NoneType.get(ctx) try: - vector_invalid = mlir.ir.VectorType.get_vector(shape, none, loc) + vector_invalid = mlir.ir.VectorType.get(shape, none, loc) except ValueError as e: # CHECK: invalid 'Type(none)' and expected floating point or integer type. print(e) @@ -245,17 +245,16 @@ # CHECK-LABEL: TEST: testRankedTensorType def testRankedTensorType(): ctx = mlir.ir.Context() - f32 = mlir.ir.F32Type(ctx) + f32 = mlir.ir.F32Type.get(ctx) shape = [2, 3] loc = ctx.get_unknown_location() # CHECK: ranked tensor type: tensor<2x3xf32> print("ranked tensor type:", - mlir.ir.RankedTensorType.get_ranked_tensor(shape, f32, loc)) + mlir.ir.RankedTensorType.get(shape, f32, loc)) - none = mlir.ir.NoneType(ctx) + none = mlir.ir.NoneType.get(ctx) try: - tensor_invalid = mlir.ir.RankedTensorType.get_ranked_tensor(shape, none, - loc) + tensor_invalid = mlir.ir.RankedTensorType.get(shape, none, loc) except ValueError as e: # CHECK: invalid 'Type(none)' and expected floating point, integer, vector # CHECK: or complex type. @@ -268,9 +267,9 @@ # CHECK-LABEL: TEST: testUnrankedTensorType def testUnrankedTensorType(): ctx = mlir.ir.Context() - f32 = mlir.ir.F32Type(ctx) + f32 = mlir.ir.F32Type.get(ctx) loc = ctx.get_unknown_location() - unranked_tensor = mlir.ir.UnrankedTensorType.get_unranked_tensor(f32, loc) + unranked_tensor = mlir.ir.UnrankedTensorType.get(f32, loc) # CHECK: unranked tensor type: tensor<*xf32> print("unranked tensor type:", unranked_tensor) try: @@ -295,9 +294,9 @@ else: print("Exception not produced") - none = mlir.ir.NoneType(ctx) + none = mlir.ir.NoneType.get(ctx) try: - tensor_invalid = mlir.ir.UnrankedTensorType.get_unranked_tensor(none, loc) + tensor_invalid = mlir.ir.UnrankedTensorType.get(none, loc) except ValueError as e: # CHECK: invalid 'Type(none)' and expected floating point, integer, vector # CHECK: or complex type. @@ -310,7 +309,7 @@ # CHECK-LABEL: TEST: testMemRefType def testMemRefType(): ctx = mlir.ir.Context() - f32 = mlir.ir.F32Type(ctx) + f32 = mlir.ir.F32Type.get(ctx) shape = [2, 3] loc = ctx.get_unknown_location() memref = mlir.ir.MemRefType.get_contiguous_memref(f32, shape, 2, loc) @@ -321,7 +320,7 @@ # CHECK: memory space: 2 print("memory space:", memref.memory_space) - none = mlir.ir.NoneType(ctx) + none = mlir.ir.NoneType.get(ctx) try: memref_invalid = mlir.ir.MemRefType.get_contiguous_memref(none, shape, 2, loc) @@ -337,9 +336,9 @@ # CHECK-LABEL: TEST: testUnrankedMemRefType def testUnrankedMemRefType(): ctx = mlir.ir.Context() - f32 = mlir.ir.F32Type(ctx) + f32 = mlir.ir.F32Type.get(ctx) loc = ctx.get_unknown_location() - unranked_memref = mlir.ir.UnrankedMemRefType.get_unranked_memref(f32, 2, loc) + unranked_memref = mlir.ir.UnrankedMemRefType.get(f32, 2, loc) # CHECK: unranked memref type: memref<*xf32, 2> print("unranked memref type:", unranked_memref) try: @@ -364,10 +363,9 @@ else: print("Exception not produced") - none = mlir.ir.NoneType(ctx) + none = mlir.ir.NoneType.get(ctx) try: - memref_invalid = mlir.ir.UnrankedMemRefType.get_unranked_memref(none, 2, - loc) + memref_invalid = mlir.ir.UnrankedMemRefType.get(none, 2, loc) except ValueError as e: # CHECK: invalid 'Type(none)' and expected floating point, integer, vector # CHECK: or complex type. @@ -381,7 +379,7 @@ def testTupleType(): ctx = mlir.ir.Context() i32 = mlir.ir.IntegerType(ctx.parse_type("i32")) - f32 = mlir.ir.F32Type(ctx) + f32 = mlir.ir.F32Type.get(ctx) vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>")) l = [i32, f32, vector] tuple_type = mlir.ir.TupleType.get_tuple(ctx, l) @@ -400,7 +398,7 @@ ctx = mlir.ir.Context() input_types = [mlir.ir.IntegerType.get_signless(ctx, 32), mlir.ir.IntegerType.get_signless(ctx, 16)] - result_types = [mlir.ir.IndexType(ctx)] + result_types = [mlir.ir.IndexType.get(ctx)] func = mlir.ir.FunctionType.get(ctx, input_types, result_types) # CHECK: INPUTS: [Type(i32), Type(i16)] print("INPUTS:", func.inputs)