This is an archive of the discontinued LLVM Phabricator instance.

Allow only valid vector.shape_cast transitive folding
ClosedPublic

Authored by asaadaldien on Oct 8 2021, 4:45 PM.

Details

Summary

When folding A->B->C => A->C only accept A->C that is valid shape cast

Diff Detail

Event Timeline

asaadaldien created this revision.Oct 8 2021, 4:45 PM
asaadaldien requested review of this revision.Oct 8 2021, 4:45 PM
Herald added a project: Restricted Project. · View Herald TranscriptOct 8 2021, 4:45 PM
asaadaldien retitled this revision from Allow only valid transitive folding to Allow only valid vector.shape_cast transitive folding.Oct 9 2021, 4:02 PM
ThomasRaoux requested changes to this revision.Oct 11 2021, 12:31 PM
ThomasRaoux added inline comments.
mlir/lib/Dialect/Vector/VectorOps.cpp
3645

What if srcType.getShape() == resultType.getRank()? This case would assert right now.

This revision now requires changes to proceed.Oct 11 2021, 12:31 PM
asaadaldien added inline comments.Oct 11 2021, 12:44 PM
mlir/lib/Dialect/Vector/VectorOps.cpp
3645

if its the same rank, in order to be valid it will have to be the same type. so we want be executing this isValidShapeCast

ThomasRaoux added inline comments.Oct 11 2021, 12:57 PM
mlir/lib/Dialect/Vector/VectorOps.cpp
3645

But isValidShapeCast has assert(rankA < rankB);
https://github.com/llvm/llvm-project/blob/e356027016c6365b3d8924f54c33e2c63d931492/mlir/lib/Dialect/Vector/VectorOps.cpp#L3465

so we would just assert in this case?

Assert src rank != result rank

ThomasRaoux added inline comments.Oct 11 2021, 2:14 PM
mlir/lib/Dialect/Vector/VectorOps.cpp
3642

How do we know that? Can't we have something like:

%0 = vector.shape_cast %arg0 : vector<4x2x2xf32> to vector<8x2xf32>
%1 = vector.shape_cast %0 : vector<8x2xf32> to vector<2x4x2xf32>

don't fold A -> B -> C when rank(A) == rank(C) && type(A) != type(C)

asaadaldien added inline comments.Oct 11 2021, 2:25 PM
mlir/lib/Dialect/Vector/VectorOps.cpp
3642

right, it will be always invalid.

ThomasRaoux accepted this revision.Oct 11 2021, 2:28 PM
ThomasRaoux added inline comments.
mlir/lib/Dialect/Vector/VectorOps.cpp
3646
3649–3651

Can you add a test for this case?

This revision is now accepted and ready to land.Oct 11 2021, 2:28 PM
mlir/lib/Dialect/Vector/VectorOps.cpp
3642

Can you put this if/else logic under isValidShapeCast(typea, typeb) ?

nicolasvasilache accepted this revision.Oct 12 2021, 5:52 AM
This revision was automatically updated to reflect the committed changes.