This is an archive of the discontinued LLVM Phabricator instance.

[mlir][Vector] Adds a pattern to fold `arith.extf` into `vector.contract`
ClosedPublic

Authored by manishucsd on Jun 1 2023, 11:33 AM.

Details

Summary

Consider mixed precision data type, i.e., F16 input lhs, F16 input rhs, F32 accumulation, and F32 output. This is typically written as F32 <= F16*F16 + F32.

During vectorization from linalg to vector for mixed precision data type (F32 <= F16*F16 + F32), linalg.matmul introduces arith.extf on input lhs and rhs operands.

"linalg.matmul"(%lhs, %rhs, %acc) ({
      ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
        %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
        %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
       %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
        %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
      "linalg.yield"(%acc) : (f32) -> ()
    })

There are backend that natively supports mixed-precision data type and does not need the arith.extf. For example, NVIDIA A100 GPU has mma.sync.aligned.*.f32.f16.f16.f32 that can support mixed-precision data type. However, the presence of arith.extf in the IR, introduces the unnecessary casting targeting F32 Tensor Cores instead of F16 Tensor Cores for NVIDIA backend. This patch adds a folding pattern to fold arith.extf into vector.contract

Diff Detail

Event Timeline

manishucsd created this revision.Jun 1 2023, 11:33 AM
Herald added a project: Restricted Project. · View Herald Transcript
manishucsd requested review of this revision.Jun 1 2023, 11:33 AM
kuhar added a subscriber: kuhar.Jun 1 2023, 11:49 AM
kuhar added inline comments.
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
1240–1247

You can combine the defining op check with the type check:

auto lhsDef = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
...

if (!lhsDef || !rhsDef) {

Apply comment from kuhar

manishucsd marked an inline comment as done.Jun 1 2023, 12:50 PM

Thanks, Manish! LG in general. A few comments.

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
79

typo

80

Would it make sense to add these patterns to the regular vector.contact folding? Any reason to keep them separate? I understand that if the target doesn't have "native" support for these flavors they will be decomposed again into the original arith ops.

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
1242

nit: no curly braces per coding standards

mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
17

Could you please add another test that lowers a vector.contract f16, f16 -> f32 to plain arith instructions again? We have to make sure that the basic end-to-end lowering is working.

mravishankar accepted this revision.Jun 1 2023, 3:44 PM

Looks fine to me, but please wait for others to weigh in as well.

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
80

Would they be decomposed that way. I am not aware of all the lowerings to LLVM, but I am not sure they will automatically undo them... At least for now it make sense to keep them separate (unless you know its OK). I think unless there is some representation of what is legal for the target being compiled, all these patterns being mixed together is problematic.

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
1242

nit nit: I think if the line spans more than one textual lines, braces are expected :)

This revision is now accepted and ready to land.Jun 1 2023, 3:44 PM
manishucsd added inline comments.Jun 1 2023, 4:24 PM
mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
17

I am not sure I follow this. Are we looking for a test mlir that looks something like the following?

func.func @vector_contract_f16_f16_f32(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> {
    %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32>
   return %result : vector<64x64xf32>
}

The expectation from fold-arith-extf-into-vector-contract-patterns will be to do nothin for the above MLIR. Is that's what we are looking to add?

manishucsd updated this revision to Diff 527660.Jun 1 2023, 4:31 PM

Fixed typo "Airth" to "Arith"

manishucsd marked an inline comment as done.Jun 1 2023, 4:31 PM
dcaballe requested changes to this revision.Jun 1 2023, 6:20 PM
dcaballe added inline comments.
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
80

Yes, that sounds good for now

mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
17

Sorry, I wasn't clear enough. I'm asking to add a test with the IR that you posted above (basically the output of the folding transformation that you are adding) and then invoke the vector contract lowering patterns on it to make sure there is a basic lowering for the new folded contract operation. You can take a look at vector-contract-to-parallel-arith-transforms.mlir or vector-contract-to-outerproduct-transforms.mlir and add the new test there. If a similar test exists, then great, we don't have to do anything!

I'm requesting this because we have complaints about dead-end patterns in MLIR upstream, i.e., patterns that generate code which can't be lowered to anything else with MLIR upstream code. We should make sure this is not one of those case and a basic lowering to LLVM exists upstream.

This revision now requires changes to proceed.Jun 1 2023, 6:20 PM
manishucsd marked an inline comment as not done.Jun 2 2023, 10:12 AM
manishucsd added inline comments.
mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
17

For the below input IR:

vector-contract-to-parallel-arith-transforms.mlir

func.func @vector_contract_f16_f16_f32(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> {
    %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32>
   return %result : vector<64x64xf32>
}

vector-contract-to-parallel-arith-transforms will not lower it to vector.fma, if that is what we are looking for? See below:

$ mlir-opt vector-contract-to-parallel-arith-transforms.mlir --test-transform-dialect-interpreter


"builtin.module"() ({
  "func.func"() <{function_type = (vector<64x64xf16>, vector<64x64xf16>, vector<64x64xf32>) -> vector<64x64xf32>, sym_name = "vector_contract_f16_f16_f32"}> ({
  ^bb0(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>):
    %0 = "vector.contract"(%arg0, %arg1, %arg2) <{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = [#vector.iterator_type<parallel>, #vector.iterator_type<parallel>, #vector.iterator_type<reduction>], kind = #vector.kind<add>}> : (vector<64x64xf16>, vector<64x64xf16>, vector<64x64xf32>) -> vector<64x64xf32>
    "func.return"(%0) : (vector<64x64xf32>) -> ()
  }) : () -> ()
}) : () -> ()

fold-arith-extf-into-vector-contract-patterns is a NVIDIA GPU-specific folding that is needed to support mixed precision. The output of fold-arith-extf-into-vector-contract-patterns is unrolled into smaller vector.contracts to match nvgpu.mma.syncshapes and data type, then lowered to nvgpu dialect

manishucsd updated this revision to Diff 528040.Jun 2 2023, 4:49 PM

Test on mixed mode vector.contract lowering to mma.sync

dcaballe accepted this revision.Jun 5 2023, 10:44 AM

Thanks, Manish! LGTM

This revision is now accepted and ready to land.Jun 5 2023, 10:44 AM