This is an archive of the discontinued LLVM Phabricator instance.

[mlir][linalg] Add more shape information for tensor.pack generalization
AbandonedPublic

Authored by hanchung on Feb 14 2023, 6:30 PM.

Details

Summary

The static shape information is known in the pattern because it limits
outer dims to be all 1s. In this context, inserting tensor.cast op is
safe and it gives other analysis more information.

Diff Detail

Event Timeline

hanchung created this revision.Feb 14 2023, 6:30 PM
hanchung requested review of this revision.Feb 14 2023, 6:30 PM
chelini accepted this revision.Feb 16 2023, 11:05 PM
This revision is now accepted and ready to land.Feb 16 2023, 11:05 PM
mravishankar requested changes to this revision.Feb 16 2023, 11:20 PM

This seems strange to me. There is nothing that prevents a subsequent transformation to collapse the cast into the slice. Indeed it is probably better to do so.

This revision now requires changes to proceed.Feb 16 2023, 11:20 PM

This seems strange to me. There is nothing that prevents a subsequent transformation to collapse the cast into the slice. Indeed it is probably better to do so.

Let me put some more context, and we can think about how to handle it correctly. So the problem is come from generic + tensor.pack vectorization. The vectorization flow is tile+fuse with the tiling sizes that makes outer dims be all ones. E.g.,

%6 = scf.for %arg0 = %c0_0 to %c16 step %c1
  %7 = scf.for %arg2 = %c0_0 to %c384 step %c1
    %21 = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
      iterator_types = ["parallel", "parallel"]}
      ins(%extracted_slice : tensor<?x?xf32>)
      outs(%extracted_slice_5 : tensor<?x?xf32>) { ... } -> tensor<?x?xf32>
   %pack = tensor.pack %21 inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %extracted_slice_15 : tensor<?x?xf32> -> tensor<1x1x8x1xf32>
   ...

Then we generalize the %pack op and kick in the generic vectorizer. The issue actually happen at generalization. The generalization converts the pack op into extract_slice + transpose + insert_slice. The IR after generalization:

%5 = scf.for %arg0 = %c0 to %c16 step %c1 iter_args(%arg1 = %2) -> (tensor<16x384x8x1xf32>) {
  %6 = scf.for %arg2 = %c0 to %c384 step %c1 iter_args(%arg3 = %arg1) -> (tensor<16x384x8x1xf32>) {
    %7 = affine.min affine_map<(d0) -> (d0 * -8 + 128, 8)>(%arg0)
    %8 = affine.min affine_map<(d0) -> (-d0 + 384, 1)>(%arg2)
    %9 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg0)
    %10 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg0)
    %extracted_slice = tensor.extract_slice %3[%9, %arg2] [%7, %8] [1, 1] : tensor<128x384xf32> to tensor<?x?xf32>
    %extracted_slice_0 = tensor.extract_slice %4[%10, %arg2] [%7, %8] [1, 1] : tensor<128x384xf32> to tensor<?x?xf32>
    %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%extracted_slice_0 : tensor<?x?xf32>) {
    ^bb0(%in: f32, %out: f32):
      %13 = arith.addf %in, %in : f32
      linalg.yield %13 : f32
    } -> tensor<?x?xf32>
    %extracted_slice_1 = tensor.extract_slice %arg3[%arg0, %arg2, 0, 0] [1, 1, 8, 1] [1, 1, 1, 1] : tensor<16x384x8x1xf32> to tensor<1x1x8x1xf32>
    %extracted_slice_2 = tensor.extract_slice %11[0, 0] [8, 1] [1, 1] : tensor<?x?xf32> to tensor<8x1xf32>
    %12 = tensor.empty() : tensor<8x1xf32>
    %transposed = linalg.transpose ins(%extracted_slice_2 : tensor<8x1xf32>) outs(%12 : tensor<8x1xf32>) permutation = [0, 1]
    %inserted_slice = tensor.insert_slice %transposed into %extracted_slice_1[0, 0, 0, 0] [1, 1, 8, 1] [1, 1, 1, 1] : tensor<8x1xf32> into tensor<1x1x8x1xf32>
    %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg3[%arg0, %arg2, 0, 0] [1, 1, 8, 1] [1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<16x384x8x1xf32>
    scf.yield %inserted_slice_3 : tensor<16x384x8x1xf32>
  }

If we look into the IR, we will find that the generic op has dynamic shapes. I found that extract_slice op stops the shape propagation from transpose to generic op. Thus the generic op is still in dynamic shape. This makes the vectorization fail on the generic op.

In the regular tile+fuse linalg ops + vectorization. There are direct deps between linalg ops. The linalg canonicalization pattern (which is InferStaticShapeOfOperands) can infer the static shapes for other operands. It basically would insert some tensor.cast ops around the linalg ops and fold it into the producers. That's why I'm thinking about inserting some known information (like tensor.cast op) before extract_slice op. It is a valid insertion because the input is either aligned or padded. If the cast op is inserted, it can be used to infer static shapes for generic ops; make vectorization work.


After writing this down, I found that we still can't vectorized the generic op if it's not aligned to inner tiling sizes. Hopefully, vector masking trick or whole program data-layout-propagation can handle it better. That's a separate issue. (We can chat offline if we need more bandwidth.)

hanchung abandoned this revision.Feb 23 2023, 10:24 AM