diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -25,6 +25,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/StringSet.h" +#include "llvm/ADT/bit.h" #include using namespace mlir; @@ -2804,6 +2805,30 @@ if (result().getType() == otherOp.source().getType()) return otherOp.source(); + Attribute sourceConstant = operands.front(); + if (!sourceConstant) + return {}; + + Type srcElemType = getSourceVectorType().getElementType(); + Type dstElemType = getResultVectorType().getElementType(); + + if (auto floatPack = sourceConstant.dyn_cast()) { + if (floatPack.isSplat()) { + auto splat = floatPack.getSplatValue(); + + // Casting fp16 into fp32. + if (srcElemType.isF16() && dstElemType.isF32()) { + uint32_t bits = static_cast( + splat.getValue().bitcastToAPInt().getZExtValue()); + // Duplicate the 16-bit pattern. + bits = (bits << 16) | (bits & 0xffff); + APInt intBits(32, bits); + APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits); + return DenseElementsAttr::get(getResultVectorType(), floatBits); + } + } + } + return {}; } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -556,6 +556,20 @@ return %0, %2 : vector<4x8xf32>, vector<2xi32> } +// CHECK-LABEL: func @bitcast_f16_to_f32 +// bit pattern: 0x00000000 +// CHECK: %[[CST0:.+]] = constant dense<0.000000e+00> : vector<4xf32> +// bit pattern: 0x40004000 +// CHECK: %[[CST1:.+]] = constant dense<2.00390625> : vector<4xf32> +// CHECK: return %[[CST0]], %[[CST1]] +func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) { + %cst0 = constant dense<0.0> : vector<8xf16> // bit pattern: 0x0000 + %cst1 = constant dense<2.0> : vector<8xf16> // bit pattern: 0x4000 + %cast0 = vector.bitcast %cst0: vector<8xf16> to vector<4xf32> + %cast1 = vector.bitcast %cst1: vector<8xf16> to vector<4xf32> + return %cast0, %cast1: vector<4xf32>, vector<4xf32> +} + // ----- // CHECK-LABEL: broadcast_folding1