diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -230,6 +230,12 @@ if (elementType.isIntOrFloat()) return elementType.getIntOrFloatBitWidth() * getNumElements(); + if (auto complexType = elementType.dyn_cast()) { + elementType = complexType.getElementType(); + assert(elementType.isIntOrFloat()); + return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; + } + // Tensors can have vectors and other tensors as elements, other shaped types // cannot. assert(isa() && "unsupported element type"); diff --git a/mlir/test/mlir-tblgen/op-derived-attribute.mlir b/mlir/test/mlir-tblgen/op-derived-attribute.mlir --- a/mlir/test/mlir-tblgen/op-derived-attribute.mlir +++ b/mlir/test/mlir-tblgen/op-derived-attribute.mlir @@ -5,9 +5,14 @@ // expected-remark @+2 {{element_dtype = f32}} // expected-remark @+1 {{size = 320}} %0 = "test.derived_type_attr"() : () -> tensor<10xf32> + // expected-remark @+2 {{element_dtype = i79}} // expected-remark @+1 {{size = 948}} %1 = "test.derived_type_attr"() : () -> tensor<12xi79> + // expected-remark @+2 {{element_dtype = complex}} + // expected-remark @+1 {{size = 768}} + %2 = "test.derived_type_attr"() : () -> tensor<12xcomplex> + return }