diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4945,6 +4945,21 @@ } } + if (auto intPack = sourceConstant.dyn_cast()) { + if (intPack.isSplat()) { + auto splat = intPack.getSplatValue(); + + // Casting int8 into int32. + if (srcElemType.isInteger(8) && dstElemType.isInteger(32)) { + uint32_t bits = static_cast(splat.getValue().getZExtValue()); + // Duplicate the 8-bit pattern. + bits = (bits << 24) | (bits << 16) | (bits << 8) | bits; + APInt intBits(32, bits); + return DenseElementsAttr::get(getResultVectorType(), intBits); + } + } + } + return {}; } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -741,6 +741,20 @@ return %cast0, %cast1: vector<4xf32>, vector<4xf32> } +// CHECK-LABEL: func @bitcast_i8_to_i32 +// bit pattern: 0xA0A0A0A0 +// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32> +// bit pattern: 0x00000000 +// CHECK-DAG: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi32> +// CHECK: return %[[CST0]], %[[CST1]] +func.func @bitcast_i8_to_i32() -> (vector<4xi32>, vector<4xi32>) { + %cst0 = arith.constant dense<0> : vector<16xi8> // bit pattern: 0x00 + %cst1 = arith.constant dense<160> : vector<16xi8> // bit pattern: 0xA0 + %cast0 = vector.bitcast %cst0: vector<16xi8> to vector<4xi32> + %cast1 = vector.bitcast %cst1: vector<16xi8> to vector<4xi32> + return %cast0, %cast1: vector<4xi32>, vector<4xi32> +} + // ----- // CHECK-LABEL: broadcast_folding1