Combine ExtractOp with scalar result with BroadcastOp source. This is useful to be able to incrementally convert degenerated vector of one element into scalar.
Details
Diff Detail
Event Timeline
mlir/lib/Dialect/Vector/VectorOps.cpp | ||
---|---|---|
817 | This transformation is not correct (as written). func @fold_extract_broadcast(%arg0: vector<4xf32>) -> f32 { %0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x2x4xf32> %1 = vector.extract %0[0, 1, 2] : vector<1x2x4xf32> return %1 : f32 } will break. You will need to do a bit more analysis of the types (but in that case, you can probably generalize beyond scalars). |
mlir/lib/Dialect/Vector/VectorOps.cpp | ||
---|---|---|
817 | Thanks for catching that. I missed to consider that broadcast source can be a vector. I generalized it to vector as long as the type broadcast source is the same as extract destination the transformation is correct. I added a test for the vector case and a negative test as well. |
mlir/lib/Dialect/Vector/VectorOps.cpp | ||
---|---|---|
819 | Why worry about the types here? |
mlir/lib/Dialect/Vector/VectorOps.cpp | ||
---|---|---|
819 | I was trying to only handle the case where the extract and broadcast cancel each other. I can handle also the case where the rank of broadcast source is greater than the rank of extract result. I don't think I can handle the case where the rank of vector result is greater than the rank of broadcast source since I would need to create a new broadcast operation, my understanding is that the fold method shouldn't create new operations? |
mlir/lib/Dialect/Vector/VectorOps.cpp | ||
---|---|---|
819 | Right, the third case would have to be a canonicalization pattern followed by DCE (if there are no other uses). I'd say let's make the folding support the 2 cases it can with a TODO that if/when we want the third we should move all this to a canonicalization pattern ? |
Add extract case of folding when extract result rank is smaller than broadcast source rank.
mlir/lib/Dialect/Vector/VectorOps.cpp | ||
---|---|---|
819 | Sounds good, I added the case 2 where result rank is smaller than broadcast source rank and added a TODO for the case where result rank if bigger than broadcast source rank. |
mlir/lib/Dialect/Vector/VectorOps.cpp | ||
---|---|---|
842 | typo: broadcast |
This transformation is not correct (as written).
For example
will break. You will need to do a bit more analysis of the types (but in that case, you can probably generalize beyond scalars).