This is an archive of the discontinued LLVM Phabricator instance.

[mlir][vector] Fold extractOp coming from broadcastOp
ClosedPublic

Authored by ThomasRaoux on Oct 2 2020, 12:04 PM.

Details

Summary

Combine ExtractOp with scalar result with BroadcastOp source. This is useful to be able to incrementally convert degenerated vector of one element into scalar.

Diff Detail

Event Timeline

ThomasRaoux created this revision.Oct 2 2020, 12:04 PM
Herald added a project: Restricted Project. · View Herald TranscriptOct 2 2020, 12:04 PM
ThomasRaoux requested review of this revision.Oct 2 2020, 12:04 PM
aartbik requested changes to this revision.Oct 5 2020, 4:27 PM
aartbik added inline comments.
mlir/lib/Dialect/Vector/VectorOps.cpp
817

This transformation is not correct (as written).
For example

func @fold_extract_broadcast(%arg0: vector<4xf32>) -> f32 {
  %0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x2x4xf32>
  %1 = vector.extract %0[0, 1, 2] : vector<1x2x4xf32>
  return %1 : f32
}

will break. You will need to do a bit more analysis of the types (but in that case, you can probably generalize beyond scalars).

This revision now requires changes to proceed.Oct 5 2020, 4:27 PM

Fix bug when broadcast source and vector.extract type mismatch

ThomasRaoux added inline comments.Oct 5 2020, 7:13 PM
mlir/lib/Dialect/Vector/VectorOps.cpp
817

Thanks for catching that. I missed to consider that broadcast source can be a vector. I generalized it to vector as long as the type broadcast source is the same as extract destination the transformation is correct. I added a test for the vector case and a negative test as well.

aartbik accepted this revision.Oct 5 2020, 7:57 PM
This revision is now accepted and ready to land.Oct 5 2020, 7:57 PM
nicolasvasilache requested changes to this revision.Oct 6 2020, 12:47 AM
nicolasvasilache added inline comments.
mlir/lib/Dialect/Vector/VectorOps.cpp
819

Why worry about the types here?
Shouldn't you just drop the n-k first dimensions from the extract and turn it into vector.extract %a[2] : vector<4xf32> to f32?
Depending on the dimension of the extract op compared to n-k you have 3 cases.

This revision now requires changes to proceed.Oct 6 2020, 12:47 AM
ThomasRaoux added inline comments.Oct 6 2020, 8:05 AM
mlir/lib/Dialect/Vector/VectorOps.cpp
819

I was trying to only handle the case where the extract and broadcast cancel each other. I can handle also the case where the rank of broadcast source is greater than the rank of extract result. I don't think I can handle the case where the rank of vector result is greater than the rank of broadcast source since I would need to create a new broadcast operation, my understanding is that the fold method shouldn't create new operations?
What do you think?

mlir/lib/Dialect/Vector/VectorOps.cpp
819

Right, the third case would have to be a canonicalization pattern followed by DCE (if there are no other uses).
It seems undesirable to have both a folding and a canonicalization for the overlap of the 3 cases.

I'd say let's make the folding support the 2 cases it can with a TODO that if/when we want the third we should move all this to a canonicalization pattern ?

Add extract case of folding when extract result rank is smaller than broadcast source rank.

ThomasRaoux added inline comments.Oct 6 2020, 8:42 AM
mlir/lib/Dialect/Vector/VectorOps.cpp
819

Sounds good, I added the case 2 where result rank is smaller than broadcast source rank and added a TODO for the case where result rank if bigger than broadcast source rank.

This revision is now accepted and ready to land.Oct 6 2020, 8:46 AM
aartbik added inline comments.Oct 6 2020, 9:41 AM
mlir/lib/Dialect/Vector/VectorOps.cpp
842

typo: broadcast

This revision was automatically updated to reflect the committed changes.