Consider mixed precision data type, i.e., F16 input lhs, F16 input rhs, F32 accumulation, and F32 output. This is typically written as F32 <= F16*F16 + F32.
During vectorization from linalg to vector for mixed precision data type (F32 <= F16*F16 + F32), linalg.matmul introduces arith.extf on input lhs and rhs operands.
"linalg.matmul"(%lhs, %rhs, %acc) ({ ^bb0(%arg1: f16, %arg2: f16, %arg3: f32): %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32 %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32 %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32 %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32 "linalg.yield"(%acc) : (f32) -> () })
There are backend that natively supports mixed-precision data type and does not need the arith.extf. For example, NVIDIA A100 GPU has mma.sync.aligned.*.f32.f16.f16.f32 that can support mixed-precision data type. However, the presence of arith.extf in the IR, introduces the unnecessary casting targeting F32 Tensor Cores instead of F16 Tensor Cores for NVIDIA backend. This patch adds a folding pattern to fold arith.extf into vector.contract
typo