This is an archive of the discontinued LLVM Phabricator instance.

[mlir][Vector] Add pattern to reorder elementwise and broadcast ops
ClosedPublic

Authored by awarzynski on Jun 13 2023, 6:23 AM.

Details

Summary

The new pattern will replace elementwise(broadcast) with
broadcast(elementwise) when safe.

Diff Detail

Event Timeline

awarzynski created this revision.Jun 13 2023, 6:23 AM
Herald added a project: Restricted Project. · View Herald Transcript
awarzynski requested review of this revision.Jun 13 2023, 6:23 AM
kuhar added a reviewer: kuhar.Jun 13 2023, 7:01 AM
kuhar added a subscriber: kuhar.
kuhar added inline comments.
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
906

Member casts are deprecated, use free cast functions: https://mlir.llvm.org/deprecation/

927–933

This is the third loop that inspects operands and casts them to broadcasts. Why not do the checks and collect all the broadcast elements in just one loop?

935

nit: Type vectorTy = *op->result_type_begin(); or op->getResultTypes()[0];?

1366–1367

Why did you decide to add the new pattern here? Is it because the other patterns tend to create code that needs this cleanup around elementwise operations on broadcasts?

Thanks for taking a look, I will send an update shortly.

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
906

Thanks!

927–933

Good point, though I would still seperate "checking" from "processing".

935

Thanks, that's neat!

Btw, for consistency with the rest of this file, I will stick with vectorType (instead of using vectorTy).

1366–1367

Why did you decide to add the new pattern here?

ReorderElementwiseOpsOnTranspose is the closest to what I am trying to do. But perhaps a new category of patterns would be better 🤔 . Any suggestions?

Is it because the other patterns tend to create code that needs this cleanup around elementwise operations on broadcasts?

Not really. This is fairly generic.

Address comments from Jakub

Thanks a lot for working on this! I think this will have an impact beyond the nD tensor extract! A few comments.

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
907

We should also check for some properties of the op. There is a trait called ElementWiseMappable something or similar that we can use here.
In this regard, I would add a test with a vector.contract where all the inputs are broadcasted from a scalar... We do not want this transformation to trigger on this one.

mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
403 ↗(On Diff #530907)
  • Move this to a different file? Maybe to a dedicated one?
  • Add test with vector1D to vector 2D broadcast?
  • Add test with vector.contract of broadcasted scalar operands (it shouldn't trigger)?
  • Any other negative tests?
kuhar added inline comments.Jun 13 2023, 12:02 PM
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
912–915

nit: Create a local variable for the broadcast so that we don't have to cast twice.

1366–1367

I'm not very familiar with patterns in this file, but this strikes me like a very specific choice, while this patterns seems like a more general cleanup that someone may want to run without all the other reduction to contract patterns. If this is intentional, could you explain this in the commit description, and/or mention that we may want a better 'home' populate function for this pattern?

dcaballe added inline comments.Jun 13 2023, 12:18 PM
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
1366–1367

Sorry, I had missed this. Yes, I think it makes sense to create a different populateVectorBroadcast... to get more finer-grain control on this pattern. I can't find any other populate where this would make sense.

Addressing comments from Diego an Jakub:

  • added more tests and moved to a dedicated file
  • created a dedicated populateVector(...)Patterns
  • added support for broadcasts of shaped types (also added a test to demonstrate)
kuhar added inline comments.Jun 14 2023, 7:38 AM
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
922

You don't need the isa now, just return bcast && bcast.get..;

mlir/test/Dialect/Vector/vector-remove-broadcast.mlir
29–30 ↗(On Diff #531318)

Can you also add some test(s) that demonstrate that your precondition checks work? For instance the one that checks that *all* arguments should be broadcasts

awarzynski added inline comments.Jun 14 2023, 8:15 AM
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
922

Argh, I missed that, Thank you :)

mlir/test/Dialect/Vector/vector-remove-broadcast.mlir
29–30 ↗(On Diff #531318)

one that checks that *all* arguments should be broadcasts

Thanks for the suggestion! I'm actually struggling to find other good examples. I could add "vector.broadcast + vector.extractelement" (it's not ElementwiseMappable), but it doesn't feel right. And everything in Arith is ElementwiseMappable (most of the example that I can think of are a mix of Arith + Vector). Happy to add more if you have more ideas.

Add a test, remove redundant llvm::isa

kuhar accepted this revision.Jun 14 2023, 8:32 AM

LGTM, thanks for the changes. You may want to wait for a thumbs up from @dcaballe before submitting.

This revision is now accepted and ready to land.Jun 14 2023, 8:32 AM
dcaballe accepted this revision.Jun 14 2023, 3:38 PM

Thanks for addressing the feedback. LGTM! Just some minor comments that you can address before submitting.

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
321–322

Instead of removing the DAG here, could we reduce this test to match only what it's important? That would help maintainability

mlir/test/Dialect/Vector/vector-remove-broadcast.mlir
1 ↗(On Diff #531345)

personal opinion, feel free to ignore: maybe "sink-vector-broadcast" is more accurate?

62–64 ↗(On Diff #531345)

good to add a -DAG here

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
377

personal opinion, feel free to ignore: ->SinkVectorBroadcast?

Thank you both for reviewing and for helping me with this!

@dcaballe , I will address your comments in the final version that I am about to merge.

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
321–322

Agreed. I've tried to trim it a couple of times, but always get stuck. I will just remove pretty much everything apart from the scalar and the contiguous loads, which are the key bits in this test.

mlir/test/Dialect/Vector/vector-remove-broadcast.mlir
1 ↗(On Diff #531345)

Naming is hard, but gets better with experience. That's something that you have tons of, so will happily incorporate your suggestion :)