Page MenuHomePhabricator

[mlir][Standard] Canonicalize chains of tensor_cast operations

Authored by herhut on Sep 15 2020, 7:47 AM.



Adds a pattern that replaces a chain of two tensor_cast operations by a single
tensor_cast operation if doing so will not remove constraints on the shapes.

Diff Detail

Event Timeline

herhut created this revision.Sep 15 2020, 7:47 AM
Herald added a project: Restricted Project. · View Herald TranscriptSep 15 2020, 7:47 AM
herhut requested review of this revision.Sep 15 2020, 7:47 AM

I have considered to place a templated version of joinShapes somewhere (to be used with ShapedType and a template parameter to figure out the type to create) but wasn't sure where and whether useful.

jpienaar added inline comments.Sep 15 2020, 8:08 AM

This is a little restrictive, especially if one considers element types with nested types. Something operating on shapes instead could be reused and no need for templating or special restrictions on element type. E.g., join of two shapes with result being a vector/populating a vector (where first element is rank to allow returning unranked) or using ShapedComponentType ( The latter does have an element type and attribute extra (the former can be left empty unless equal in the join, the latter we haven't used yet and so remove it too)


Perhaps expand to say that it won't be used if it contains less information and remove empty line below.

ftynse accepted this revision.Sep 15 2020, 8:11 AM
ftynse added a subscriber: ftynse.
ftynse added inline comments.

Drop llvm::


assert that ranks are equal


Reserve space before pushing in a loop, you are guaranteed to push back exactly once on each iteration

This revision is now accepted and ready to land.Sep 15 2020, 8:11 AM
frgossen added inline comments.Sep 15 2020, 8:35 AM

I think, you could get away without materialising the joined shapes here.

If the cast chain is A -> B -> C then
A <= B or C <= B is a sufficient condition for this rewrite (smaller meaning more concrete shapes).


Would resize(one.getRank()) be useful?


A comment for why this is needed could help.
Cannot eliminate intermediate cast if is stricter than the resulting cast.

frgossen accepted this revision.Sep 15 2020, 8:35 AM
herhut updated this revision to Diff 292128.Sep 16 2020, 1:02 AM
herhut marked 6 inline comments as done.


I will do the change to use ShapedTypeComponents in a follow up to enable reuse for the memref case.


That is true, if computed for each element of the shape. It is not required for the shapes as a whole. Like A = [1,?], B = [?,2] and C = [1,2]. So implementing lessPrecise(ShapeType, ShapeType) and composing the above from it does not work. So I need a specialized castingAtoCisEquivalentToCastingAtoBtoC(), which I shied away from as join seemed more reusable.

So how much to we value performance over code reuse?


A thanks! How about moving ShapedTypeComponents to TypeUtilities.h or even cook up a ShapeUtilities.h? If we use this more broadly than just the InferTypeOpInterface, it should live in a shared location.

frgossen added inline comments.Sep 16 2020, 2:32 AM

Right, the order on shapes is not complete.

The join version is a lot easier to read :-)

herhut closed this revision.Tue, Sep 29, 1:12 AM

This has landed a while back in 5e0ded268929b87ddf2c5e077c9185554342f602.