diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2462,6 +2462,9 @@ VectorType castSrcType = bitcastOp.getSourceVectorType(); VectorType castDstType = bitcastOp.getResultVectorType(); assert(castSrcType.getRank() == castDstType.getRank()); + // Skip 0-D vector which will not from InsertStridedSliceOp. + if (castSrcType.getRank() == 0) + return failure(); int64_t castSrcLastDim = castSrcType.getShape().back(); int64_t castDstLastDim = castDstType.getShape().back(); diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -525,3 +525,11 @@ %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32> return %cast: vector<4xf32> } + +// Make sure not crash on 0-D vector. +// CHECK-LABEL:func.func @vec_0D +// CHECK-NEXT:vector.bitcast +func.func @vec_0D(%arg0: vector) -> vector { + %0 = vector.bitcast %arg0 : vector to vector + return %0 : vector +}