diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2503,6 +2503,14 @@ if (rank != insertOp.getDestVectorType().getRank()) return failure(); + // Requires that shape of insert op src is castable to dstType. + unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth(); + unsigned destinationWidth = + castDstType.getElementType().getIntOrFloatBitWidth(); + unsigned numElements = destinationWidth / sourceWidth; + if (insertOp.getSourceVectorType().getNumElements() % numElements != 0) + return failure(); + ArrayAttr newOffsets = insertOp.getOffsets(); assert(newOffsets.size() == rank); SmallVector offsets = getIntValueVector(newOffsets); diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -507,3 +507,21 @@ %cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32> return %cast: vector<16x4x4xf32> } + +// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_shape +func.func @bubble_up_bitcast_in_strided_slice_insert_odd_shape(%dst: vector<2xf16>, %src: vector<1xf16>) -> vector<1xf32> { + // CHECK: vector.insert_strided_slice + // CHECK-NEXT: vector.bitcast + %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<1xf16> into vector<2xf16> + %cast = vector.bitcast %0: vector<2xf16> to vector<1xf32> + return %cast: vector<1xf32> +} + +// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape +func.func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape(%dst: vector<8xf16>, %src: vector<3xf16>) -> vector<4xf32> { + // CHECK: vector.insert_strided_slice + // CHECK-NEXT: vector.bitcast + %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<3xf16> into vector<8xf16> + %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32> + return %cast: vector<4xf32> +}