diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2181,11 +2181,11 @@ Results<(outs AnyMemRef:$result)> { let summary = "type_cast op converts a scalar memref to a vector memref"; let description = [{ - Performs a conversion from a memref with scalar element to a memref with a - *single* vector element, copying the shape of the memref to the vector. This - is the minimal viable operation that is required to makeke - super-vectorization operational. It can be seen as a special case of the - `view` operation but scoped in the super-vectorization context. + Performs a conversion from a memref with alignable scalar element or vector + element to a memref with a *single* vector element, copying the shape of the + memref to the vector. This is the minimal viable operation that is required + to makeke super-vectorization operational. It can be seen as a special case + of the `view` operation but scoped in the super-vectorization context. Syntax: @@ -2196,9 +2196,18 @@ Example: ```mlir - %A = memref.alloc() : memref<5x4x3xf32> - %VA = vector.type_cast %A : memref<5x4x3xf32> to memref> + %A = memref.alloc() : memref<5x4x4xf32> + %VA = vector.type_cast %A : memref<5x4x4xf32> to memref> ``` + + ```mlir + %B = memref.alloc() : memref<5x4xvector<3xf32>> + %VB = vector.type_cast %B : memref<5x4xvector<3xf32>> to memref> + ``` + + Because current restriction in the lowering of multi-dim vectors, the inner- + most dim of memref with scalar element type can only be a power of 2 to meet + the alignment requirement for casted vector. }]; /// Build the canonical memRefType with a single vector. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4995,6 +4995,17 @@ return emitOpError( "expects concatenated result and operand shapes to be equal: ") << resultType; + + if (!getElementTypeOrSelf(sourceType).isa()) { + if (auto shape = extractShape(sourceType); !shape.empty()) { + auto innerMostDim = shape.back(); + if (!llvm::isPowerOf2_64(innerMostDim)) { + return emitOpError("can only cast to aligned vectors, inner-most dim " + "for scalar type should be power of 2"); + } + } + } + return success(); } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1223,6 +1223,13 @@ // ----- +func.func @type_cast_align(%arg0: memref<5x4x3xf32>) { + // expected-error@+1 {{can only cast to aligned vectors}} + %0 = vector.type_cast %arg0 : memref<5x4x3xf32> to memref> +} + +// ----- + func.func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>>, %i : index, %j : index, %value : vector<8xf32>) { // expected-error@+1 {{'vector.store' op most minor memref dim must have unit stride}}