Page MenuHomePhabricator

`shape` dialect: add some ops

Authored by silvas on Mar 25 2020, 6:23 PM.



shape dialect: add some ops

  • add to_extent_tensor
    • rename create_shape to from_extent_tensor for symmetry
  • add split_at and concat ops for basic shape manipulations

This set of ops is inspired by the requirements of lowering a dynamic-shape-aware batch matmul op. For such an op, the "matrix" dimensions aren't subject to broadcasting but the others are, and so we need to slice, broadcast, and reconstruct the final output shape. Furthermore, the actual broadcasting op used downstream uses a tensor of extents as its preferred shape interface for the actual op that does the broadcasting.

However, this functionality is quite general. It's obvious that to_extent_tensor is needed long-term to support many common patterns that involve computations on shapes. We can evolve the shape manipulation ops introduced here. The specific choices made here took into consideration the potentially unranked nature of the !shape.shape type, which means that a simple listing of dimensions to extract isn't possible in general.

Diff Detail

Event Timeline

silvas created this revision.Mar 25 2020, 6:23 PM
Herald added a project: Restricted Project. · View Herald TranscriptMar 25 2020, 6:23 PM
jpienaar added inline comments.Mar 26 2020, 5:27 AM

Instead of drop back and take back, how about an op that mirrors concat (well mirrors the binary concat) like split_at which produces two shapes.

silvas updated this revision to Diff 252985.Mar 26 2020, 2:55 PM

Address comments.

silvas updated this revision to Diff 252993.Mar 26 2020, 3:17 PM

Update description

silvas edited the summary of this revision. (Show Details)Mar 26 2020, 3:17 PM
silvas edited the summary of this revision. (Show Details)
Harbormaster completed remote builds in B50628: Diff 252993.

Nice, I think one more round and it is good


What happens for unranked and or dynamic tensors?


With your other rev, can these now be IndexTensor? (can also be done in follow up)


Add behavior for unranked here and below

silvas marked 3 inline comments as done.Mar 26 2020, 5:43 PM
silvas added inline comments.

it return the shape. Notice that the return type is not static shaped. For an unranked tensor it will return a tensor<?xi32> (or index someday).


Yeah, I want to do it in a follow-up. Just being consistent for now.


Whether something is unranked or not is a purely static concept, so it's totally orthogonal to an execution example like this.

jpienaar added inline comments.Mar 26 2020, 5:57 PM

Thus far FromExtentTensorOp has as input a tensor with known values and produces a fixed shape shape (e.g., there is no special numeric values that correspond to unknown). But another problem is that ToExtentTensor now needs to convert from a shape which may be an error to a tensor of ints. The intention is to have all shape ops be side-effect free and I don't see how that can be maintained for this op.

What I had though of here was something like:

shape.if_static(%s : shape) {
  ^bb(%t: tensor<?xi32>):
     ... now use as 
} else {

so that you only convert when it is safe,

silvas marked an inline comment as done.Mar 26 2020, 6:52 PM
silvas added inline comments.

There is no such thing as a tensor with unknown or dynamic or even partially specified shape at runtime.

One can consider the abstract !shape.shape having a runtime backing representation which represents a dataflow lattice value (such as tracking extent upper bounds to allow buffer allocation before a data-dependent shape op (like "unique") completes), but that would require specialized ops that are aware of the underlying backing representation for the abstract !shape.shape in order to construct such a value, since the choice of lattice is arbitrary.

There's only a small number of ops on !shape.shape that reify the concept of the shape being a lattice value (since there just isn't much you can abstractly do with a lattice, by definition). shape.join is one of them (and the complementary "most general shape" is the only other one I can think of). All other ops should be defined as if they are runtime computations on concrete shapes, and any particular lattice that a user intends to use will need to substitute in appropriate transfer functions either via dataflow analysis statically, or reified in the ir during lowering (in their own dialect), or literally have the runtime manifestation of !shape.shape have virtual methods they can override to substitute the transfer functions for the primitives.

As far as side effects in the error case, you can't dodge them at least how you've defined them here, at least in my use case. The result of the shape.concat/shape.slice_at/etc. is used to broadcast a tensor to a particular concrete runtime shape. If the !shape.shape is an error, then now you've just passed the error onto the "broadcast this tensor to this shape op". In general you could pass the resulting shape to "{numpy,tf,torch}.zeros(shape)" which to your point should be pure.

I feel strongly that we need a dialect that can be used to model runtime shape computations. If you want this dialect to be used purely for abstract dataflow then that's okay and I can create a different dialect.

I think this error handling situation needs more thought. Neither tf dialect nor xla_hlo dialects (or any other dialects I'm aware of that would use this) have any actual IR manifestation for what happens in error cases. E.g. most "matmul" ops that I'm aware of claim to be side effect free. Hmm....

jpienaar added inline comments.Mar 27 2020, 2:18 PM

I'm not sure I follow, shape here is a effectively an error monad, the ops operating on it propagates the error (as mentioned this could end up merging multiple errors together), but the ops are side effect free and can be folded, else many of these ops are effectively asserts that can fail (e.g., they are adding constraints on the number of elements, the maximum extent of dimensions, ...). Once lowered to a runtime executable form, then those errors would need to be considered. E.g., a tf.assert or some such needs to be inserted before spanning the gap to tf.zeros or what not that operates on tensor of ints. There are ops in those platforms which can handle those and so in lowering handle it - there are actual runtime ops there with cancellation or abort mechanisms.

The goal of this dialect is to describe shape computations for multiple different consumers at analysis, compilation or for lowering to runtime. I don't see there being a single consumer/runtime here as the way that different consumers/runtimes could handle computations, failures, error reporting, would be different and keeping it abstract means the lowerings to those targets can consider what they need (e.g., perhaps we are in "assume there is no errors" mode and all checks can be elided, or a target is massively distributed and you want error reporting to be totally ordered).

These are defined by runtime behavior, it is just one of the "runtime" is symbolic analysis (e.g., runtime behavior includes how to create a constraint programming problem). Being able to reuse these in multiple contexts is important, else we end up defining the shapes of ops for each use case. For the runtime case with all the values known, you just need 1D tensor of ints, and you could reuse normal ops to implement the computations (e.g., you could use tf.slice or xla_hlo.mul or linalg), the result of these then just gets consumed by runtime specific functions to allocate/deallocate buffers, so you don't need any special types or ops at that point. So once you are using tensor of ints you are free to use arbitrary ops and you've jumped to a new domain, but given bridging these two could require handling errors, I want the lowering for a use to handle that.

So specified as runtime behavior on shapes makes sense to me, and it makes sense to have this op to convert from !shape.shape to tensor<?xindex> in this dialect, but not for use in a shape functions (else you are just using tensor<?x index> here as a cast to appease the type system).

I might be wrong though, I was trying to sketch multiple cases in my head (and refining the comment every now and again :)) and while I can see cases where you have benefit and approaches to use such an op, I don't know how many of them are really real. But one thing that is true, if we have such an if instead, you get the conversion and if we find it is not needed, it amounts to replacing the if with a .ToTensorExtend and folding the else [which could also be the lowering pattern towards runtime you use and that would give you the exact same shape function you have today]

jpienaar accepted this revision.Mar 27 2020, 3:06 PM
This revision is now accepted and ready to land.Mar 27 2020, 3:06 PM
silvas marked an inline comment as done.Mar 27 2020, 4:24 PM
silvas added inline comments.

We discussed this offline.

There was agreement that an op like ToExtentTensor needs to somehow be guarded against the case that the !shape.shape is an error. For now, the op will abort if the shape is an error.

We also were envisioning different use cases. I don't think that ToExtentTensor would ever be used inside a shape transfer function, which is what Jacques was thinking about. But my use case in a downstream dialect for lowering of batch matmul doesn't arise in a shape transfer function. Shape computations are needed to broadcast the input args to their common shape, which requires passing the extent tensor for the broadcasted shapes to a broadcast_to(%t, %extent_tensor) op before passing it to the lower level batch matmul op.

This revision was automatically updated to reflect the committed changes.