This is an archive of the discontinued LLVM Phabricator instance.

[mlir][tosa] Canonicalize concatenate->slice sequence
ClosedPublic

Authored by lhutton1 on Feb 22 2023, 1:46 AM.

Details

Summary

Adds a canonicalizer for the concatenate->slice sequence where
an output of slice can be replaced with an input of concatenate.

This is useful in the context of operations with complex inputs
and outputs that are legalized from a framework such as TFL.
For example, a TFL graph (FFT->FFT) will be legalized to the
following TOSA graph:

<complex input>
    /     \
slice    slice
    \     /
      FFT
     /   \     -+
  concatenate   |
    /     \     |  Redundant
slice    slice  |
    \     /    -+
      FFT
    /     \
  concatenate
       |
<complex output>

Concatenate and slice operations at the boundaries of the graph are
useful as they maintain the correct correspondance of input/output
tensors to the original TFL graph. However, consecutive
complex operations will result in redundant concatenate->slice
sequences which should be removed from the final TOSA graph.

The canonicalization does not currently handle dynamic types.

Signed-off-by: Luke Hutton <luke.hutton@arm.com>

Diff Detail

Event Timeline

lhutton1 created this revision.Feb 22 2023, 1:46 AM
lhutton1 requested review of this revision.Feb 22 2023, 1:46 AM

Friendly ping for review

I don't think you've selected reviewers yet (Nicolas seems unlikely as being selected, perhaps some overly broad Herald rule).

Thanks, yes I believe the reviewer was assigned automatically - hope the updated list is more sensible

rsuderman requested changes to this revision.Mar 13 2023, 3:47 PM
rsuderman added inline comments.
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
549

It is worth noting that this pass could still be useful when false as it would be slicing a portion of the concat input pre concatenation. Not necessary for this patch but I would include a comment about that, as it could push the slice up above the concatenation.

568

I would avoid performing the replacement here. The loop should focus on finding the relevant input that matches the target and perform the replacement outside of the loop. Some optimizers perform multiple replacements so its usually safer to avoid replaceOp within loop bodies when possible.

573

You would then note if you found a valid replacement, and perform the replaceOp and return success().

This revision now requires changes to proceed.Mar 13 2023, 3:47 PM

Thanks for the review, I just had a question before making the changes

lhutton1 added inline comments.Mar 15 2023, 4:07 PM
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
549

Happy to add a comment, although I'm not sure I fully understand, could you provide a small example of when this might happen?

lhutton1 updated this revision to Diff 506600.Mar 20 2023, 8:22 AM

Defer replacement to outside loop

lhutton1 marked 2 inline comments as done.Mar 20 2023, 8:22 AM
rsuderman requested changes to this revision.Mar 20 2023, 12:57 PM
rsuderman added inline comments.
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
549

In short - you do not need this check if you support slicing the input that corresponds. e.g.:

func.func @main(%arg0 : tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<2x2xf32> {
%0 = tosa.concat(%arg0, %arg1) { axis = 0 : i32} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<8x8xf32>
%1 = tosa.slice(%0) { size = array<i64: 2, 2>, offset = array<i64 : 0, 0> (tensor<8x8xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
}

Can be represented by

func.func @main(%arg0 : tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<2x2xf32> {
%1 = tosa.slice(%arg0) { size = array<i64: 2, 2>, offset = array<i64 : 0, 0> (tensor<4x4xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
}

It is worth noting this is not a hard requirement but validating the slice is the whole size is almost the same amount of work.

568

If you perform the sub-slice improvement you should check that the slice does not span the concatenation and we can slice from the subset.

578

You don't need the else because you return in the main case. I would also change the check for returning an error if you do not found a valid substitution value.

This revision now requires changes to proceed.Mar 20 2023, 12:57 PM
lhutton1 updated this revision to Diff 507125.Mar 21 2023, 2:32 PM

Address review comments

  • remove unecessary else statement
  • replace concat->slice sequence with slice to account for slicing on non-concatenated axis

Apologies for all the comments. This should be the last one and we can land it. I was hinting you could land the slice improvements in a separate followup (just leave a comment with details). But thank you for all the changes! Looks like a great PR.

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
540

You will need to clone this into a llvm::SmallVector<...>

546

You can remove the inputOffset math and instead check the sliceStart and sliceSize values below.

557

If we check that the generated slice falls entirely within the input, we know whether we can slice out just this section.

if (sliceStart[axis] >= 0 && (sliceStart[axis] + sliceSize[axis]) < inputType.getDimSize(axis)
566

You will also need to update sliceStart as below:

sliceStart[axis] -= inputType.getDimSize(axis)'

This adjusts the slice offset for the sub offset in the slice.

mlir/test/Dialect/Tosa/canonicalize.mlir
468

I would delete this tests. tosa.concat has a folder that already folds all cases with a single operand.

481

Reword away from invalid, it implies that the IR is not legal for the dialect when it just is referring to the optimization. Maybe cross_concat_inputs?

498

Update this one for start = array<i64: 1, 3, 12> so that we pull from the other operand. It also checks the offset updating changes above.

lhutton1 marked 2 inline comments as not done.Mar 21 2023, 2:53 PM

No problem at all, thanks for your patience :)

If we check that the generated slice falls entirely within the input,...

Thanks, I was just thinking about this after I submitted the changes, I'd missed this comment previously.

I'll try to action these shortly

lhutton1 updated this revision to Diff 507184.Mar 21 2023, 4:55 PM
lhutton1 marked 13 inline comments as done.
  • remove unecessary test case
  • rename test
  • allow partial slice of input on concatenated axis
rsuderman accepted this revision.Mar 21 2023, 5:53 PM
rsuderman added inline comments.
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
554

Good catch! My example was off by one :)

This revision is now accepted and ready to land.Mar 21 2023, 5:53 PM

Found a failure in https://github.com/openxla/iree/issues/14479 with this change. It ends up creating a

%179 = "tosa.slice"(%176) {size = array<i64: 1, 6, 1, 1>, start = array<i64: 0, 480, 0, 0>} : (tensor<1x480x1x1xf32>) -> tensor<1x6x1x1xf32>

which is OOB. The full repro can be done with

tensorflow/compiler/mlir/lite/flatbuffer_translate -- --tflite-flatbuffer-to-mlir  /tmp/bt.tflite -o /tmp/bt.mlir && tensorflow/compiler/mlir/tf-opt -- --tfl-to-tosa-pipeline /tmp/bt.mlir -o /tmp/tosa.mlir

and looking at first slice. I'll try and add smaller reproducer.

Airplane wifi making it tough,

func.func @foo(%177: tensor<1x6x1x1xf32>, %179: tensor<1x480x1x1xf32>) -> tensor<1x486x1x1xf32> {

%178 = "tosa.concat"(%177, %176) {axis = 1 : i64} : (tensor<1x6x1x1xf32>, tensor<1x480x1x1xf32>) -> tensor<1x486x1x1xf32>
%179 = "tosa.slice"(%178) {size = array<i64: 1, 6, 1, 1>, start = array<i64: 0, 480, 0, 0>} : (tensor<1x486x1x1xf32>) -> tensor<1x6x1x1xf32>
return %179: tensor<1x6x1x1xf32>

}

I think should show it (typed not tested ...).

Close

func.func @foo(%arg0: tensor<1x6x1x1xf32>, %arg1: tensor<1x480x1x1xf32>) -> tensor<1x6x1x1xf32> {
  %0 = "tosa.concat"(%arg0, %arg1) {axis = 1 : i64} : (tensor<1x6x1x1xf32>, tensor<1x480x1x1xf32>) -> tensor<1x486x1x1xf32>
  %1 = "tosa.slice"(%0) {size = array<i64: 1, 6, 1, 1>, start = array<i64: 0, 480, 0, 0>} : (tensor<1x486x1x1xf32>) -> tensor<1x6x1x1xf32>
  return %1 : tensor<1x6x1x1xf32>
}

through mlir-opt --canonicalize shows it.