diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -426,6 +426,11 @@ public: using ShapedType::ShapedType; + /// Return true if the specified element type is ok in a memref. + static bool isValidElementType(Type type) { + return type.isIntOrIndexOrFloat() || type.isa(); + } + /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Type type); }; diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -408,9 +408,7 @@ Optional location) { auto *context = elementType.getContext(); - // Check that memref is formed from allowed types. - if (!elementType.isIntOrIndexOrFloat() && - !elementType.isa()) + if (!BaseMemRefType::isValidElementType(elementType)) return emitOptionalError(location, "invalid memref element type"), MemRefType(); @@ -486,9 +484,7 @@ LogicalResult UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, unsigned memorySpace) { - // Check that memref is formed from allowed types. - if (!elementType.isIntOrFloat() && - !elementType.isa()) + if (!BaseMemRefType::isValidElementType(elementType)) return emitError(loc, "invalid memref element type"); return success(); } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -17,6 +17,14 @@ // ----- +func @illegalmemrefelementtype(memref>) -> () // expected-error {{invalid memref element type}} + +// ----- + +func @illegalunrankedmemrefelementtype(memref<*xtensor>) -> () // expected-error {{invalid memref element type}} + +// ----- + func @indexvector(vector<4 x index>) -> () // expected-error {{vector elements must be int or float type}} // ----- diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -152,6 +152,12 @@ // CHECK: func @unranked_memref_with_complex_elems(memref<*xcomplex>) func @unranked_memref_with_complex_elems(memref<*xcomplex>) +// CHECK: func @unranked_memref_with_index_elems(memref<*xindex>) +func @unranked_memref_with_index_elems(memref<*xindex>) + +// CHECK: func @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>) +func @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>) + // CHECK: func @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ()) func @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())