diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -306,6 +306,23 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet( MlirType shapedType, intptr_t numElements, MlirAttribute const *elements); +/// Creates a dense elements attribute with the given Shaped type and elements +/// populated from a packed, row-major opaque buffer of contents. +/// +/// The format of the raw buffer is a densely packed array of values that +/// can be bitcast to the storage format of the element type specified. +/// Types that are not byte aligned will be: +/// - For bitwidth > 1: Rounded up to the next byte. +/// - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to +/// the linear order of the shape type from MSB to LSB, padded to on the +/// right. +/// +/// A raw buffer of a single element (or for 1-bit, a byte of value 0 or 255) +/// will be interpreted as a splat. User code should be prepared for additional, +/// conformant patterns to be identified as splats in the future. +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrRawBufferGet( + MlirType shapedType, size_t rawBufferSize, const void *rawBuffer); + /// Creates a dense elements attribute with the given Shaped type containing a /// single replicated element (splat). MLIR_CAPI_EXPORTED MlirAttribute diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -183,15 +183,35 @@ } /// Construct a dense elements attribute from a raw buffer representing the - /// data for this attribute. Users should generally not use this methods as - /// the expected buffer format may not be a form the user expects. + /// data for this attribute. Users are encouraged to use one of the + /// constructors above, which provide more safeties. However, this + /// constructor is useful for tools which may want to interop and can + /// follow the precise definition. + /// + /// The format of the raw buffer is a densely packed array of values that + /// can be bitcast to the storage format of the element type specified. + /// Types that are not byte aligned will be: + /// - For bitwidth > 1: Rounded up to the next byte. + /// - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to + /// the linear order of the shape type from MSB to LSB, padded to on the + /// right. + /// + /// If `isSplatBuffer` is true, then the raw buffer should contain a + /// single element (or for the case of 1-bit, a single byte of 0 or 255), + /// which will be used to construct a splat. static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef rawBuffer, bool isSplatBuffer); /// Returns true if the given buffer is a valid raw buffer for the given type. /// `detectedSplat` is set if the buffer is valid and represents a splat - /// buffer. + /// buffer. The definition may be expanded over time, but currently, a + /// splat buffer is detected if: + /// - For >1bit: The buffer consists of a single element. + /// - For 1bit: The buffer consists of a single byte with value 0 or 255. + /// + /// User code should be prepared for additional, conformant patterns to be + /// identified as splats in the future. static bool isValidRawBuffer(ShapedType type, ArrayRef rawBuffer, bool &detectedSplat); diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -17,9 +17,57 @@ using namespace mlir; using namespace mlir::python; +using llvm::None; +using llvm::Optional; using llvm::SmallVector; using llvm::Twine; +//------------------------------------------------------------------------------ +// Docstrings (trivial, non-duplicated docstrings are included inline). +//------------------------------------------------------------------------------ + +static const char kDenseElementsAttrGetDocstring[] = + R"(Gets a DenseElementsAttr from a Python buffer or array. + +When `type` is not provided, then some limited type inferencing is done based +on the buffer format. Support presently exists for 8/16/32/64 signed and +unsigned integers and float16/float32/float64. DenseElementsAttrs of these +types can also be converted back to a corresponding buffer. + +For conversions outside of these types, a `type=` must be explicitly provided +and the buffer contents must be bit-castable to the MLIR internal +representation: + + * Integer types (except for i1): the buffer must be byte aligned to the + next byte boundary. + * Floating point types: Must be bit-castable to the given floating point + size. + * i1 (bool): Bit packed into 8bit words where the bit pattern matches a + row major ordering. An arbitrary Numpy `bool_` array can be bit packed to + this specification with: `np.packbits(ary, axis=None, bitorder='little')`. + +If a single element buffer is passed (or for i1, a single byte with value 0 +or 255), then a splat will be created. + +Args: + array: The array or buffer to convert. + signless: If inferring an appropriate MLIR type, use signless types for + integers (defaults True). + type: Skips inference of the MLIR element type and uses this instead. The + storage size must be consistent with the actual contents of the buffer. + shape: Overrides the shape of the buffer when constructing the MLIR + shaped type. This is needed when the physical and logical shape differ (as + for i1). + context: Explicit context, if not from context manager. + +Returns: + DenseElementsAttr on success. + +Raises: + ValueError: If the type of the buffer or array cannot be matched to an MLIR + type or if the buffer does not meet expectations. +)"; + namespace { static MlirStringRef toMlirStringRef(const std::string &s) { @@ -301,7 +349,6 @@ } }; -// TODO: Support construction of bool elements. // TODO: Support construction of string elements. class PyDenseElementsAttribute : public PyConcreteAttribute { @@ -311,7 +358,8 @@ using PyConcreteAttribute::PyConcreteAttribute; static PyDenseElementsAttribute - getFromBuffer(py::buffer array, bool signless, + getFromBuffer(py::buffer array, bool signless, Optional explicitType, + Optional> explicitShape, DefaultingPyMlirContext contextWrapper) { // Request a contiguous view. In exotic cases, this will cause a copy. int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; @@ -321,69 +369,95 @@ throw py::error_already_set(); } py::buffer_info arrayInfo(view); + SmallVector shape; + if (explicitShape) { + shape.append(explicitShape->begin(), explicitShape->end()); + } else { + shape.append(arrayInfo.shape.begin(), + arrayInfo.shape.begin() + arrayInfo.ndim); + } + MlirAttribute encodingAttr = mlirAttributeGetNull(); MlirContext context = contextWrapper->get(); - // Switch on the types that can be bulk loaded between the Python and - // MLIR-C APIs. - // See: https://docs.python.org/3/library/struct.html#format-characters - if (arrayInfo.format == "f") { + + // Detect format codes that are suitable for bulk loading. This includes + // all byte aligned integer and floating point types up to 8 bytes. + // Notably, this excludes, bool (which needs to be bit-packed) and + // other exotics which do not have a direct representation in the buffer + // protocol (i.e. complex, etc). + Optional bulkLoadElementType; + if (explicitType) { + bulkLoadElementType = *explicitType; + } else if (arrayInfo.format == "f") { // f32 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); - return PyDenseElementsAttribute( - contextWrapper->getRef(), - bulkLoad(context, mlirDenseElementsAttrFloatGet, - mlirF32TypeGet(context), arrayInfo)); + bulkLoadElementType = mlirF32TypeGet(context); } else if (arrayInfo.format == "d") { // f64 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); - return PyDenseElementsAttribute( - contextWrapper->getRef(), - bulkLoad(context, mlirDenseElementsAttrDoubleGet, - mlirF64TypeGet(context), arrayInfo)); + bulkLoadElementType = mlirF64TypeGet(context); + } else if (arrayInfo.format == "e") { + // f16 + assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); + bulkLoadElementType = mlirF16TypeGet(context); } else if (isSignedIntegerFormat(arrayInfo.format)) { if (arrayInfo.itemsize == 4) { // i32 - MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeSignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrInt32Get, - elementType, arrayInfo)); + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); } else if (arrayInfo.itemsize == 8) { // i64 - MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeSignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrInt64Get, - elementType, arrayInfo)); + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + } else if (arrayInfo.itemsize == 1) { + // i8 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeSignedGet(context, 8); + } else if (arrayInfo.itemsize == 2) { + // i16 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeSignedGet(context, 16); } } else if (isUnsignedIntegerFormat(arrayInfo.format)) { if (arrayInfo.itemsize == 4) { // unsigned i32 - MlirType elementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeUnsignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrUInt32Get, - elementType, arrayInfo)); + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); } else if (arrayInfo.itemsize == 8) { // unsigned i64 - MlirType elementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeUnsignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrUInt64Get, - elementType, arrayInfo)); + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + } else if (arrayInfo.itemsize == 1) { + // i8 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeUnsignedGet(context, 8); + } else if (arrayInfo.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeUnsignedGet(context, 16); } } + if (bulkLoadElementType) { + auto shapedType = mlirRankedTensorTypeGet( + shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); + size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; + MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( + shapedType, rawBufferSize, arrayInfo.ptr); + if (mlirAttributeIsNull(attr)) { + throw std::invalid_argument( + "DenseElementsAttr could not be constructed from the given buffer. " + "This may mean that the Python buffer layout does not match that " + "MLIR expected layout and is a bug."); + } + return PyDenseElementsAttribute(contextWrapper->getRef(), attr); + } - // TODO: Fall back to string-based get. - std::string message = "unimplemented array format conversion from format: "; - message.append(arrayInfo.format); - throw SetPyError(PyExc_ValueError, message); + throw std::invalid_argument( + std::string("unimplemented array format conversion from format: ") + + arrayInfo.format); } static PyDenseElementsAttribute getSplat(PyType shapedType, @@ -422,47 +496,82 @@ intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } py::buffer_info accessBuffer() { + if (mlirDenseElementsAttrIsSplat(*this)) { + // TODO: Raise an exception. + // Reported as https://github.com/pybind/pybind11/issues/3336 + return py::buffer_info(); + } + MlirType shapedType = mlirAttributeGetType(*this); MlirType elementType = mlirShapedTypeGetElementType(shapedType); + std::string format; if (mlirTypeIsAF32(elementType)) { // f32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); + return bufferInfo(shapedType); } else if (mlirTypeIsAF64(elementType)) { // f64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); + return bufferInfo(shapedType); + } else if (mlirTypeIsAF16(elementType)) { + // f16 + return bufferInfo(shapedType, "e"); } else if (mlirTypeIsAInteger(elementType) && mlirIntegerTypeGetWidth(elementType) == 32) { if (mlirIntegerTypeIsSignless(elementType) || mlirIntegerTypeIsSigned(elementType)) { // i32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); + return bufferInfo(shapedType); } else if (mlirIntegerTypeIsUnsigned(elementType)) { // unsigned i32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); + return bufferInfo(shapedType); } } else if (mlirTypeIsAInteger(elementType) && mlirIntegerTypeGetWidth(elementType) == 64) { if (mlirIntegerTypeIsSignless(elementType) || mlirIntegerTypeIsSigned(elementType)) { // i64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); + return bufferInfo(shapedType); } else if (mlirIntegerTypeIsUnsigned(elementType)) { // unsigned i64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); + return bufferInfo(shapedType); + } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 8) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i8 + return bufferInfo(shapedType); + } else if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i8 + return bufferInfo(shapedType); + } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 16) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i16 + return bufferInfo(shapedType); + } else if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i16 + return bufferInfo(shapedType); } } - std::string message = "unimplemented array format."; - throw SetPyError(PyExc_ValueError, message); + // TODO: Currently crashes the program. Just returning an empty buffer + // for now. + // Reported as https://github.com/pybind/pybind11/issues/3336 + // throw std::invalid_argument( + // "unsupported data type for conversion to Python buffer"); + return py::buffer_info(); } static void bindDerived(ClassTy &c) { c.def("__len__", &PyDenseElementsAttribute::dunderLen) .def_static("get", PyDenseElementsAttribute::getFromBuffer, py::arg("array"), py::arg("signless") = true, + py::arg("type") = py::none(), py::arg("shape") = py::none(), py::arg("context") = py::none(), - "Gets from a buffer or ndarray") + kDenseElementsAttrGetDocstring) .def_static("get_splat", PyDenseElementsAttribute::getSplat, py::arg("shaped_type"), py::arg("element_attr"), "Gets a DenseElementsAttr where all values are the same") @@ -474,21 +583,6 @@ } private: - template - static MlirAttribute - bulkLoad(MlirContext context, - MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), - MlirType mlirElementType, py::buffer_info &arrayInfo) { - SmallVector shape(arrayInfo.shape.begin(), - arrayInfo.shape.begin() + arrayInfo.ndim); - MlirAttribute encodingAttr = mlirAttributeGetNull(); - auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), - mlirElementType, encodingAttr); - intptr_t numElements = arrayInfo.size; - const ElementTy *contents = static_cast(arrayInfo.ptr); - return ctor(shapedType, numElements, contents); - } - static bool isUnsignedIntegerFormat(const std::string &format) { if (format.empty()) return false; @@ -507,7 +601,7 @@ template py::buffer_info bufferInfo(MlirType shapedType, - Type (*value)(MlirAttribute, intptr_t)) { + const char *explicitFormat = nullptr) { intptr_t rank = mlirShapedTypeGetRank(shapedType); // Prepare the data for the buffer_info. // Buffer is configured for read-only access below. @@ -528,9 +622,14 @@ strides.push_back(sizeof(Type) * strideFactor); } strides.push_back(sizeof(Type)); - return py::buffer_info(data, sizeof(Type), - py::format_descriptor::format(), rank, shape, - strides, /*readonly=*/true); + std::string format; + if (explicitFormat) { + format = explicitFormat; + } else { + format = py::format_descriptor::format(); + } + return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, + /*readonly=*/true); } }; // namespace diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -331,6 +331,21 @@ unwrapList(numElements, elements, attributes))); } +MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, + size_t rawBufferSize, + const void *rawBuffer) { + auto shapedTypeCpp = unwrap(shapedType).cast(); + ArrayRef rawBufferCpp(static_cast(rawBuffer), + rawBufferSize); + bool isSplat = false; + if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp, + isSplat)) { + return mlirAttributeGetNull(); + } + return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp, + isSplat)); +} + MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element) { return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -792,9 +792,16 @@ // Storage width of 1 is special as it is packed by the bit. if (storageWidth == 1) { - // Check for a splat, or a buffer equal to the number of elements. - if ((detectedSplat = rawBuffer.size() == 1)) - return true; + // Check for a splat, or a buffer equal to the number of elements which + // consists of either all 0's or all 1's. + detectedSplat = false; + if (rawBuffer.size() == 1) { + auto rawByte = static_cast(rawBuffer[0]); + if (rawByte == 0 || rawByte == 0xff) { + detectedSplat = true; + return true; + } + } return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); } // All other types are 8-bit aligned. diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py --- a/mlir/test/python/ir/array_attributes.py +++ b/mlir/test/python/ir/array_attributes.py @@ -11,11 +11,13 @@ f() gc.collect() assert Context._get_live_count() == 0 + return f ################################################################################ # Tests of the array/buffer .get() factory method on unsupported dtype. ################################################################################ +@run def testGetDenseElementsUnsupported(): with Context(): array = np.array([["hello", "goodbye"]]) @@ -25,13 +27,12 @@ # CHECK: unimplemented array format conversion from format: print(e) -run(testGetDenseElementsUnsupported) - ################################################################################ # Splats. ################################################################################ # CHECK-LABEL: TEST: testGetDenseElementsSplatInt +@run def testGetDenseElementsSplatInt(): with Context(), Location.unknown(): t = IntegerType.get_signless(32) @@ -43,10 +44,9 @@ # CHECK: is_splat: True print("is_splat:", attr.is_splat) -run(testGetDenseElementsSplatInt) - # CHECK-LABEL: TEST: testGetDenseElementsSplatFloat +@run def testGetDenseElementsSplatFloat(): with Context(), Location.unknown(): t = F32Type.get() @@ -56,10 +56,9 @@ # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32> print(attr) -run(testGetDenseElementsSplatFloat) - # CHECK-LABEL: TEST: testGetDenseElementsSplatErrors +@run def testGetDenseElementsSplatErrors(): with Context(), Location.unknown(): t = F32Type.get() @@ -88,32 +87,113 @@ # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64) print(e) -run(testGetDenseElementsSplatErrors) + +# CHECK-LABEL: TEST: testRepeatedValuesSplat +@run +def testRepeatedValuesSplat(): + with Context(): + array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32) + attr = DenseElementsAttr.get(array) + # CHECK: dense<1.000000e+00> : tensor<2x3xf32> + print(attr) + # CHECK: is_splat: True + print("is_splat:", attr.is_splat) + # CHECK: () + print(np.array(attr)) + + +# CHECK-LABEL: TEST: testNonSplat +@run +def testNonSplat(): + with Context(): + array = np.array([2.0, 1.0, 1.0], dtype=np.float32) + attr = DenseElementsAttr.get(array) + # CHECK: is_splat: False + print("is_splat:", attr.is_splat) ################################################################################ # Tests of the array/buffer .get() factory method, in all of its permutations. ################################################################################ +### explicitly provided types + +@run +def testGetDenseElementsBF16(): + with Context(): + array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16) + attr = DenseElementsAttr.get(array, type=BF16Type.get()) + # Note: These values don't mean much since just bit-casting. But they + # shouldn't change. + # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16> + print(attr) + +@run +def testGetDenseElementsInteger4(): + with Context(): + array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8) + attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4)) + # Note: These values don't mean much since just bit-casting. But they + # shouldn't change. + # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4> + print(attr) + + +@run +def testGetDenseElementsBool(): + with Context(): + bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_) + array = np.packbits(bool_array, axis=None, bitorder="little") + attr = DenseElementsAttr.get( + array, type=IntegerType.get_signless(1), shape=bool_array.shape) + # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1> + print(attr) + + +@run +def testGetDenseElementsBoolSplat(): + with Context(): + zero = np.array(0, dtype=np.uint8) + one = np.array(255, dtype=np.uint8) + print(one) + # CHECK: dense : tensor<4x2x5xi1> + print(DenseElementsAttr.get( + zero, type=IntegerType.get_signless(1), shape=(4, 2, 5))) + # CHECK: dense : tensor<4x2x5xi1> + print(DenseElementsAttr.get( + one, type=IntegerType.get_signless(1), shape=(4, 2, 5))) + + ### float and double arrays. +# CHECK-LABEL: TEST: testGetDenseElementsF16 +@run +def testGetDenseElementsF16(): + with Context(): + array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16> + print(attr) + # CHECK: {{\[}}[ 2. 4. 8.] + # CHECK: {{\[}}16. 32. 64.]] + print(np.array(attr)) + + # CHECK-LABEL: TEST: testGetDenseElementsF32 +@run def testGetDenseElementsF32(): with Context(): array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32) attr = DenseElementsAttr.get(array) # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32> print(attr) - # CHECK: is_splat: False - print("is_splat:", attr.is_splat) # CHECK: {{\[}}[1.1 2.2 3.3] # CHECK: {{\[}}4.4 5.5 6.6]] print(np.array(attr)) -run(testGetDenseElementsF32) - # CHECK-LABEL: TEST: testGetDenseElementsF64 +@run def testGetDenseElementsF64(): with Context(): array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64) @@ -124,11 +204,62 @@ # CHECK: {{\[}}4.4 5.5 6.6]] print(np.array(attr)) -run(testGetDenseElementsF64) +### 16 bit integer arrays +# CHECK-LABEL: TEST: testGetDenseElementsI16Signless +@run +def testGetDenseElementsI16Signless(): + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) + + +# CHECK-LABEL: TEST: testGetDenseElementsUI16Signless +@run +def testGetDenseElementsUI16Signless(): + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) + + +# CHECK-LABEL: TEST: testGetDenseElementsI16 +@run +def testGetDenseElementsI16(): + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16) + attr = DenseElementsAttr.get(array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) + + +# CHECK-LABEL: TEST: testGetDenseElementsUI16 +@run +def testGetDenseElementsUI16(): + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16) + attr = DenseElementsAttr.get(array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) ### 32 bit integer arrays # CHECK-LABEL: TEST: testGetDenseElementsI32Signless +@run def testGetDenseElementsI32Signless(): with Context(): array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) @@ -139,10 +270,9 @@ # CHECK: {{\[}}4 5 6]] print(np.array(attr)) -run(testGetDenseElementsI32Signless) - # CHECK-LABEL: TEST: testGetDenseElementsUI32Signless +@run def testGetDenseElementsUI32Signless(): with Context(): array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) @@ -153,9 +283,9 @@ # CHECK: {{\[}}4 5 6]] print(np.array(attr)) -run(testGetDenseElementsUI32Signless) # CHECK-LABEL: TEST: testGetDenseElementsI32 +@run def testGetDenseElementsI32(): with Context(): array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) @@ -166,10 +296,9 @@ # CHECK: {{\[}}4 5 6]] print(np.array(attr)) -run(testGetDenseElementsI32) - # CHECK-LABEL: TEST: testGetDenseElementsUI32 +@run def testGetDenseElementsUI32(): with Context(): array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) @@ -180,11 +309,10 @@ # CHECK: {{\[}}4 5 6]] print(np.array(attr)) -run(testGetDenseElementsUI32) - ## 64bit integer arrays # CHECK-LABEL: TEST: testGetDenseElementsI64Signless +@run def testGetDenseElementsI64Signless(): with Context(): array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) @@ -195,10 +323,9 @@ # CHECK: {{\[}}4 5 6]] print(np.array(attr)) -run(testGetDenseElementsI64Signless) - # CHECK-LABEL: TEST: testGetDenseElementsUI64Signless +@run def testGetDenseElementsUI64Signless(): with Context(): array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) @@ -209,9 +336,9 @@ # CHECK: {{\[}}4 5 6]] print(np.array(attr)) -run(testGetDenseElementsUI64Signless) # CHECK-LABEL: TEST: testGetDenseElementsI64 +@run def testGetDenseElementsI64(): with Context(): array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) @@ -222,10 +349,9 @@ # CHECK: {{\[}}4 5 6]] print(np.array(attr)) -run(testGetDenseElementsI64) - # CHECK-LABEL: TEST: testGetDenseElementsUI64 +@run def testGetDenseElementsUI64(): with Context(): array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) @@ -236,5 +362,3 @@ # CHECK: {{\[}}4 5 6]] print(np.array(attr)) -run(testGetDenseElementsUI64) -