diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1528,6 +1528,7 @@ MlirContext context = contextWrapper->get(); // Switch on the types that can be bulk loaded between the Python and // MLIR-C APIs. + // See: https://docs.python.org/3/library/struct.html#format-characters if (arrayInfo.format == "f") { // f32 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); @@ -1542,42 +1543,44 @@ contextWrapper->getRef(), bulkLoad(context, mlirDenseElementsAttrDoubleGet, mlirF64TypeGet(context), arrayInfo)); - } else if (arrayInfo.format == "i") { - // i32 - assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); - MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeSignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrInt32Get, - elementType, arrayInfo)); - } else if (arrayInfo.format == "I") { - // unsigned i32 - assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); - MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeUnsignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrUInt32Get, - elementType, arrayInfo)); - } else if (arrayInfo.format == "l") { - // i64 - assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); - MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeSignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrInt64Get, - elementType, arrayInfo)); - } else if (arrayInfo.format == "L") { - // unsigned i64 - assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); - MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeUnsignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrUInt64Get, - elementType, arrayInfo)); + } else if (isSignedIntegerFormat(arrayInfo.format)) { + if (arrayInfo.itemsize == 4) { + // i32 + MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); + return PyDenseElementsAttribute(contextWrapper->getRef(), + bulkLoad(context, + mlirDenseElementsAttrInt32Get, + elementType, arrayInfo)); + } else if (arrayInfo.itemsize == 8) { + // i64 + MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + return PyDenseElementsAttribute(contextWrapper->getRef(), + bulkLoad(context, + mlirDenseElementsAttrInt64Get, + elementType, arrayInfo)); + } + } else if (isUnsignedIntegerFormat(arrayInfo.format)) { + if (arrayInfo.itemsize == 4) { + // unsigned i32 + MlirType elementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); + return PyDenseElementsAttribute(contextWrapper->getRef(), + bulkLoad(context, + mlirDenseElementsAttrUInt32Get, + elementType, arrayInfo)); + } else if (arrayInfo.itemsize == 8) { + // unsigned i64 + MlirType elementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + return PyDenseElementsAttribute(contextWrapper->getRef(), + bulkLoad(context, + mlirDenseElementsAttrUInt64Get, + elementType, arrayInfo)); + } } // TODO: Fall back to string-based get. @@ -1650,7 +1653,23 @@ const ElementTy *contents = static_cast(arrayInfo.ptr); return ctor(shapedType, numElements, contents); } -}; + + static bool isUnsignedIntegerFormat(const std::string &format) { + if (format.empty()) + return false; + char code = format[0]; + return code == 'I' || code == 'B' || code == 'H' || code == 'L' || + code == 'Q'; + } + + static bool isSignedIntegerFormat(const std::string &format) { + if (format.empty()) + return false; + char code = format[0]; + return code == 'i' || code == 'b' || code == 'h' || code == 'l' || + code == 'q'; + } +}; // namespace /// Refinement of the PyDenseElementsAttribute for attributes containing integer /// (and boolean) values. Supports element access.