diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -456,9 +456,7 @@ using ShapedType::ShapedType; /// Return true if the specified element type is ok in a memref. - static bool isValidElementType(Type type) { - return type.isIntOrIndexOrFloat() || type.isa(); - } + static bool isValidElementType(Type type); /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Type type); diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -363,6 +363,16 @@ // BaseMemRefType //===----------------------------------------------------------------------===// +/// Return true if the specified element type is ok in a memref. +bool BaseMemRefType::isValidElementType(Type type) { + // Note: Non standard/builtin types are allowed to exist within memref + // types. Dialects are expected to verify that memref types have a valid + // element type within that dialect. + return type.isa() || + !type.getDialect().getNamespace().empty(); +} + unsigned BaseMemRefType::getMemorySpace() const { return static_cast(impl)->memorySpace; } diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -217,8 +217,7 @@ return nullptr; // Check that memref is formed from allowed types. - if (!elementType.isIntOrIndexOrFloat() && - !elementType.isa()) + if (!BaseMemRefType::isValidElementType(elementType)) return emitError(typeLoc, "invalid memref element type"), nullptr; // Parse semi-affine-map-composition. diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -78,6 +78,9 @@ // CHECK: func @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<8xi8>) func @memrefs(memref<1x?x4x?x?xi32, #map0>, memref<8xi8, #map1, #map1>) +// CHECK: func @memref_quant_element(memref<3x!quant.uniform>) +func @memref_quant_element(memref<3x!quant.uniform>) + // Test memref affine map compositions. // CHECK: func @memrefs2(memref<2x4x8xi8, 1>)