diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2503,6 +2503,13 @@ if (rank != insertOp.getDestVectorType().getRank()) return failure(); + // Requires that shape of insert op src is castable to dstType. + unsigned srcWidth = castSrcType.getElementType().getIntOrFloatBitWidth(); + unsigned dstWidth = castDstType.getElementType().getIntOrFloatBitWidth(); + unsigned unitCastNumel = dstWidth / srcWidth; + if (insertOp.getSourceVectorType().getNumElements() % unitCastNumel != 0) + return failure(); + ArrayAttr newOffsets = insertOp.getOffsets(); assert(newOffsets.size() == rank); SmallVector offsets = getIntValueVector(newOffsets);