The static shape information is known in the pattern because it limits
outer dims to be all 1s. In this context, inserting tensor.cast op is
safe and it gives other analysis more information.
Details
- Reviewers
mravishankar chelini nicolasvasilache
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
This seems strange to me. There is nothing that prevents a subsequent transformation to collapse the cast into the slice. Indeed it is probably better to do so.
Let me put some more context, and we can think about how to handle it correctly. So the problem is come from generic + tensor.pack vectorization. The vectorization flow is tile+fuse with the tiling sizes that makes outer dims be all ones. E.g.,
%6 = scf.for %arg0 = %c0_0 to %c16 step %c1 %7 = scf.for %arg2 = %c0_0 to %c384 step %c1 %21 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%extracted_slice_5 : tensor<?x?xf32>) { ... } -> tensor<?x?xf32> %pack = tensor.pack %21 inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %extracted_slice_15 : tensor<?x?xf32> -> tensor<1x1x8x1xf32> ...
Then we generalize the %pack op and kick in the generic vectorizer. The issue actually happen at generalization. The generalization converts the pack op into extract_slice + transpose + insert_slice. The IR after generalization:
%5 = scf.for %arg0 = %c0 to %c16 step %c1 iter_args(%arg1 = %2) -> (tensor<16x384x8x1xf32>) { %6 = scf.for %arg2 = %c0 to %c384 step %c1 iter_args(%arg3 = %arg1) -> (tensor<16x384x8x1xf32>) { %7 = affine.min affine_map<(d0) -> (d0 * -8 + 128, 8)>(%arg0) %8 = affine.min affine_map<(d0) -> (-d0 + 384, 1)>(%arg2) %9 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg0) %10 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg0) %extracted_slice = tensor.extract_slice %3[%9, %arg2] [%7, %8] [1, 1] : tensor<128x384xf32> to tensor<?x?xf32> %extracted_slice_0 = tensor.extract_slice %4[%10, %arg2] [%7, %8] [1, 1] : tensor<128x384xf32> to tensor<?x?xf32> %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%extracted_slice_0 : tensor<?x?xf32>) { ^bb0(%in: f32, %out: f32): %13 = arith.addf %in, %in : f32 linalg.yield %13 : f32 } -> tensor<?x?xf32> %extracted_slice_1 = tensor.extract_slice %arg3[%arg0, %arg2, 0, 0] [1, 1, 8, 1] [1, 1, 1, 1] : tensor<16x384x8x1xf32> to tensor<1x1x8x1xf32> %extracted_slice_2 = tensor.extract_slice %11[0, 0] [8, 1] [1, 1] : tensor<?x?xf32> to tensor<8x1xf32> %12 = tensor.empty() : tensor<8x1xf32> %transposed = linalg.transpose ins(%extracted_slice_2 : tensor<8x1xf32>) outs(%12 : tensor<8x1xf32>) permutation = [0, 1] %inserted_slice = tensor.insert_slice %transposed into %extracted_slice_1[0, 0, 0, 0] [1, 1, 8, 1] [1, 1, 1, 1] : tensor<8x1xf32> into tensor<1x1x8x1xf32> %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg3[%arg0, %arg2, 0, 0] [1, 1, 8, 1] [1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<16x384x8x1xf32> scf.yield %inserted_slice_3 : tensor<16x384x8x1xf32> }
If we look into the IR, we will find that the generic op has dynamic shapes. I found that extract_slice op stops the shape propagation from transpose to generic op. Thus the generic op is still in dynamic shape. This makes the vectorization fail on the generic op.
In the regular tile+fuse linalg ops + vectorization. There are direct deps between linalg ops. The linalg canonicalization pattern (which is InferStaticShapeOfOperands) can infer the static shapes for other operands. It basically would insert some tensor.cast ops around the linalg ops and fold it into the producers. That's why I'm thinking about inserting some known information (like tensor.cast op) before extract_slice op. It is a valid insertion because the input is either aligned or padded. If the cast op is inserted, it can be used to infer static shapes for generic ops; make vectorization work.
After writing this down, I found that we still can't vectorized the generic op if it's not aligned to inner tiling sizes. Hopefully, vector masking trick or whole program data-layout-propagation can handle it better. That's a separate issue. (We can chat offline if we need more bandwidth.)