Page MenuHomePhabricator

Please use GitHub pull requests for new patches. Avoid migrating existing patches. Phabricator shutdown timeline

[mlir][tosa] make Select operator broadcastable in the pass

Authored by tatwaichong on Dec 1 2022, 4:08 PM.



Making Select broadcastable can let this op easier to use.

Change-Id: I4a4bec4f7cbe532e954a5b4fe53136676ab4300c

Diff Detail

Event Timeline

tatwaichong created this revision.Dec 1 2022, 4:08 PM
Herald added a project: Restricted Project. · View Herald TranscriptDec 1 2022, 4:08 PM
tatwaichong requested review of this revision.Dec 1 2022, 4:08 PM
tatwaichong edited the summary of this revision. (Show Details)Dec 1 2022, 4:12 PM
tatwaichong added a reviewer: rsuderman.
tatwaichong added a reviewer: eric-k256.
rsuderman requested changes to this revision.Dec 5 2022, 11:45 AM
rsuderman added inline comments.

Include an notifyMatchFailure for the return. Its useful for debugging.


What happens if pred broadcasts both times? They may not broadcast in the same way.


Ditto notifyMatchRewriteFailure.




Seeing this nested ternary makes me concerned the broadcasted predicate may have some issues. It assumes that if both newPred1 and newPred2 are new, they could still be the same broadcasting.

This revision now requires changes to proceed.Dec 5 2022, 11:45 AM
  • replace failure with notifyMatchFailure
  • address review comment
  • add more mlir tests
  • add comment in code
tatwaichong marked 3 inline comments as done.Dec 12 2022, 1:48 PM
tatwaichong added inline comments.

This pattern assumes the rank of output is known. In broadcasting, the reshaped rank of inputs must match the rank of output that must be equal to the maximum of the input ranks. And [[ | reshapeLowerToHigher ]] checks if the reshaped rank of input arguments match with the rank of output.

So both newPred1 and newPred2 are still the same broadcasting if they are new at the same round of the pass.
e.g. given an example that both newPred1 and newPred2 are new.
%3 = ""(%0, %1, %2) : (tensor<1xi1>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>

When a predicate gets changed, it cannot have the highest rank. On the other hand, a lower-rank predicate gets broadcast success in reshapeLowerToHigher only when it aligns with an input having the highest rank (=output rank).

Further, in the example below.
%3 = ""(%0, %1, %2) : (tensor<1xi1>, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
after calling reshapeLowerToHigher(pred=<1xi1>, input1=<32x8xf32>, output=<1x32x32x8xf32>) will return failure, and won’t get a new predicate.

To my understanding, Patterns rewriting pass will repeatedly try to do pattern matching and stop when it meets convergence. The broadcasting can be done in multiple rounds of the pass. Let's give an example that the three input tensors are a, b, and c where rank(a) < rank(b) < rank(c). Hence there are six permutations of the set. They are (a, b, c), (a, c, b), (b, a, c), (b, c, a), (c, a, b), and (c, b, a). The rank of output shape is c as this is the maximum.

Let me take the case (a, b, c) where (pred, input1, input2) = (32x8, 32x32x8, 1x32x32x8) as an example,
a) The first round of the pass
Fail on broadcast(perd, input1) as the rank result doesn’t match the output.
Success on broadcast(perd, input2), then we get "tosa.reshape"(%arg0) {new_shape = [1, 1, 32, 8]} : (tensor<32x8xi1>) -> tensor<1x1x32x8xi1>

b) Second round of the pass
Success on broadcast(perd`, input1), then we get "tosa.reshape"(%arg1) {new_shape = [1, 32, 32, 8]} : (tensor<32x32x8xf32>) -> tensor<1x32x32x8xf32>

c) Nothing changes within further rounds because the code cannot be rewritten more as it's already correct.

I added more mlir tests to show its coverage.


I more or less replied with my thoughts above on your question.

About this nested ternary, I'm also not satisfied with this. Do you come up with better writing?

rsuderman added inline comments.Dec 12 2022, 2:11 PM

I see, it should be possible by tweaking reshapeLowerToHigher's behavior with returning values and guarantee the broadcasting happens all at once. See comment below.


Rather than doing this predicates, set outInput1 and outInput2 to the input values in reshapeLowerToHigher. This avoids all of the predicate selection as the values will be updated at each point.

While you are there, make it return a bool instead of an integer.

overload reshapeLowerToHigher to support 3 inputs so that the broadcasting happens all at once.

tatwaichong marked an inline comment as done.Jan 6 2023, 2:14 PM
tatwaichong added inline comments.

Yes, it can. I create an overloading reshapeLowerToHigher to support 3 inputs (the function logic should be able to extend to support more operands). Sort the tensors by rank to make broadcasting easier as the tensor in the last position of the sorted list has the highest rank.

rsuderman added inline comments.Jan 6 2023, 2:17 PM

We shouldn't need another function. The change I wanted was just to invoke the broadcast twice. Specifically:

rebroadcast(input1, input2, result1, result2)
rebroadcast(result1, input3, result1, result3)

Then check if inputN == resultN for each input.

The idea is we can apply each broadcasting behavior and know the result is legal as they compound. Previously you had extra mutations / checks and no interaction between the first broadcast and the second. By guaranteeing they chain in this way it minimizes the number of cases.

tatwaichong added inline comments.Jan 6 2023, 2:47 PM

I see. I'm not sure if I fully catch your message, Let me double-check my understanding.

If we want to guarantee the whole broadcasting for all inputs happens at once in one round of the pass.
Given the highest rank tensor is input3, and the others have a lower rank. In this case, we still need to do an extra reshapeLowerToHigher, is this right?

matchAndRewrite () {
  reshapeLowerToHigher(input1, input2, result1, result2)
  reshapeLowerToHigher(result1, input3, result1, result3)

  if (result2.rank != result3.rank)
    reshapeLowerToHigher(result2, result3, result2, result3)

roll back to use the existing pair-wise broadcasting function. Apply each broadcasting to 3 different pairs of inputs. By chaining them this way as a compound the broadcasting happens all at once.
update mlir tests.

rsuderman added inline comments.Feb 6 2023, 10:40 AM

You don't need separate input1 and outInput1. Just mutate the referenced value passed in. This avoids doing the swapping in the lowering below.


See comment above - if you just mutate input1 and input2 via references you can more easily track the three inputs and avoid the extra reassigning.

As suggested, mutate the referenced inputs passed in to avoid doing the extra reassigning and stateful variabled in the lowering.

tatwaichong marked 2 inline comments as done.Feb 6 2023, 2:28 PM
rsuderman accepted this revision.Feb 7 2023, 12:02 PM
This revision is now accepted and ready to land.Feb 7 2023, 12:02 PM
This revision was automatically updated to reflect the committed changes.