This is an archive of the discontinued LLVM Phabricator instance.

[mlir][tensor] Fix transpose permutation in tensor.pack generalization pattern
ClosedPublic

Authored by qedawkins on Feb 20 2023, 12:10 PM.

Details

Summary

The generalization pattern for tensor.pack was inverting the
innerDimsPos permutation when normalizing. Thus, the transpose op
produced when generalizing wouldn't do the correct transpose. This can
be observed with the following example by comparing the IR generated
with and without data layout op (pack/unpack) propagation.

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
builtin.module {
  func.func @forward(%arg0: tensor<3x5x7xf32>, %arg1: tensor<3x5x7xf32>) -> tensor<1x1x1x5x7x3xf32> {
    %empty = tensor.empty() : tensor<3x5x7xf32>
    %elementwise = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<3x5x7xf32>, tensor<3x5x7xf32>) outs(%empty : tensor<3x5x7xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %add = arith.addf %in, %in_1 : f32
      linalg.yield %add : f32
    } -> tensor<3x5x7xf32>
    %pack_empty = tensor.empty() : tensor<1x1x1x5x7x3xf32>
    %pack = tensor.pack %elementwise inner_dims_pos = [1, 2, 0] inner_tiles = [5, 7, 3] into %pack_empty : tensor<3x5x7xf32> -> tensor<1x1x1x5x7x3xf32>
    return %pack : tensor<1x1x1x5x7x3xf32>
  }
}

With the data layout propagation patterns through elementwise ops:

#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
module {
  func.func @forward(%arg0: tensor<3x5x7xf32>, %arg1: tensor<3x5x7xf32>) -> tensor<1x1x1x5x7x3xf32> {
    %0 = tensor.empty() : tensor<1x1x1x5x7x3xf32>
    %1 = tensor.empty() : tensor<1x1x1x5x7x3xf32>
    %2 = tensor.empty() : tensor<7x3x5xf32>
    %transposed = linalg.transpose ins(%arg0 : tensor<3x5x7xf32>) outs(%2 : tensor<7x3x5xf32>) permutation = [2, 0, 1]
    %inserted_slice = tensor.insert_slice %transposed into %1[0, 0, 0, 0, 0, 0] [1, 1, 1, 7, 3, 5] [1, 1, 1, 1, 1, 1] : tensor<7x3x5xf32> into tensor<1x1x1x5x7x3xf32>
    %3 = tensor.empty() : tensor<1x1x1x5x7x3xf32>
    %4 = tensor.empty() : tensor<7x3x5xf32>
    %transposed_0 = linalg.transpose ins(%arg1 : tensor<3x5x7xf32>) outs(%4 : tensor<7x3x5xf32>) permutation = [2, 0, 1]
    %inserted_slice_1 = tensor.insert_slice %transposed_0 into %3[0, 0, 0, 0, 0, 0] [1, 1, 1, 7, 3, 5] [1, 1, 1, 1, 1, 1] : tensor<7x3x5xf32> into tensor<1x1x1x5x7x3xf32>
    %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%inserted_slice, %inserted_slice_1 : tensor<1x1x1x5x7x3xf32>, tensor<1x1x1x5x7x3xf32>) outs(%0 : tensor<1x1x1x5x7x3xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %6 = arith.addf %in, %in_2 : f32
      linalg.yield %6 : f32
    } -> tensor<1x1x1x5x7x3xf32>
    return %5 : tensor<1x1x1x5x7x3xf32>
  }
}

Without propagation:

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
  func.func @forward(%arg0: tensor<3x5x7xf32>, %arg1: tensor<3x5x7xf32>) -> tensor<1x1x1x5x7x3xf32> {
    %0 = tensor.empty() : tensor<3x5x7xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<3x5x7xf32>, tensor<3x5x7xf32>) outs(%0 : tensor<3x5x7xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %4 = arith.addf %in, %in_0 : f32
      linalg.yield %4 : f32
    } -> tensor<3x5x7xf32>
    %2 = tensor.empty() : tensor<1x1x1x5x7x3xf32>
    %3 = tensor.empty() : tensor<7x3x5xf32>
    %transposed = linalg.transpose ins(%1 : tensor<3x5x7xf32>) outs(%3 : tensor<7x3x5xf32>) permutation = [2, 0, 1]
    %inserted_slice = tensor.insert_slice %transposed into %2[0, 0, 0, 0, 0, 0] [1, 1, 1, 7, 3, 5] [1, 1, 1, 1, 1, 1] : tensor<7x3x5xf32> into tensor<1x1x1x5x7x3xf32>
    return %inserted_slice : tensor<1x1x1x5x7x3xf32>
  }
}

Where data layout propagation is doing a different transpose than what
generalization comes up with.

Diff Detail

Event Timeline

qedawkins created this revision.Feb 20 2023, 12:10 PM
qedawkins published this revision for review.Feb 20 2023, 12:16 PM
hanchung requested changes to this revision.Feb 21 2023, 2:41 PM
hanchung added inline comments.
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
588–595

The implementation of getPackNormalizedInnerPerm is weird to me.. it's calling getPackUnpackNormalizedInnerPerm and generates the inverse permutation.

Why not just use the same method and call invertPermutationVector here?

This revision now requires changes to proceed.Feb 21 2023, 2:41 PM
qedawkins updated this revision to Diff 499308.Feb 21 2023, 3:03 PM

Use invert helper instead of specialized function

qedawkins added inline comments.Feb 21 2023, 3:07 PM
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
588–595

Thanks, this is much cleaner :)

hanchung accepted this revision.Feb 22 2023, 11:29 AM
This revision is now accepted and ready to land.Feb 22 2023, 11:29 AM
This revision was landed with ongoing or failed builds.Feb 22 2023, 11:53 AM
This revision was automatically updated to reflect the committed changes.