diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -895,7 +895,7 @@ vector-type ::= `vector` `<` static-dimension-list vector-element-type `>` vector-element-type ::= float-type | integer-type | index-type - static-dimension-list ::= (decimal-literal `x`)+ + static-dimension-list ::= (decimal-literal `x`)* ``` The vector type represents a SIMD style vector, used by target-specific @@ -903,13 +903,13 @@ vector<16 x f32>) we also support multidimensional registers on targets that support them (like TPUs). - Vector shapes must be positive decimal integers. + Vector shapes must be positive decimal integers. 0D vectors are allowed by + omitting the dimension: `vector`. Note: hexadecimal integer literals are not allowed in vector type declarations, `vector<0x42xi32>` is invalid because it is interpreted as a 2D vector with shape `(0, 42)` and zero shapes are not allowed. - Examples: ```mlir diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -441,9 +441,6 @@ LogicalResult VectorType::verify(function_ref emitError, ArrayRef shape, Type elementType) { - if (shape.empty()) - return emitError() << "vector types must have at least one dimension"; - if (!isValidElementType(elementType)) return emitError() << "vector elements must be int/index/float type but got " diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -442,9 +442,7 @@ /// Parse a vector type. /// -/// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` -/// non-empty-static-dimension-list ::= decimal-literal `x` -/// static-dimension-list +/// vector-type ::= `vector` `<` static-dimension-list type `>` /// static-dimension-list ::= (decimal-literal `x`)* /// VectorType Parser::parseVectorType() { @@ -456,8 +454,6 @@ SmallVector dimensions; if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) return nullptr; - if (dimensions.empty()) - return (emitError("expected dimension size in vector type"), nullptr); if (any_of(dimensions, [](int64_t i) { return i <= 0; })) return emitError(getToken().getLoc(), "vector types must have positive constant sizes"), diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -949,7 +949,7 @@ // ----- -// expected-error @+1 {{expected dimension size in vector type}} +// expected-error @+1 {{expected non-function type}} func @negative_vector_size() -> vector<-1xi32> // ----- diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -67,8 +67,8 @@ // CHECK: func private @float_types(f80, f128) func private @float_types(f80, f128) -// CHECK: func private @vectors(vector<1xf32>, vector<2x4xf32>) -func private @vectors(vector<1 x f32>, vector<2x4xf32>) +// CHECK: func private @vectors(vector, vector<1xf32>, vector<2x4xf32>) +func private @vectors(vector, vector<1 x f32>, vector<2x4xf32>) // CHECK: func private @tensors(tensor<*xf32>, tensor<*xvector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor) func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,