diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4949,13 +4949,19 @@ if (intPack.isSplat()) { auto splat = intPack.getSplatValue(); - // Casting int8 into int32. - if (srcElemType.isInteger(8) && dstElemType.isInteger(32)) { - uint32_t bits = static_cast(splat.getValue().getZExtValue()); - // Duplicate the 8-bit pattern. - bits = (bits << 24) | (bits << 16) | (bits << 8) | bits; - APInt intBits(32, bits); - return DenseElementsAttr::get(getResultVectorType(), intBits); + if (dstElemType.isa()) { + uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth(); + uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth(); + + // Casting to a larger integer bit width. + if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) { + APInt intBits = splat.getValue().zext(dstBitWidth); + + // Duplicate the lower width element. + for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++) + intBits = (intBits << srcBitWidth) | intBits; + return DenseElementsAttr::get(getResultVectorType(), intBits); + } } } }