This is an archive of the discontinued LLVM Phabricator instance.

[DRAFT] Generalize expand_shape to take shape as explicit input
Needs ReviewPublic

Authored by ramiro050 on Jan 1 2023, 10:16 AM.

Details

Summary

*DO NOT SUBMIT*

(This patch is for early design feedback only. Notably, tests have not been
updated and the implementation is incomplete in some cases.)

This patch generalizes tensor.expand_shape and memref.expand_shape to consume
the output shape as a list of SSA values. This enables us to implement generic
reshape operations with dynamic shapes using collapse_shape/expand_shape pairs.

The output_shape input to expand_shape follows the static/dynamic representation
that's also used in tensor.extract_slice.

Diff Detail

Event Timeline

sanjoy created this revision.Jan 1 2023, 10:16 AM
sanjoy requested review of this revision.Jan 1 2023, 10:16 AM

Generally looks reasonable to me.

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
1544–1545

drop this paragraph

1586

In practice, such APIs are quite painful to use.

Instead we use ArrayRef<OpFoldResult> (see e.g. tensor.insert_slice).
We also have SmallVector<OpFoldResult> getMixedSizes-like APIs.

Last, StaticValueUtils.cpp|h has helpers such as

SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
                                         ValueRange dynamicValues, Builder &b);

to do the mechanical parts for you

1677

I would add some more sugar, e.g.:

%b = memref.expand_shape %a reassociate[[0, 1], [2]] into [%sz0, %sz1, 32]

: memref<?x32xf32> into memref<?x?x32xf32>
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
1009

nit: must result in ?

obtain does not work anymore in your new phrasing.

1037

I would add some more sugar, e.g.:

%b = tensor.expand_shape %a reassociate[[0, 1], [2]] into [%sz0, %sz1, 32]
        : tensor<?x32xf32> into tensor<?x?x32xf32>
1050

same comment re APIs

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
44

Make a static member of ExpandShapeOp and attach it to Tablegen, otherwise we easily forget what is available where and people end up reimplementing the same functionality in different places.

48

OpFoldResult-based API should be enough

50

delete this? I don't see why 2 versions

197

seems redundant for no good reason, just use std::is_same?

mravishankar added inline comments.Jan 6 2023, 12:19 PM
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
1578

Nit: Maybe just add a keyword output_shape (or something) after the reassociation list to better show what this is.

1586

+1.

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
1037

Same comment as mine above, this form looks good to me.

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
340–341

I am guessing you wanted to get comments on whether you just fail on this path?

The additional use here you are mentioning, I am guessing are uses in tensor.dim ops. Those need to be handled separately through -resolve-shaped-type-result-dims. This pass is basically a pattern for DimOp, so if you have

%1 = tensor.collapse_shape %0 ...
%2 = tensor.dim %1, ...

It replaces the tensor.dim %1 in terms of tensor.dim %0. This is done through the ReifyRankedShapedTypeOpInterface that is implemented by the collapse/expand shape ops.

sanjoy marked 3 inline comments as done.Jan 14 2023, 10:02 PM

Thanks for the review! Replied to a few things inline as I continue to work on this.

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
1677

This bit is for collapse_shape -- I needed to change it because I moved out assemblyFormat from MemRef_ReassociativeReshapeOp. The behavior for collapse_shape should be unchanged, LMK if you see otherwise.

But I added the output_shape keyword to expand_shape as you suggested above.

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
1037

Added an output_shape keyword (like you suggested for memref.expand_shape) since we already use into when we print the types.

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
197

Unless I'm missing some C++ trick, that requires to add a dependency from this header to either the memref or the tensor dialect. I assumed that was not OK, but please let me know if it is.

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
1037

would wfm but I don't see output_shape spelled out anywhere atm.

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
197

ugh you're right, I always forget about the unfortunate duplication of ops in 2 dialects

ramiro050 commandeered this revision.Mar 9 2023, 11:47 AM
ramiro050 added a reviewer: sanjoy.
ramiro050 added a subscriber: ramiro050.

After talking with @sanjoy, we decided that I will finish the implementation of this patch.

ramiro050 updated this revision to Diff 503870.Mar 9 2023, 11:55 AM

Updating D140821: [DRAFT] Generalize expand_shape to take shape as explicit input

Rebase, address comments, fix crashes when running integration tests.
All that remains is updating lit tests to use new assembly for
expand_shape.

ramiro050 marked 7 inline comments as done.Mar 9 2023, 11:58 AM
ramiro050 added inline comments.Mar 9 2023, 12:09 PM
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
340–341

I've updated this part to handle the creation of the expand op. Should I be adding the handling for the collapse_shape + tensor.dim pattern, or is this something that should already happen during -resolve-shaped-type-result-dims?