diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -334,7 +334,7 @@ // Check that memref is formed from allowed types. if (!elementType.isIntOrFloat() && !elementType.isa() && - !elementType.isa()) + !elementType.isa() && !elementType.isa()) return emitOptionalError(location, "invalid memref element type"), MemRefType(); @@ -413,7 +413,7 @@ unsigned memorySpace) { // Check that memref is formed from allowed types. if (!elementType.isIntOrFloat() && !elementType.isa() && - !elementType.isa()) + !elementType.isa() && !elementType.isa()) return emitOptionalError(*loc, "invalid memref element type"); return success(); } diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -143,6 +143,15 @@ // CHECK: func @unranked_memref_with_complex_elems(memref<*xcomplex>) func @unranked_memref_with_complex_elems(memref<*xcomplex>) +// CHECK: func @memref_with_memref_elems(memref<1x?xmemref>) +func @memref_with_memref_elems(memref<1x?xmemref>) + +// CHECK: func @unranked_memref_with_memref_elems(memref<*xmemref>) +func @unranked_memref_with_memref_elems(memref<*xmemref>) + +// CHECK: func @functions_with_memref((memref<1x?x4x?x?xmemref, #map0>, memref<8xi8>) -> (), () -> ()) +func @functions_with_memref((memref<1x?x4x?x?xmemref, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->()) + // CHECK: func @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ()) func @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())