diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -629,8 +629,17 @@ } } if (bulkLoadElementType) { - auto shapedType = mlirRankedTensorTypeGet( - shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); + MlirType shapedType; + if (mlirTypeIsAShaped(*bulkLoadElementType)) { + if (explicitShape) { + throw std::invalid_argument("Shape can only be specified explicitly " + "when the type is not a shaped type."); + } + shapedType = *bulkLoadElementType; + } else { + shapedType = mlirRankedTensorTypeGet( + shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); + } size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( shapedType, rawBufferSize, arrayInfo.ptr); diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py --- a/mlir/test/python/dialects/builtin.py +++ b/mlir/test/python/dialects/builtin.py @@ -3,6 +3,7 @@ from mlir.ir import * import mlir.dialects.builtin as builtin import mlir.dialects.func as func +import numpy as np def run(f): @@ -221,3 +222,17 @@ # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"}, # CHECK: %{{.*}}: f32) print(module) + + +# CHECK-LABEL: testDenseElementsAttr +@run +def testDenseElementsAttr(): + with Context(), Location.unknown(): + values = np.arange(4, dtype=np.int32) + i32 = IntegerType.get_signless(32) + print(DenseElementsAttr.get(values, type=i32)) + # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32> + print(DenseElementsAttr.get(values, type=i32, shape=(2, 2))) + # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> + print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32))) + # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>