diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -9,6 +9,7 @@ #ifndef MLIR_IR_STANDARDTYPES_H #define MLIR_IR_STANDARDTYPES_H +#include "mlir/IR/Dialect.h" #include "mlir/IR/Types.h" namespace llvm { @@ -457,7 +458,12 @@ /// Return true if the specified element type is ok in a memref. static bool isValidElementType(Type type) { - return type.isIntOrIndexOrFloat() || type.isa(); + // Note: Non standard/builtin types are allowed to exist within memref + // types. Dialects are expected to verify that memref types have a valid + // element type within that dialect. + return type.isa() || + !type.getDialect().getNamespace().empty(); } /// Methods for support type inquiry through isa, cast, and dyn_cast. diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -217,8 +217,7 @@ return nullptr; // Check that memref is formed from allowed types. - if (!elementType.isIntOrIndexOrFloat() && - !elementType.isa()) + if (!BaseMemRefType::isValidElementType(elementType)) return emitError(typeLoc, "invalid memref element type"), nullptr; // Parse semi-affine-map-composition. diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -78,6 +78,9 @@ // CHECK: func @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<8xi8>) func @memrefs(memref<1x?x4x?x?xi32, #map0>, memref<8xi8, #map1, #map1>) +// CHECK: func @memref_quant_element(memref<3x!quant.uniform>) +func @memref_quant_element(memref<3x!quant.uniform>) + // Test memref affine map compositions. // CHECK: func @memrefs2(memref<2x4x8xi8, 1>)