diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -333,7 +333,8 @@ auto *context = elementType.getContext(); // Check that memref is formed from allowed types. - if (!elementType.isIntOrFloat() && !elementType.isa()) + if (!elementType.isIntOrFloat() && !elementType.isa() && + !elementType.isa()) return emitOptionalError(location, "invalid memref element type"), MemRefType(); @@ -411,7 +412,8 @@ Optional loc, MLIRContext *context, Type elementType, unsigned memorySpace) { // Check that memref is formed from allowed types. - if (!elementType.isIntOrFloat() && !elementType.isa()) + if (!elementType.isIntOrFloat() && !elementType.isa() && + !elementType.isa()) return emitOptionalError(*loc, "invalid memref element type"); return success(); } diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -133,6 +133,16 @@ // CHECK: func @complex_types(complex) -> complex func @complex_types(complex) -> complex + +// CHECK: func @memref_with_complex_elems(memref<1x?xcomplex>) +func @memref_with_complex_elems(memref<1x?xcomplex>) + +// CHECK: func @memref_with_vector_elems(memref<1x?xvector<10xf32>>) +func @memref_with_vector_elems(memref<1x?xvector<10xf32>>) + +// CHECK: func @unranked_memref_with_complex_elems(memref<*xcomplex>) +func @unranked_memref_with_complex_elems(memref<*xcomplex>) + // CHECK: func @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ()) func @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())