Page MenuHomePhabricator

[mlir][shape] Generalize broadcast to a variadic number of shapes
ClosedPublic

Authored by tpopp on Feb 1 2021, 2:45 AM.

Details

Summary

Previously broadcast was a binary op. Now it can support more inputs.
This has been changed in such a way that for now, this is an NFC for
all broadcast operations that were previously legal.

Diff Detail

Event Timeline

tpopp created this revision.Feb 1 2021, 2:45 AM
tpopp requested review of this revision.Feb 1 2021, 2:45 AM
tpopp updated this revision to Diff 320440.Feb 1 2021, 4:51 AM

Follow naming convention in new functions.

jpienaar added inline comments.Feb 3 2021, 7:35 AM
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
84–89

Do we have places that use this form?

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
80

The last sentence is true for all rewrites, so redundant here.

90

A comment here would be good to explain why dim is used to get rank.

Also does dim work on a shape? (Tensor yes, and so are we guaranteed we'd be in tensor world here?)

188

Could more be reused here? E.g., is the nary one (excluding computing the max rank) less efficient for the binary case? Or vice versa, is multiple binary case applications less efficient than nary lowering?

herhut added inline comments.Feb 3 2021, 11:09 AM
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
115

inBound is confusing. It is true if we are outside the bounds of the index, right?

130

This sentence ends abruptly.

142

You could fold this up. In the then case, return reduceDim. In the else case, do the select.

188

Thinking of it, maybe doing multiple 2d broadcasts in a row (in the implementation, not the op) would yield similar performance.

mlir/lib/Dialect/Shape/IR/Shape.cpp
355

Is this intended?

tpopp updated this revision to Diff 321428.Feb 4 2021, 7:48 AM
tpopp marked 6 inline comments as done.

Fix variable names and adress other comments.

tpopp added a comment.Feb 4 2021, 7:48 AM

I'm going to wait on fixing the test until I hear your thoughts on the two options for lowering broadcast.

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
84–89

This was the original form of the build method before making the number of inputs variadic. I thought it might be nice to still have it for the common case and to make the transition essentially an NFC

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
90

This was guaranteed in the caller. I'll replicate the check here though in case of future uses from other locations.

188

I think performance differences between the two are negligible. We already fully unroll the tensor::GenerateOp, so they will be roughly the same.

For the binary case, we need a starting value during reductions in the nary lowering (without making the c++ code much more complex) while the binary case can skip that step. On the other hand, the binary case might recompute a small amount of work between each invocation for more than 2 inputs. Technically the nary case has the potential to be just as performant. I'm just not sure how clean I can make it look.

I personally find the n-ary lowering to be easier to read, and I think Stephan has in the past said he expected the binary case was easier to read. I think we should choose the implementation that we find easier to read and stick with that. I would like to hear your opinions on if you agree or not, and which you find easier to read.

herhut added inline comments.Feb 7 2021, 11:41 PM
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
188

I found the binary case easier because my n-ary case looked like a mess. Looking at your code, this is much nicer.

So let's ship the n-ary case only. The performance difference should be negligible.

Could you also extend the cstr_broadcastable accordingly? They need to be in sync otherwise broadcast cannot really be used.

tpopp updated this revision to Diff 322092.Feb 8 2021, 5:56 AM
tpopp marked 4 inline comments as done.

Removed binary broadcast case and updated tests accordingly.

tpopp added inline comments.Feb 9 2021, 3:35 AM
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
188

Binary is removed. I'll extend cstr_broadcastable in a follow up CL.

herhut accepted this revision.Feb 9 2021, 3:45 AM

Please address minor nits before landing.

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
151

This comment seems lonely.

152–167

Does it still make sense to have this in an extra function?

tpopp updated this revision to Diff 322376.Feb 9 2021, 6:30 AM
tpopp marked an inline comment as done.

Replace naryBroadcast helper with only a single dim calculation.

tpopp added inline comments.Feb 9 2021, 7:03 AM
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
151

Fixed

152–167

I refactored to have only part of this function separate as it's convenient or use in the broadcastable follow up

This revision was not accepted when it landed; it landed in state Needs Review.Feb 9 2021, 11:31 PM
This revision was automatically updated to reflect the committed changes.