diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4950,6 +4950,27 @@ } } + if (auto intPack = sourceConstant.dyn_cast()) { + if (intPack.isSplat()) { + auto splat = intPack.getSplatValue(); + + if (dstElemType.isa()) { + uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth(); + uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth(); + + // Casting to a larger integer bit width. + if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) { + APInt intBits = splat.getValue().zext(dstBitWidth); + + // Duplicate the lower width element. + for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++) + intBits = (intBits << srcBitWidth) | intBits; + return DenseElementsAttr::get(getResultVectorType(), intBits); + } + } + } + } + return {}; } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -741,6 +741,20 @@ return %cast0, %cast1: vector<4xf32>, vector<4xf32> } +// CHECK-LABEL: func @bitcast_i8_to_i32 +// bit pattern: 0xA0A0A0A0 +// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32> +// bit pattern: 0x00000000 +// CHECK-DAG: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi32> +// CHECK: return %[[CST0]], %[[CST1]] +func.func @bitcast_i8_to_i32() -> (vector<4xi32>, vector<4xi32>) { + %cst0 = arith.constant dense<0> : vector<16xi8> // bit pattern: 0x00 + %cst1 = arith.constant dense<160> : vector<16xi8> // bit pattern: 0xA0 + %cast0 = vector.bitcast %cst0: vector<16xi8> to vector<4xi32> + %cast1 = vector.bitcast %cst1: vector<16xi8> to vector<4xi32> + return %cast0, %cast1: vector<4xi32>, vector<4xi32> +} + // ----- // CHECK-LABEL: broadcast_folding1